[开发技巧]·深度学习使用生成器加速数据读取与训练简明教程(TensorFlow,pytorch,keras)

2019-06-27 14:53:17 浏览数 (1)

开发技巧·深度学习使用生成器加速数据读取与训练简明教程(TensorFlow,pytorch,keras)

1.问题描述

在深度学习里面有句名言,数据决定深度应用效果的上限,而网络模型与算法的功能是不断逼近这个上限。由此也可以看出数据的重要程度。

在进行深度学习的开发中,我们在建模与训练之前很重要的部分就是数据特征分析与读取,这篇文章的主要内容是数据的读取与组织,其他的方面等以后在其他博客中阐述。

数据读取的一般方式使同一放到一个数组里面去,在一些小的数据上这样处理可以,但是在一些数据量比较多的数据集上就会有很大问题了:

  • 占用太大内存,我们在训练网络时,一般采取minibatch的方法,没必要一下读取很多数据在使用切片选取一部分。
  • 花费更长时间,我们生成包含所有数据的数组时,会去读取每个元素,所有的时间在累加在一起,很耗时,此时神经网络也没有在训练,这样会导致总体的时间加长很多。

笔者在开发的过程中,在使用大规模的数据集(上百万条音频数据)时就遇到了这些问题。首先全部读取到内存,内存空间肯定不够用,再者读取耗时累加就会超过好几天。最终还是解决上述的问题,这归功于Python的一个强大功能,生成器。

生成器实现了这些功能,可以按批次读取返回数据,返回完一批数据后重新从上次结束的地方继续读取返回

2.编程实战

2.1生成一些假数据用于演示

代码语言:javascript复制
import numpy as np
import math
data = np.array([[x*10,x] for x in range(16)])

print(data)

输出结果

代码语言:javascript复制
[[  0   0]
 [ 10   1]
 [ 20   2]
 [ 30   3]
 [ 40   4]
 [ 50   5]
 [ 60   6]
 [ 70   7]
 [ 80   8]
 [ 90   9]
 [100  10]
 [110  11]
 [120  12]
 [130  13]
 [140  14]
 [150  15]]

2.2构建生成器

代码语言:javascript复制
def xs_gen(data,batch_size):
    lists = data
    num_batch = math.ceil(len(lists) / batch_size)    # 确定每轮有多少个batch
    for i in range(num_batch):
        batch_list = lists[i * batch_size : i * batch_size   batch_size]
        np.random.shuffle(batch_list)  
        batch_x = np.array([x for x in batch_list[:,0]])
        batch_y = np.array([y for y in batch_list[:,1]])

为了方便演示,上面是直接对列表进行读入操作,一般在用的时候是读取path列表,在按照path提取数据

2.3演示输出

代码语言:javascript复制
if __name__ == "__main__":

    #data_gen = xs_gen(data,5)
    for x,y in xs_gen(data,5):
        print("item",x,y)
    for x,y in xs_gen(data,5):
        print("item",x,y)

结果如下

代码语言:javascript复制
item [30 20 10  0 40] [3 2 1 0 4]
item [50 70 80 90 60] [5 7 8 9 6]
item [110 120 140 100 130] [11 12 14 10 13]
item [150] [15]
item [ 0 30 20 10 40] [0 3 2 1 4]
item [60 80 90 70 50] [6 8 9 7 5]
item [130 100 110 120 140] [13 10 11 12 14]
item [150] [15]

的确是按照我们的想法组织了,但是有个问题。对比上方的第一行和第五行可以发现,虽然会打乱数据,但是数据还是那五个,最好的结果应该是随机的五个数据。

怎么实现呢,我们可以通过增加一个判断条件,当为返回第一批数据时,打乱整个表格。

2.4改进的生成器函数

代码语言:javascript复制
def xs_gen_pro(data,batch_size):
    lists = data
    num_batch = math.ceil(len(lists) / batch_size)    # 确定每轮有多少个batch
    for i in range(num_batch):
        if(i==0):
            np.random.shuffle(lists)
        batch_list = lists[i * batch_size : i * batch_size   batch_size]
        np.random.shuffle(batch_list)
        batch_x = np.array([x for x in batch_list[:,0]])
        batch_y = np.array([y for y in batch_list[:,1]])

        yield batch_x, batch_y

再次输出数据

代码语言:javascript复制
item [50 30 20 90 80] [5 3 2 9 8]
item [ 60   0 100 110  40] [ 6  0 10 11  4]
item [120  10 140 130 150] [12  1 14 13 15]
item [70] [7]
item [120  90  70  80 130] [12  9  7  8 13]
item [ 10 150 100   0  50] [ 1 15 10  0  5]
item [140  30  60  20 110] [14  3  6  2 11]
item [40] [4]

这次数据随机很彻底了。

如何在深度学习应用生成器

2.1如何在TensorFlow,pytorch应用生成器

在TensorFlow,pytorch应用生成器时可以直接应用

代码语言:javascript复制
for e in Epochs:
    for x,y in xs_gen():
    train(x,y)

2.1如何在keras应用生成器

在keras使用生成器要做些小修改

代码语言:javascript复制
def xs_gen_keras(data,batch_size):
    lists = data
    num_batch = math.ceil(len(lists) / batch_size)    # 确定每轮有多少个batch
    while True:
        np.random.shuffle(lists)
        for i in range(num_batch):
                
            batch_list = lists[i * batch_size : i * batch_size   batch_size]
            np.random.shuffle(batch_list)
            batch_x = np.array([x for x in batch_list[:,0]])
            batch_y = np.array([y for y in batch_list[:,1]])

            yield batch_x, batch_y

keras使用生成器训练

代码语言:javascript复制
train_iter = xs_gen_keras()
val_iter = xs_gen_keras()
model.fit_generator(   
        generator=train_iter,
        steps_per_epoch=Lens1//Batch_size,
        epochs=10,
        validation_data = val_iter,
        nb_val_samples = Lens2//Batch_size,
        )

简单讲解几个参数,val_iter就是自己定义的测试生成器,我上面直接用训练生成器来做了,大家使用时注意仿照训练生成器自己修改一下。其中steps_per_epoch就是一个epoch中有多少个batch,nb_val_samples 定义类似,使用的时候就是那总的数据个数整除Batch_size。具体的参数可以查阅keras的文档。

具体例子的应用生成器训练网络可以参考我的这个实战博文:https://cloud.tencent.com/developer/article/1451538

0 人点赞