1. Pytorch模块总览
相比TensorFlow,PyTorch 是非常轻量级的:相比 TensorFlow 追求兼容并包,PyTorch 把外围功能放在了扩展包中,比如torchtext,以保持主体的轻便。
根据PyTorch 的 API,可知其核心大概如下:
torch.nn
&torch.nn.functional
:构建神经网络torch.nn.init
:初始化权重torch.optim
:优化器torch.utils.data
:载入数据
可以说,掌握了上面四个模块和前文中提到的底层 API,至少 80% 的 PyTorch 任务都可以完成。剩下的外围事物则有如下的模块支持:
torch.cuda
:管理 GPU 资源torch.distributed
:分布式训练torch.jit
:构建静态图提升性能torch.tensorboard
:神经网络的可视化
如果额外掌握了上面的四个的模块,PyTorch 就只剩下一些边边角角的特殊需求了。
2.torch.utils.data
这个功能包的作用是收集、打包数据,给数据索引,然后按照 batch 将数据分批喂给神经网络。
数据读取的核心是 torch.utils.data.DataLoader
类。它是一个数据迭代读取器,支持
- 映射方式和迭代方式读取数据;
- 自定义数据读取顺序;
- 自动批;
- 单线程或多线程数据读取;
- 自动内存定位。
所有上述功能都可以在 torch.utils.data.DataLoader
的变量中定义:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
最重要的变量为 dataset
,它指明了数据的来源。
DataLoader 支持两种数据类型:
- 映射风格的数据封装(map-style datasets):这种数据结构拥有自定义的
__getitem__()
和__len__()
属性,可以以“索引/值”的方式读取数据,对应torch.utils.data.Dataset
类; - 迭代风格的数据封装(iterable-style datasets):这种数据结构拥有自定义的
__iter__()
属性,通常适用于不方便随机获取数据或不定长数据集的读取上,对应torch.utils.data.IterableDataset
类。
下面我们从顶层的 torch.utils.data.DataLoader
开始,然后一步一步深入到自定义的细节上。为了方便讨论,我们先人工构建一个数据集:
>>> samples = torch.arange(100)
>>> labels = torch.cat([torch.zeros(50), torch.ones(50)], dim=0)
2.1 torch.utils.data.DataLoader 数据加载器
首先看一下常用的变量:
dataset
:数据源;batch_size
:一个整数,定义每一批读取的元素个数;shuffle
:一个布尔值,定义是否随机读取;sampler
:定义获取数据的策略,必须与shuffle
互斥;num_workers
:一个整数,读取数据使用的线程数;collate_fn
:一个将读取的数据处理、聚合成一个一个 batch 的自定义函数;drop_last
:一个布尔值,如果最后一批数据的个数不足 batch 的大小,是否保留这个 batch。
dataset
, sampler
和 collate_fn
是自定义的类或功能,我们从后往前看。
2.2 数据集的分割
在介绍这三个变量以前,我们先看看如何将数据集分割,比如分成训练集和测试集。
torch.utils.data.Subset(dataset, indices)
这个函数可以根据索引indices将数据集dataset分割。
代码语言:javascript复制>>> even = [i for i in range(100) if i % 2 == 0]
>>> new1 = torch.utils.data.Subset(samples, even)
>>> print(new1[:5])
tensor([0, 2, 4, 6, 8])
torch.utils.data.random_split(dataset, lengths)
先将数据随机排列,然后按照指定的长度进行选择。长度的和必须等于数据集中的数据数量。
代码语言:javascript复制>>> train, test = torch.utils.data.random_split(samples, [90, 10])
>>> print(torch.tensor(test))
tensor([79, 60, 98, 74, 31, 43, 21, 69, 55, 76])
2.3. collate_fn 核对函数
这个变量的功能是在数据被读取后,送进模型前对所有数据进行处理、打包。
比如我们有一个不定长度的视频数据集或文本数据集,我们可以自定义一个函数将它们的长度归一化。比如:
代码语言:javascript复制>>> a = [[1,2,3],[4,5],[6,7,8,9]]
>>> def collate_fn(data):
... '''
... padding data, so they have same length.
... '''
... max_len = max([len(feature) for feature in data])
... new = torch.zeros(len(data), max_len)
... for i in range(len(data)):
... tmp = torch.as_tensor(data[i])
... j = len(tmp)
... new[i][:j] = tmp
... return new
>>> collate_fn(a)
tensor([[1., 2., 3., 0.],
[4., 5., 0., 0.],
[6., 7., 8., 9.]])
将这个函数赋值给 collate_fn
,在读取数据的时候就可以自动对数据进行 padding 并打包成一个 batch。
2.4 sampler 采样器
这个变量决定了数据读取的顺序。
注意,sampler
只对 iterable-style datasets 有效。
除了可以自定义采样器,Python 内置了几种不同的采样器:
torch.utils.data.SequentialSampler(data_source)
默认的采样器。torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
随机选择数据。可以指定一次读取 num_samples
个数据。replacement
为 True
的话可以指定 num_samples
。
>>> batch = torch.utils.data.RandomSampler(samples, replacement=True, num_samples=5) # 生成一个迭代器
>>> print(list(batch))
[85, 70, 5, 63, 79]
还有三个采样器无法独立使用,必须先实例化,然后放进 DataLoader
:
torch.utils.data.SubsetRandomSampler(indices)
:先按照索引选取数据,然后随机排列。torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
:字面意思是按照概率选择不同类别的元素。torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
:在一个 batch 中应用另外一个采样器。
2.5 dataset 数据集生成器
torch.utils.data.Dataset
这个类需要覆写 __getitem__
和 __len__
属性。
>>> class MyData(torch.utils.data.Dataset):
... def __init__(self, data):
... super(MyData, self).__init__()
... self.data = data
... def __len__(self, data):
... return len(self.data)
... def __getitem__(self, index):
... return self.data[index]
>>> mydata = MyData(samples)
>>> mydata[0]
tensor(0)
>>> mydata[10:15]
tensor([10, 11, 12, 13, 14])
除此以外,还有若干个 wrapper:
torch.utils.data.IterableDataset
torch.utils.data.TensorDataset(*tensors)
torch.utils.data.ConcatDataset(datasets)
torch.utils.data.ChainDataset(datasets)
2.6 总结
选择让我们把所有知识应用一下。假设我们想以 10 为一个 batch,随机选择数据:
代码语言:javascript复制>>> train = data.TensorDataset(torch.as_tensor(samples), torch.as_tensor(labels))
>>> ds = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
>>> for _ in range(5):
... print(iter(ds).next())
[tensor([35, 19, 99, 58, 59, 10, 26, 86, 24, 25]), tensor([0., 0., 1., 1., 1., 0., 0., 1., 0., 0.])]
[tensor([ 6, 37, 24, 98, 96, 18, 88, 90, 19, 87]), tensor([0., 0., 0., 1., 1., 0., 1., 1., 0., 1.])]
[tensor([80, 75, 48, 34, 90, 67, 8, 63, 47, 32]), tensor([1., 1., 0., 0., 1., 1., 0., 1., 0., 0.])]
[tensor([48, 68, 64, 54, 87, 76, 18, 53, 65, 17]), tensor([0., 1., 1., 1., 1., 1., 0., 1., 1., 0.])]
[tensor([65, 26, 67, 5, 4, 8, 35, 47, 40, 96]), tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 1.])]