PyTorch-数据处理流程

2022-11-14 11:00:07 浏览数 (2)

在此记录对PyTorch框架学习过程中的思考。

数据加载处理是深度学习模型训练的前奏,是很重要的一部分。这一过程需要把原始数据,影像或者文本等进行封装、转换,并以合适的格式传递给模型。这个过程依赖torch.utils.data模块,常用以上三个类:

torch.utils.data.Dataset

torch.utils.data.Sampler

torch.utils.data.DataLoader

01

三者关系

三者的关系可以表示如下图:

三个类形成对数据的层层封装。

Dataset对原始数据进行封装,暴露数据提取的接口。

Sampler决定了采样策略,根据不同索引方式来从Dataset中提取部分数据。

DataLoader通过封装Dataset和Sampler,设定batch_size等参数,构造了方便快速遍历的mini batch数据集。

02

Dataset

Dataset是一个抽象类,迭代器。负责对原始数据进行封装,形成模型可以识别的数据结构,其暴露了获取单个数据的接口。

Dataset有两种:Map-style datasets 和 Iterable-style datasets

torch.utils.data.Dataset通过实现__len__()和__getitem__()来获取数据。

torch.utils.data.IterableDataset 通过实现__iter()__来获取数据。

可以通过集成Dataset类来自定义自己的数据集,如下示例,通过改写__getitem__()方法自定义提取数据,可以在其中加入数据增强的方法。

代码语言:javascript复制
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

03

Sampler

Sampler定义了对数据集的采样策略。通过Sampler类中的__iter__()方法来获取数据集的索引,其基类如下。

代码语言:javascript复制
class Sampler(Generic[T_co]):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source: Optional[Sized]) -> None:
        pass

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

所有子类都继承自Sampler,通过改写__iter__()方法来实现。

torch.utils.data.SequentialSampler 指定顺序采样样本。

torch.utils.data.RandomSampler 随机采样,可指定是否放回样本

torch.utils.data.DistributeSampler 数据加载限制为数据集子集,每个进程都可以把一个DistributeSampler实例作为DataLoader采样器传递

torch.utils.data.BatchSampler 在一个小batch中封装一个Sampler,返回小batch的索引

04

DataLoader

DataLoader是数据加载的核心,它对Dataset和Sampler进行封装,以mini batch的形式加载数据。支持单进程和多进程.

代码语言:javascript复制
torch.utils.data.DataLoader(dataset, batch_size=1, 
          shuffle=False, sampler=None, batch_sampler=None, 
          num_workers=0, collate_fn=None, pin_memory=False, 
          drop_last=False, timeout=0, worker_init_fn=None, 
          multiprocessing_context=None, generator=None, *, 
          prefetch_factor=2, persistent_workers=False)

DataLoader是数据加载的核心,它对Dataset和Sampler进行封装,以mini batch的形式加载数据。支持单进程和多进程.

Dataset,加载的数据集,Dataset实例

batch_size,每个batch的样本数

shuffle:设置为True,在每个epoch开始前,都会随机抽取数据,调用了RandomSampler

sampler:定义从数据集的抽取策略,指定了sampler,shuffle必须为False

batch_sampler:和sampler功能一样,传入BatchSampler,和batch_size

shuffle互斥。

num_workers:指定进程数。默认0,只在主进程加载数据

drop_last:True的话,会删除最后一个不完整的batch数据。

总结来讲,DataLoader通过Sampler定义的索引策略,从Dataset中遍历提取数据。

0 人点赞