一.前言
训练模型一般都是先处理 数据的输入问题 和 预处理问题 。Pytorch提供了几个有用的工具:torch.utils.data.Dataset 类和 torch.utils.data.DataLoader 类 。
而这也是我们在之前的文章里说过的三件套之一。
流程是先把原始数据转变成 torch.utils.data.Dataset 类,随后再把得到的 torch.utils.data.Dataset 类当作一个参数传递给 torch.utils.data.DataLoader 类,得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用。
所以整体的流程是
数据=》Datasets=》DataLoader
在 pytorch 中,提供了一种十分方便的数据读取机制,即使用 torch.utils.data.Dataset 与 Dataloader 组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个 batch 数据,并能在输出时对数据进行相应的预处理或数据增广操作。
二.Datasets类
如果我们要自己定义一个读取数据的方法,就得继承torch.utils.data.Dataset这个父类,并且需要重写两个方法
我们可以看一下Dataset父类的源码:
代码语言:javascript复制class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
可以清楚的看到我们需要重写两个方法分别是getitem和len方法
下面我们自定义我们自己的数据读取类
代码语言:javascript复制import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
#继承data.Dataset
# __init__方法
# __getitem__必须创建,作用:对数据切片
#__len__必须创建,作用:返回对象长度
class Tomdataset(data.Dataset):
def __init__(self,root):
self.imgs_path=root
def __getitem__(self,index):
img_path=self.imgs_path[index]
return img_path
def __len__(self):
return len(Self.imgs_path)
这里的Tomdataset类就简单的实现了一下数据加载类的方法
因为可以重写方法的实现,所以我觉得可玩性还是很高的,比如在getitem方法内我们不仅可以返回一个单纯的元素,如果在构造方法中有其他的参数也可以一并返回。所以客制性很高,可玩性也很高。
在后面的文章中,我们会使用Tomdataset类对我们的数据进行加载和处理。
未完,待续