python多线程结合DataLoader加载数据

2021-09-14 21:13:28 浏览数 (1)

在模型训练过程中,通常大家都会将注意力集中在模型加速以及提升GPU使用率,但是有时我们的耗时瓶颈也会在读取数据上,gpu处理太快,反而cpu喂数据跟不上。当然框架也会提供一些数据读取加速方案,比如tensorflow的 tf.data.TFRecordDataset,pytorch的DataLoader使用num_workers参数内部采用多线程方案等,还有些代码是将所有数据制作到一个二进制文件读入内存,然后从内存中快速读取数据,但是这种方案无法处理大数据项目。

tensorflow的record也需要先生成record文件格式然后读取,pytorch的DataLoader在设置num_workers时特别在windows中有些版本设置为非0会存在一些问题,本文介绍自己使用python的多线程来处理数据的一种方案,然后结合pytorch的Dataset和DataLoader获取数据,供大家参考。

一 创建buffer类

先建立一个buffer类,其中读写数据需要使用两个锁

代码语言:txt复制
import threading
import random

class Buffer:
    def __init__(self, size):
        self.size = size
        self.buffer = []
        self.lock = threading.Lock()
        self.has_data = threading.Condition(self.lock)
        self.has_pos = threading.Condition(self.lock)

    def get_size(self):
        return self.size

    def get(self):
        with self.has_data:
            while len(self.buffer) == 0:
                self.has_data.wait()
            result = self.buffer[0]
            # print("get buffer size", len(self.buffer))
            del self.buffer[0]
            self.has_pos.notify_all()
        return result

    def put(self, data):
        with self.has_pos:
            while len(self.buffer) >= self.size:
                self.has_pos.wait()
            self.buffer.append(data)
            self.has_data.notify_all()

# test
def get():
    while True:
        get_data = buffer.get()
# test
def put():
    while True:
        data = random.randint(0, 9)
        buffer.put(a)

buffer类参考:https://cloud.tencent.com/developer/article/1724559

二 创建Dataset

生成一个DataReader创建多线程写数据,以及单线程读数据。以下为多线程的关键代码

代码语言:txt复制
class DataReader:
    def __init__(self, max_buffer_size=5000):
        self.audio_files = files_to_list(training_files)
        random.shuffle(self.audio_files)
        self.buffer = Buffer(max_buffer_size)
    # 消费数据
    def comsume(self):
        while True:
            result = self.buffer.get()
    # 生产数据 
    def produce(self):
        while True:
            global index
            index  = 1
            if index >= len(self.audio_files)-1:
                index = 0
            start = time.time()
            file = self.audio_files[index]
            audio = load_wav(file)
            end = time.time()
            self.buffer.put(audio)

    def run_produce(self, thread_num=16):
        # 多线程生产
        for _ in range(thread_num):
            th = threading.Thread(target=self.produce)
            th.start()

    def get_item(self, index):
        result = self.buffer.get()
        return result
       

下面使用一个Dataset来使用DataReader获取数据

代码语言:txt复制
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data_reader = DataReader()
        self.data_reader.run_produce()
        
    def __getitem__(self, index):
        # 从buffer中获取一个数据
        start = time.time()
        audio = self.data_reader.get_item(index)
        # 进行数据处理
        ...
        audio = torch.from_numpy(audio).float()
        end = time.time()
        # print("get item time cost", (end - start) * 1000, audio.shape)
        return audio.unsqueeze(0)
    def __len__(self):
        return len(self.audio_files)

三 创建DataLoader

最后就可以通过DataLoader从DataSet中循环获取batch数据输入到模型进行训练了

代码语言:python代码运行次数:0复制
dataset = AudioDataset()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

0 人点赞