目录
1 迭代器介绍
2 Dataset
2 Sampler
3 DataLoader
4 三者关系
一张图带你看懂全文
最近被迫开始了居家办公,这不,每天认真工(mo)作(yu)之余,也有了更多时间重新学习分析起了 PyTorch 源码分享,属于是直接站在巨人的肩膀上了。在简单捋一捋思路之后,就从 torch.utils.data 数据处理模块开始,一步步重新学习 PyTorch 的一些源码模块解析,希望也能让大家重新认识已经不陌生的 PyTorch 这个小伙伴。
1. 迭代器介绍
OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。
· __len__(self):定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数。
· __getitem__(self):定义获取容器中指定元素时的行为,相当于 self[key] ,即允许类对象拥有索引操作。
· __iter__(self):定义当迭代容器中的元素时的行为。
除此之外,我们也需要清楚两个概念:
· 迭代(Iteration):当我们用一个循环(比如 for 循环)来遍历容器(比如列表,元组)中的元素时,这种遍历的过程可称为迭代。
· 可迭代对象(Iterable):一般指含有 __iter__() 方法或 __getitem__() 方法的对象。我们通常接触的数据结构,如序列(列表、元组和字符串)还有字典等,都支持迭代操作,也可称为可迭代对象。
那什么是迭代器(Iterator)呢?简而言之,迭代器就是一种可以被遍历的容器类对象,但它又比较特别,它需要遵循迭代器协议,那什么又是迭代器协议呢?迭代器协议(iterator protocol)是指要实现对象的 __iter()__ 和 __next__() 方法。一个容器或者类如果是迭代器,那么就必须实现 __iter__() 方法以及重点实现 __next__() 方法,前者会返回一个迭代器(通常是迭代器对象本身),而后者决定了迭代的规则。现在,为更好地理解迭代器的内部运行机制,我们可以看一个斐波那契数列的迭代器实现例子:
代码语言:javascript复制class Fibs:
def __init__(self, n=20):
self.a = 0
self.b = 1
self.n = n
def __iter__(self):
return self
def __next__(self):
self.a, self.b = self.b, self.a self.b
if self.a > self.n:
raise StopIteration
return self.a
fibs = Fibs()
for each in fibs:
print(each)
# 输出
# 1 1 2 3 5 8 13
一般而言,迭代器满足以下几种特性:
· 迭代器是⼀个对象,但比较特别,需要满足迭代器协议,他还可以被 for 语句循环迭代直到终⽌。
· 迭代器可以被 next() 函数调⽤,并返回⼀个值,亦可以被 iter() 函数调⽤,但返回的是一个迭代器(可以是自身)。
· 迭代器连续被 next() 函数调⽤时,依次返回⼀系列的值,但如果到了迭代的末尾,则抛出 StopIteration 异常,另外他可以没有末尾,但只要被 next() 函数调⽤,就⼀定会返回⼀个值。
· Python3 中, next() 内置函数调⽤的是对象的 __next__() ⽅法,iter() 内置函数调⽤的是对象的 __iter__() ⽅法。
那么,了解了什么是迭代器后,我们马上开始解析 torch.utils.data 模块,对于 torch.utils.data 而言,重点是其 Dataset,Sampler,DataLoader 三个模块,辅以 collate,fetch,pin_memory 等组件对特定功能予以支持。
Tips:涉及的源码皆以 PyTorch 1.7 为准。
2. Dataset
Dataset 主要负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。Dataset 中共有 Map-style datasets 和 Iterable-style datasets 两种:
1.1 Map-style dataset
torch.utils.data.Dataset 它是一种通过实现 __len__() 和 __getitem__() 方法来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。因而,在我们访问 Map-style 的数据集时,使用 dataset[idx] 即可访问 idx 对应的数据。通常,我们使用 Map-style 类型的 dataset 居多,可以看到其数据接口定义如下:
代码语言:javascript复制class Dataset(Generic[T_co]):
# Generic is an Abstract base class for generic types.
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
在 PyTorch 1.7 源码中所有定义的 Dataset 都是其子类,而对于一般计算机视觉任务,我们通常也会在其中进行一些 resize,crop,flip 等预处理的操作。
值得一提的是,PyTorch 源码中并没有提供默认的 __len__() 方法实现,原因是 return NotImplemented 或者 raise NotImplementedError() 之类的默认实现都会存在各自的问题,这点我们在源码 pytorch/torch/utils/data/sampler.py 中的注释也可以得到解释。
1.2 Iterable-style dataset
torch.utils.data.IterableDataset 它是一种实现 __iter__() 来获取数据的 Dataset,Iterable-style 的数据集特别适用于以下情况:随机读取代价很大甚至不可能,且 batch size 取决于获取到的数据。其接口定义如下:
代码语言:javascript复制class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
特别地,当 DataLoader 的 num_workers > 0 时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)。
1.3 其他 Dataset
除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类:
· torch.utils.data.ConcatDataset:用于连接多个 ConcatDataset 数据集。
· torch.utils.data.ChainDataset:用于连接多个 IterableDataset 数据集,在 IterableDataset 的 __add__() 方法中被调用。
· torch.utils.data.Subset:用于获取指定一个索引序列对应的子数据集。
代码语言:javascript复制class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
· torch.utils.data.TensorDataset:用于获取封装成 tensor 的数据集,每一个样本都可通过索引张量来获得。
代码语言:javascript复制class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors
def __len__(self):
return self.tensors[0].size(0)
3. Sampler
torch.utils.data.Sampler 主要负责提供一种遍历数据集所有元素索引的方式。可支持我们自定义,也可以使用 PyTorch 本身提供的,其基类接口定义如下:
代码语言:javascript复制lass Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
特别地,__len()__ 方法虽不是必要的,但是当 DataLoader 需要计算 length 的时候必须定义,这点在源码中也有注释加以体现。
同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类:
· torch.utils.data.SequentialSampler:顺序采样样本,始终按照同一个顺序。
· torch.utils.data.RandomSampler:可指定有无放回地,进行随机采样样本元素。
· torch.utils.data.SubsetRandomSampler:无放回地按照给定的索引列表采样样本元素。
· torch.utils.data.WeightedRandomSampler:按照给定的概率来采样样本。样本元素来自 [0,…,len(weights)-1] ,给定概率(权重)。
· torch.utils.data.BatchSampler:在一个 batch 中封装一个其他的采样器, 返回一个 batch 大小的 index 索引。
· torch.utils.data.DistributedSample:将数据加载限制为数据集子集的采样器。与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递。
4. DataLoader
torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以通过参数设置如 sampler, batch size, pin memory 等自定义数据加载顺序以及控制数据批处理功能。其接口定义如下:
代码语言:javascript复制DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
对于每个参数的含义,下面通过一个表格进行直观地介绍:
从参数定义中,我们可以看到 DataLoader 主要支持以下几个功能:
· 支持加载 map-style 和 iterable-style 的 dataset,主要涉及到的参数是 dataset。
· 自定义数据加载顺序,主要涉及到的参数有 shuffle,sampler,batch_sampler,collate_fn。
· 自动把数据整理成batch序列,主要涉及到的参数有 batch_size,batch_sampler,collate_fn,drop_last。
· 单进程和多进程的数据加载,主要涉及到的参数有 num_workers,worker_init_fn。
· 自动进行锁页内存读取 (memory pinning),主要涉及到的参数 pin_memory。
· 支持数据预加载,主要涉及的参数 prefetch_factor。
3.1 批处理
3.1.1 自动批处理(默认)
DataLoader 支持通过参数 batch_size, drop_last, batch_sampler,自动地把取出的数据整理(collate)成批次样本(batch),其中 batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample 参数,一次就生成一个 keys list。
在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象整个过程,其表示方式大致如下:
代码语言:javascript复制# For Map-style
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
3.1.2 关闭自动批处理
当我们想用 dataset 代码手动处理 batch,或仅加载单个 sample data 时,可将 batch_size 和 batch_sampler 设为 None, 将关闭自动批处理。此时,由 Dataset 产生的 sample 将会直接被 collate_fn 处理。抽象整个过程,其表示方式大致如下:
代码语言:javascript复制# For Map-style
for index in sampler:
yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
yield collate_fn(data)
3.1.3 collate_fn
当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。
而当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情:
· 添加新的批次维度(一般是第一维)。
· 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。
· 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或 list,当数据类型不能转换的时候)。这在 list,tuples,namedtuples 同样适用。
自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。
5. 三者关系
通过以上解析的三者工作内容,不难可以推出其内在关系:
1)设置 Dataset,将数据 data source 包装成 Dataset 类,暴露出提取接口。
2)设置 Sampler,决定采样方式。我们虽然能从 Dataset 中提取元素了,但还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
3)将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle,batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。
至此我们就可以了解到了 Dataset,Sampler,Dataloader 三个类的基本定义以及对应实现功能,同时也介绍了批处理对应参数组件。总结来说,我们需要记得的是三点,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。
今天的分享就到此为止啦,关于 prefetch,pin_memory 等组件的介绍,我们会在后续系列文章中和大家分享,并对其特定功能予以解读,相关的数据处理代码详解也会一并附上。