“ 在此记录对PyTorch框架学习过程中的思考。”
数据加载处理是深度学习模型训练的前奏,是很重要的一部分。这一过程需要把原始数据,影像或者文本等进行封装、转换,并以合适的格式传递给模型。这个过程依赖torch.utils.data模块,常用以上三个类:
torch.utils.data.Dataset
torch.utils.data.Sampler
torch.utils.data.DataLoader
01
—
三者关系
三者的关系可以表示如下图:
三个类形成对数据的层层封装。
Dataset对原始数据进行封装,暴露数据提取的接口。
Sampler决定了采样策略,根据不同索引方式来从Dataset中提取部分数据。
DataLoader通过封装Dataset和Sampler,设定batch_size等参数,构造了方便快速遍历的mini batch数据集。
02
—
Dataset
Dataset是一个抽象类,迭代器。负责对原始数据进行封装,形成模型可以识别的数据结构,其暴露了获取单个数据的接口。
Dataset有两种:Map-style datasets 和 Iterable-style datasets
torch.utils.data.Dataset通过实现__len__()和__getitem__()来获取数据。
torch.utils.data.IterableDataset 通过实现__iter()__来获取数据。
可以通过集成Dataset类来自定义自己的数据集,如下示例,通过改写__getitem__()方法自定义提取数据,可以在其中加入数据增强的方法。
代码语言:javascript复制import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
03
—
Sampler
Sampler定义了对数据集的采样策略。通过Sampler类中的__iter__()方法来获取数据集的索引,其基类如下。
代码语言:javascript复制class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
所有子类都继承自Sampler,通过改写__iter__()方法来实现。
torch.utils.data.SequentialSampler 指定顺序采样样本。
torch.utils.data.RandomSampler 随机采样,可指定是否放回样本
torch.utils.data.DistributeSampler 数据加载限制为数据集子集,每个进程都可以把一个DistributeSampler实例作为DataLoader采样器传递
torch.utils.data.BatchSampler 在一个小batch中封装一个Sampler,返回小batch的索引
04
—
DataLoader
DataLoader是数据加载的核心,它对Dataset和Sampler进行封装,以mini batch的形式加载数据。支持单进程和多进程.
代码语言:javascript复制torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=None, pin_memory=False,
drop_last=False, timeout=0, worker_init_fn=None,
multiprocessing_context=None, generator=None, *,
prefetch_factor=2, persistent_workers=False)
DataLoader是数据加载的核心,它对Dataset和Sampler进行封装,以mini batch的形式加载数据。支持单进程和多进程.
Dataset,加载的数据集,Dataset实例
batch_size,每个batch的样本数
shuffle:设置为True,在每个epoch开始前,都会随机抽取数据,调用了RandomSampler
sampler:定义从数据集的抽取策略,指定了sampler,shuffle必须为False
batch_sampler:和sampler功能一样,传入BatchSampler,和batch_size
shuffle互斥。
num_workers:指定进程数。默认0,只在主进程加载数据
drop_last:True的话,会删除最后一个不完整的batch数据。
总结来讲,DataLoader通过Sampler定义的索引策略,从Dataset中遍历提取数据。