大家好,又见面了,我是你们的朋友全栈君。
本博客讲解了pytorch框架下DataLoader
的多种用法,每一种方法都展示了实例,虽然有一点复杂,但是小伙伴静下心看一定能看懂哦 :)
个人建议,在1.1.1节介绍的三种方法中,推荐 方法二>方法一>方法三
(方法三实在是过于复杂不做推荐),另外,第三节中的处理示例使用了非DataLoader
的方法进行数据集处理,也可以借鉴~
目录
- 1 torch.utils.data.DataLoader
- 1.1 dataset
- 1.1.1 Map-style datasets
- 实现方法一(简单直白法)
- 实现方法二(借助TensorDataset直接将数据包装成dataset类)
- 实现方法三(地址读取法)
- 1.1.1 Iterable-style datasets
- 1.1.1 Map-style datasets
- 1.1 dataset
- 2 torchvision.datasets
- 2.1 ImageFolder
- 3 处理示例
- 5 实用功能
- 5.1 分割dataloader
我们一般使用一个for循环(或多层的)来训练神经网络,每一次迭代,加载一个batch的数据,神经网络前向反向传播各一次并更新一次参数。 而这个过程中加载一个batch的数据这一步需要使用一个torch.utils.data.DataLoader对象,并且DataLoader是一个基于某个dataset的iterable,这个iterable每次从dataset中基于某种采样原则取出一个batch的数据。 也可以这样说:Torch中可以创建一个torch.utils.data.Dataset对象,并与torch.utils.data.DataLoader一起使用,在训练模型时不断为模型提供数据。
1 torch.utils.data.DataLoader
定义:Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
我们先来看一看其构造函数的参数
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
接下来我们具体分析下它的参数。 其中最重要的参数是dataset,是一个抽象类,包含两种类型:map-style datasets 和 iterable-style datasets.
dataset (Dataset) – dataset from which to load the data.
batch_size (int, optional) – how many samples per batch to load (default: 1).
shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.
batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers – 1]) as input, after seeding and before data loading. (default: None)
1.1 dataset
只支持两种类型的数据集:map-style datasets
, iterable-style datasets.
1.1.1 Map-style datasets
是一个类,要求有 __getitem__()
and__len__()
这两个构造函数,代表一个从索引映射到数据样本。
(1)其中__getitem__函数的作用是根据索引index遍历数据
(2)__len__函数的作用是返回数据集的长度
(3)在创建的dataset类中可根据自己的需求对数据进行处理。可编写独立的数据处理函数,在__getitem__函数中进行调用;或者直接将数据处理方法写在__getitem__函数中或者__init__函数中,但__getitem__必须根据index返回响应的值,该值会通过index传到dataloader中进行后续的batch批处理。
即基本满足:
代码语言:javascript复制def __getitem__(self, index):
return self.src[index], self.trg[index]
代码语言:javascript复制def __len__(self):
return len(self.src)
看一下他的大概构造:
代码语言:javascript复制class Dataset(object):
"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
上述代码是pytorch中Datasets的源码,注意成员方法__getitem__和__len__都是未实现的。我们要实现自定义Datasets类来完成数据的读取,则只需要完成这两个成员方法的重写。
首先,getitem()方法用来从datasets中读取一条数据,这条数据包含训练图片(已CV距离)和标签,参数index表示图片和标签在总数据集中的Index。
len()方法返回数据集的总长度(训练集的总数)。
下面介绍两种简单实现MyDatasets类
实现方法一(简单直白法)
重点是把 x 和 label 都分别装入两个列表 self.src 和 self.trg ,然后通过 getitem(self, index)返回对应元素。
代码语言:javascript复制import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
class My_dataset(Dataset):
def __init__(self):
super().__init__()
# 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
# 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
self.x = torch.randn(1000,3)
self.y = self.x.sum(axis=1)
self.src, self.trg = [], []
for i in range(1000):
self.src.append(self.x[i])
self.trg.append(self.y[i])
def __getitem__(self, index):
return self.src[index], self.trg[index]
def __len__(self):
return len(self.src)
# 或者return len(self.trg), src和trg长度一样
data_train = My_dataset()
data_test = My_dataset()
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
# i_batch的多少根据batch size和def __len__(self)返回的长度确定
# batch_data返回的值根据def __getitem__(self, index)来确定
# 对训练集:(不太清楚enumerate返回什么的时候就多print试试)
for i_batch, batch_data in enumerate(data_loader_train):
print(i_batch) # 打印batch编号
print(batch_data[0]) # 打印该batch里面src
print(batch_data[1]) # 打印该batch里面trg
# 对测试集:(下面的语句也可以)
for i_batch, (src, trg) in enumerate(data_loader_test):
print(i_batch) # 打印batch编号
print(src) # 打印该batch里面src的尺寸
print(trg) # 打印该batch里面trg的尺寸
多说几句:生成的data_train可以通过 data_train[xxx] 直接索引某个元素,或者通过next(iter(data_train))得到一条条的数据。
实现方法二(借助TensorDataset直接将数据包装成dataset类)
另一种方法是直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。
代码语言:javascript复制import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))
data = TensorDataset(src, trg)
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
print(i_batch) # 打印batch编号
print(batch_data[0].size()) # 打印该batch里面src
print(batch_data[1].size()) # 打印该batch里面trg
output:
代码语言:javascript复制0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
...
实现方法三(地址读取法)
适用于lfw这样的数据集,每一份数据都对应一个文件夹,或者说数据量过大,无法一次加载出来的数据集。并且要求这样的数据集,有一个txt文件
可以进行索引!
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.image as mpimg
# 对所有图片生成path-label map.txt 这个程序可根据实际需要适当修改
def generate_map(root_dir):
#得到当前绝对路径
current_path = os.path.abspath('.')
#os.path.dirname()向前退一个路径
father_path = os.path.abspath(os.path.dirname(current_path) os.path.sep ".")
with open(root_dir 'map.txt', 'w') as wfp:
for idx in range(10):
subdir = os.path.join(root_dir, '%d/' % idx)
for file_name in os.listdir(subdir):
abs_name = os.path.join(father_path, subdir, file_name)
# linux_abs_name = abs_name.replace("\", '/')
wfp.write('{file_dir} {label}n'.format(file_dir=linux_abs_name, label=idx))
# 实现MyDatasets类
class MyDatasets(Dataset):
def __init__(self, dir):
# 获取数据存放的dir
# 例如d:/images/
self.data_dir = dir
# 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4)
self.image_target_list = []
# 从dir--label的map文件中将所有的tuple对读取到image_target_list中
# map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径
with open(os.path.join(dir, 'map.txt'), 'r') as fp:
content = fp.readlines()
#s.rstrip()删除字符串末尾指定字符(默认是字符)
# 得到 [['d:/.../image_data/1/3.jpg', '1'], ...,]
str_list = [s.rstrip().split() for s in content]
# 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好
self.image_target_list = [(x[0], int(x[1])) for x in str_list]
def __getitem__(self, index):
image_label_pair = self.image_target_list[index]
# 按path读取图片数据,并转换为图片格式例如[3,32,32]
# 可以用别的代替
img = mpimg.imread(image_label_pair[0])
return img, image_label_pair[1]
def __len__(self):
return len(self.image_target_list)
if __name__ == '__main__':
# 生成map.txt
# generate_map('train/')
train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True)
for step in range(20000):
for idx, (img, label) in enumerate(train_loader):
print(img.shape)
print(label.shape)
如果使用其他形式的数据,例如二进制文件,则需要字节读取文件,分割成每一张图片和label,然后从__getitem__中返回就可以了。例如cifar-10数据,我们只需要在__getitem__方法中,按index来读取对应位置的字节,然后转换为label和img,并返回。在__len__中返回cifar-10训练集的总样本数。DataLoader就可以根据我们提供的index,len以及batch_size,shuffle来返回相应的batch数据和label。
1.1.1 Iterable-style datasets
可迭代样式的数据集是IterableDataset的一个实例,该实例必须重写__iter__方法,该方法用于对数据集进行迭代。这种类型的数据集特别适合随机读取数据不太可能实现的情况,并且批处理大小batchsize取决于获取的数据。比如读取数据库,远程服务器或者实时日志等数据的时候,可使用该样式,一般时序数据不使用这种样式。
For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time. 。。。。。 这里就不详细讲了,太特么复杂了~
2 torchvision.datasets
这个包的作用是方便提供现成数据集。
torchvision.datasets
中包含了以下数据集
- MNIST -COCO(用于图像标注和目标检测)(Captioning and Detection) -LSUN Classification -ImageFolder -Imagenet-12 -CIFAR10 and CIFAR100 -STL10
Datasets 拥有以下API: __getitem__
和 __len__
具体用法看参考第四条(搭配torch.utils.data.DataLoader)
2.1 ImageFolder
这个和DatasetFolder一样,适合用于已经下载好的并且符合一定要求的数据集,ImageFolder要求数据呈这样分布:
代码语言:javascript复制root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
使用方法:
代码语言:javascript复制my_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
torchvision.datasets.ImageFolder(root="./my_dataset/", transform=my_transform)
3 处理示例
我们在1.1.1节已经讨论了三种加载数据集的方法,现在以Crime数据集另介绍一种数据集加载办法。这种方法和 DataLoader
没有任何关系,实现起来的复杂度一般。
import numpy as np
from matplotlib import pyplot as plt
import os
import torch
class CrimeDataset():
def __init__(self, device):
reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/communities.data'))
attributes = []
while True:
# 读取 用逗号作为分隔符的数据集文件
line = reader.readline().split(',')
if len(line) < 128:
break
# set the ? as -1
line = ['-1' if val == '?' else val for val in line]
line = np.array(line[5:], dtype=np.float)
attributes.append(line)
reader.close()
# attributes.shape=(1994, 123)
attributes = np.stack(attributes, axis=0)
# load the name of each column; total: 128
reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/names'))
names = []
for i in range(128):
# reader.readline().split() = ['@attribute', 'county', 'numeric'] and we choose 'county'
line = reader.readline().split()[1]
# exclude the first 5 columns. Thus the number of column names = 123, arroding with attributes.shape
if i >= 5:
names.append(line)
names = np.array(names)
# shuffle the attribute by axis0
attributes = attributes[np.random.permutation(range(attributes.shape[0])), :]
val_size = 500
# the last column of attributes is the labels
self.train_labels = attributes[val_size:, -1:]
self.test_labels = attributes[:val_size:, -1:]
# exclude the last column of attributes. Thus attributes.shape = (1994,122)
attributes = attributes[:, :-1]
# select the column whose minimum >= 0. selected has 99 features
selected = np.argwhere(np.array([np.min(attributes[:, i]) for i in range(attributes.shape[1])]) >= 0).flatten()
self.train_features = attributes[val_size:, selected]
self.test_features = attributes[:val_size:, selected]
self.names = names[selected]
# self.train_ptr is the counter which counts the number of data records having been loaded
self.train_ptr = 0
self.test_ptr = 0
self.x_dim = self.train_features.shape[1]
# train_size = 1494; test_size = 500
self.train_size = self.train_features.shape[0]
self.test_size = self.test_features.shape[0]
self.device = device
def train_batch(self, batch_size=None):
# if batch_size is None, then each iteration outputs all the training set
if batch_size is None:
batch_size = self.train_features.shape[0]
self.train_ptr = 0
# if all data has been outputed, reset the trailoader.
if self.train_ptr batch_size > self.train_features.shape[0]:
self.train_ptr = 0
bx, by = self.train_features[self.train_ptr:self.train_ptr batch_size],
self.train_labels[self.train_ptr:self.train_ptr batch_size]
self.train_ptr = batch_size
if self.train_ptr == self.train_features.shape[0]:
self.train_ptr = 0
return torch.from_numpy(bx).float().to(self.device), torch.from_numpy(by).float().to(self.device)
def test_batch(self, batch_size=None):
if batch_size is None:
batch_size = self.test_features.shape[0]
self.train_ptr = 0
if self.test_ptr batch_size > self.test_features.shape[0]:
self.test_ptr = 0
bx, by = self.test_features[self.test_ptr:self.test_ptr batch_size],
self.test_labels[self.test_ptr:self.test_ptr batch_size]
self.test_ptr = batch_size
if self.test_ptr == self.test_features.shape[0]:
self.test_ptr = 0
return torch.from_numpy(bx).float().to(self.device), torch.from_numpy(by).float().to(self.device)
if __name__ == '__main__':
dataset = CrimeDataset("cpu")
print(dataset.names)
print(dataset.train_features.shape, dataset.train_labels.shape)
5 实用功能
5.1 分割dataloader
有时候从 torchvision 里下载下来的是一个完整的数据集,包装成 dataloader
`以后我们想把该数据集进行进一步划分:
def split(dataloader, batch_size, split=0.2):
"""Splits the given dataset into training/validation. Args: dataset[torch dataloader]: Dataset which has to be split batch_size[int]: Batch size split[float]: Indicates ratio of validation samples Returns: train_set[list]: Training set val_set[list]: Validation set """
index = 0
length = len(dataloader)
train_set = []
val_set = []
for data, target in dataloader:
if index <= (length * split):
train_set.append([data, target])
else:
val_set.append([data, target])
index = 1
return train_set, val_set
还有更好的分割方法见:pytorch数据集的分割
参考: https://www.cnblogs.com/leokale-zz/p/11275800.html https://www.lagou.com/lgeduarticle/74174.html 太感谢啦!https://blog.csdn.net/zuiyishihefang/article/details/105985760 torchvision-datasets
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/131858.html原文链接:https://javaforall.cn