的多种用法,每一种方法都展示了实例,虽然有一点复杂,但是小伙伴静下心看一定能看懂哦 :)
个人建议,在1.1.1节介绍的三种方法中,推荐 方法二>方法一>方法三
- 1 torch.utils.data.DataLoader
- 1.1 dataset
- 1.1.1 Map-style datasets
- 实现方法一(简单直白法)
- 实现方法二(借助TensorDataset直接将数据包装成dataset类)
- 实现方法三(地址读取法)
- 1.1.1 Iterable-style datasets
- 1.1.1 Map-style datasets
- 1.1 dataset
- 2 torchvision.datasets
- 2.1 ImageFolder
- 3 处理示例
- 5 实用功能
- 5.1 分割dataloader
我们一般使用一个for循环(或多层的)来训练神经网络,每一次迭代,加载一个batch的数据,神经网络前向反向传播各一次并更新一次参数。 而这个过程中加载一个batch的数据这一步需要使用一个torch.utils.data.DataLoader对象,并且DataLoader是一个基于某个dataset的iterable,这个iterable每次从dataset中基于某种采样原则取出一个batch的数据。 也可以这样说:Torch中可以创建一个torch.utils.data.Dataset对象,并与torch.utils.data.DataLoader一起使用,在训练模型时不断为模型提供数据。
1 torch.utils.data.DataLoader
定义:Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
接下来我们具体分析下它的参数。 其中最重要的参数是dataset,是一个抽象类,包含两种类型:map-style datasets 和 iterable-style datasets.
dataset (Dataset) – dataset from which to load the data.
batch_size (int, optional) – how many samples per batch to load (default: 1).
shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.
batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers – 1]) as input, after seeding and before data loading. (default: None)
1.1 dataset
只支持两种类型的数据集:map-style datasets
, iterable-style datasets.
1.1.1 Map-style datasets
是一个类,要求有 __getitem__()
代码语言:javascript复制def __getitem__(self, index):
return self.src[index], self.trg[index]
代码语言:javascript复制def __len__(self):
return len(self.src)
代码语言:javascript复制class Dataset(object):
"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
重点是把 x 和 label 都分别装入两个列表 self.src 和 self.trg ,然后通过 getitem(self, index)返回对应元素。
代码语言:javascript复制import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
class My_dataset(Dataset):
def __init__(self):
# 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。
# 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里
self.x = torch.randn(1000,3)
self.y = self.x.sum(axis=1)
self.src, self.trg = [], []
for i in range(1000):
def __getitem__(self, index):
return self.src[index], self.trg[index]
def __len__(self):
return len(self.src)
# 或者return len(self.trg), src和trg长度一样
data_train = My_dataset()
data_test = My_dataset()
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
# i_batch的多少根据batch size和def __len__(self)返回的长度确定
# batch_data返回的值根据def __getitem__(self, index)来确定
# 对训练集:(不太清楚enumerate返回什么的时候就多print试试)
for i_batch, batch_data in enumerate(data_loader_train):
print(i_batch) # 打印batch编号
print(batch_data[0]) # 打印该batch里面src
print(batch_data[1]) # 打印该batch里面trg
# 对测试集:(下面的语句也可以)
for i_batch, (src, trg) in enumerate(data_loader_test):
print(i_batch) # 打印batch编号
print(src) # 打印该batch里面src的尺寸
print(trg) # 打印该batch里面trg的尺寸
多说几句:生成的data_train可以通过 data_train[xxx] 直接索引某个元素,或者通过next(iter(data_train))得到一条条的数据。
另一种方法是直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。
代码语言:javascript复制import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))
data = TensorDataset(src, trg)
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
print(i_batch) # 打印batch编号
print(batch_data[0].size()) # 打印该batch里面src
print(batch_data[1].size()) # 打印该batch里面trg
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.image as mpimg
# 对所有图片生成path-label map.txt 这个程序可根据实际需要适当修改
def generate_map(root_dir):
current_path = os.path.abspath('.')
father_path = os.path.abspath(os.path.dirname(current_path) os.path.sep ".")
with open(root_dir 'map.txt', 'w') as wfp:
for idx in range(10):
subdir = os.path.join(root_dir, '%d/' % idx)
for file_name in os.listdir(subdir):
abs_name = os.path.join(father_path, subdir, file_name)
# linux_abs_name = abs_name.replace("\", '/')
wfp.write('{file_dir} {label}n'.format(file_dir=linux_abs_name, label=idx))
# 实现MyDatasets类
class MyDatasets(Dataset):
def __init__(self, dir):
# 获取数据存放的dir
# 例如d:/images/
self.data_dir = dir
# 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4)
self.image_target_list = []
# 从dir--label的map文件中将所有的tuple对读取到image_target_list中
# map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径
with open(os.path.join(dir, 'map.txt'), 'r') as fp:
content = fp.readlines()
# 得到 [['d:/.../image_data/1/3.jpg', '1'], ...,]
str_list = [s.rstrip().split() for s in content]
# 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好
self.image_target_list = [(x[0], int(x[1])) for x in str_list]
def __getitem__(self, index):
image_label_pair = self.image_target_list[index]
# 按path读取图片数据,并转换为图片格式例如[3,32,32]
# 可以用别的代替
img = mpimg.imread(image_label_pair[0])
return img, image_label_pair[1]
def __len__(self):
return len(self.image_target_list)
if __name__ == '__main__':
# 生成map.txt
# generate_map('train/')
train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True)
for step in range(20000):
for idx, (img, label) in enumerate(train_loader):
1.1.1 Iterable-style datasets
For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time. 。。。。。 这里就不详细讲了,太特么复杂了~
2 torchvision.datasets
- MNIST -COCO(用于图像标注和目标检测)(Captioning and Detection) -LSUN Classification -ImageFolder -Imagenet-12 -CIFAR10 and CIFAR100 -STL10
Datasets 拥有以下API: __getitem__
和 __len__
2.1 ImageFolder
代码语言:javascript复制my_transform = transforms.Compose(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
torchvision.datasets.ImageFolder(root="./my_dataset/", transform=my_transform)
3 处理示例
我们在1.1.1节已经讨论了三种加载数据集的方法,现在以Crime数据集另介绍一种数据集加载办法。这种方法和 DataLoader
import numpy as np
from matplotlib import pyplot as plt
import os
import torch
class CrimeDataset():
def __init__(self, device):
reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/communities.data'))
attributes = []
while True:
# 读取 用逗号作为分隔符的数据集文件
line = reader.readline().split(',')
if len(line) < 128:
# set the ? as -1
line = ['-1' if val == '?' else val for val in line]
line = np.array(line[5:], dtype=np.float)
# attributes.shape=(1994, 123)
attributes = np.stack(attributes, axis=0)
# load the name of each column; total: 128
reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/names'))
names = []
for i in range(128):
# reader.readline().split() = ['@attribute', 'county', 'numeric'] and we choose 'county'
line = reader.readline().split()[1]
# exclude the first 5 columns. Thus the number of column names = 123, arroding with attributes.shape
if i >= 5:
names = np.array(names)
# shuffle the attribute by axis0
attributes = attributes[np.random.permutation(range(attributes.shape[0])), :]
val_size = 500
# the last column of attributes is the labels
self.train_labels = attributes[val_size:, -1:]
self.test_labels = attributes[:val_size:, -1:]
# exclude the last column of attributes. Thus attributes.shape = (1994,122)
attributes = attributes[:, :-1]
# select the column whose minimum >= 0. selected has 99 features
selected = np.argwhere(np.array([np.min(attributes[:, i]) for i in range(attributes.shape[1])]) >= 0).flatten()
self.train_features = attributes[val_size:, selected]
self.test_features = attributes[:val_size:, selected]
self.names = names[selected]
# self.train_ptr is the counter which counts the number of data records having been loaded
self.train_ptr = 0
self.test_ptr = 0
self.x_dim = self.train_features.shape[1]
# train_size = 1494; test_size = 500
self.train_size = self.train_features.shape[0]
self.test_size = self.test_features.shape[0]
self.device = device
def train_batch(self, batch_size=None):
# if batch_size is None, then each iteration outputs all the training set
if batch_size is None:
batch_size = self.train_features.shape[0]
self.train_ptr = 0
# if all data has been outputed, reset the trailoader.
if self.train_ptr batch_size > self.train_features.shape[0]:
self.train_ptr = 0
bx, by = self.train_features[self.train_ptr:self.train_ptr batch_size],
self.train_labels[self.train_ptr:self.train_ptr batch_size]
self.train_ptr = batch_size
if self.train_ptr == self.train_features.shape[0]:
self.train_ptr = 0
return torch.from_numpy(bx).float().to(self.device), torch.from_numpy(by).float().to(self.device)
def test_batch(self, batch_size=None):
if batch_size is None:
batch_size = self.test_features.shape[0]
self.train_ptr = 0
if self.test_ptr batch_size > self.test_features.shape[0]:
self.test_ptr = 0
bx, by = self.test_features[self.test_ptr:self.test_ptr batch_size],
self.test_labels[self.test_ptr:self.test_ptr batch_size]
self.test_ptr = batch_size
if self.test_ptr == self.test_features.shape[0]:
self.test_ptr = 0
return torch.from_numpy(bx).float().to(self.device), torch.from_numpy(by).float().to(self.device)
if __name__ == '__main__':
dataset = CrimeDataset("cpu")
print(dataset.train_features.shape, dataset.train_labels.shape)
5 实用功能
5.1 分割dataloader
有时候从 torchvision 里下载下来的是一个完整的数据集,包装成 dataloader
def split(dataloader, batch_size, split=0.2):
"""Splits the given dataset into training/validation. Args: dataset[torch dataloader]: Dataset which has to be split batch_size[int]: Batch size split[float]: Indicates ratio of validation samples Returns: train_set[list]: Training set val_set[list]: Validation set """
index = 0
length = len(dataloader)
train_set = []
val_set = []
for data, target in dataloader:
if index <= (length * split):
train_set.append([data, target])
val_set.append([data, target])
index = 1
return train_set, val_set
参考: https://www.cnblogs.com/leokale-zz/p/11275800.html https://www.lagou.com/lgeduarticle/74174.html 太感谢啦!https://blog.csdn.net/zuiyishihefang/article/details/105985760 torchvision-datasets