在模型训练过程中,通常大家都会将注意力集中在模型加速以及提升GPU使用率,但是有时我们的耗时瓶颈也会在读取数据上,gpu处理太快,反而cpu喂数据跟不上。当然框架也会提供一些数据读取加速方案,比如tensorflow的 tf.data.TFRecordDataset,pytorch的DataLoader使用num_workers参数内部采用多线程方案等,还有些代码是将所有数据制作到一个二进制文件读入内存,然后从内存中快速读取数据,但是这种方案无法处理大数据项目。
tensorflow的record也需要先生成record文件格式然后读取,pytorch的DataLoader在设置num_workers时特别在windows中有些版本设置为非0会存在一些问题,本文介绍自己使用python的多线程来处理数据的一种方案,然后结合pytorch的Dataset和DataLoader获取数据,供大家参考。
一 创建buffer类
先建立一个buffer类,其中读写数据需要使用两个锁
代码语言:txt复制import threading
import random
class Buffer:
def __init__(self, size):
self.size = size
self.buffer = []
self.lock = threading.Lock()
self.has_data = threading.Condition(self.lock)
self.has_pos = threading.Condition(self.lock)
def get_size(self):
return self.size
def get(self):
with self.has_data:
while len(self.buffer) == 0:
self.has_data.wait()
result = self.buffer[0]
# print("get buffer size", len(self.buffer))
del self.buffer[0]
self.has_pos.notify_all()
return result
def put(self, data):
with self.has_pos:
while len(self.buffer) >= self.size:
self.has_pos.wait()
self.buffer.append(data)
self.has_data.notify_all()
# test
def get():
while True:
get_data = buffer.get()
# test
def put():
while True:
data = random.randint(0, 9)
buffer.put(a)
buffer类参考:https://cloud.tencent.com/developer/article/1724559
二 创建Dataset
生成一个DataReader创建多线程写数据,以及单线程读数据。以下为多线程的关键代码
代码语言:txt复制class DataReader:
def __init__(self, max_buffer_size=5000):
self.audio_files = files_to_list(training_files)
random.shuffle(self.audio_files)
self.buffer = Buffer(max_buffer_size)
# 消费数据
def comsume(self):
while True:
result = self.buffer.get()
# 生产数据
def produce(self):
while True:
global index
index = 1
if index >= len(self.audio_files)-1:
index = 0
start = time.time()
file = self.audio_files[index]
audio = load_wav(file)
end = time.time()
self.buffer.put(audio)
def run_produce(self, thread_num=16):
# 多线程生产
for _ in range(thread_num):
th = threading.Thread(target=self.produce)
th.start()
def get_item(self, index):
result = self.buffer.get()
return result
下面使用一个Dataset来使用DataReader获取数据
代码语言:txt复制class AudioDataset(torch.utils.data.Dataset):
def __init__(self):
self.data_reader = DataReader()
self.data_reader.run_produce()
def __getitem__(self, index):
# 从buffer中获取一个数据
start = time.time()
audio = self.data_reader.get_item(index)
# 进行数据处理
...
audio = torch.from_numpy(audio).float()
end = time.time()
# print("get item time cost", (end - start) * 1000, audio.shape)
return audio.unsqueeze(0)
def __len__(self):
return len(self.audio_files)
三 创建DataLoader
最后就可以通过DataLoader从DataSet中循环获取batch数据输入到模型进行训练了
代码语言:python代码运行次数:0复制dataset = AudioDataset()
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
)