pytorch之Dataset

2022-11-21 12:05:13 浏览数 (1)

一.前言

训练模型一般都是先处理 数据的输入问题 和 预处理问题 。Pytorch提供了几个有用的工具:torch.utils.data.Dataset 类和 torch.utils.data.DataLoader 类 。

而这也是我们在之前的文章里说过的三件套之一。

流程是先把原始数据转变成 torch.utils.data.Dataset 类,随后再把得到的 torch.utils.data.Dataset 类当作一个参数传递给 torch.utils.data.DataLoader 类,得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用。

所以整体的流程是

数据=》Datasets=》DataLoader

在 pytorch 中,提供了一种十分方便的数据读取机制,即使用 torch.utils.data.Dataset 与 Dataloader 组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个 batch 数据,并能在输出时对数据进行相应的预处理或数据增广操作。

二.Datasets类

如果我们要自己定义一个读取数据的方法,就得继承torch.utils.data.Dataset这个父类,并且需要重写两个方法

我们可以看一下Dataset父类的源码:

代码语言:javascript复制
class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

可以清楚的看到我们需要重写两个方法分别是getitem和len方法

下面我们自定义我们自己的数据读取类

代码语言:javascript复制
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms

#继承data.Dataset
# __init__方法
# __getitem__必须创建,作用:对数据切片
#__len__必须创建,作用:返回对象长度
class Tomdataset(data.Dataset):
    def __init__(self,root):
        self.imgs_path=root
    def __getitem__(self,index):
        img_path=self.imgs_path[index]
        return img_path
    def __len__(self):
        return len(Self.imgs_path)

这里的Tomdataset类就简单的实现了一下数据加载类的方法

因为可以重写方法的实现,所以我觉得可玩性还是很高的,比如在getitem方法内我们不仅可以返回一个单纯的元素,如果在构造方法中有其他的参数也可以一并返回。所以客制性很高,可玩性也很高。

在后面的文章中,我们会使用Tomdataset类对我们的数据进行加载和处理。

未完,待续

0 人点赞