Pytorch中的数据加载艺术

2019-08-05 11:28:31 浏览数 (1)

数据库DataBase 数据集DataSet 采样器Sampler = 加载器Loader

from torch.utils.data import *

IMDB Dataset Sampler || BatchSampler = DataLoader

数据库 DataBase

Image DataBase 简称IMDB,指的是存储在文件中的数据信息。

文件格式可以多种多样。比如xml, yaml, json, sql.

VOC是xml格式的,COCO是JSON格式的。

构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。

一般会被解析为Python列表, 以方便后续迭代读取。

数据集 DataSet

数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。

换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。

简言之,DataSet,通过__getitem__定义了数据集DataSet是一个可索引对象,An Indexerable Object。

即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。

Pytorch源码如下:

代码语言:txt复制
class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    # 定义单例/切片访问方法,即 dataItem = Dataset[index]
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。

代码语言:txt复制
# 方法一: 单继承
class XxDataset(Dataset)
    # 将IMDB作为参数传入,进行二次封装
    imdb = IMDB()
    pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
    pass

采样器 Sampler & BatchSampler

在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,

因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler

另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler

所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。

简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制

BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler BatchSize

Pytorch源码如下,

代码语言:txt复制
class Sampler(object):
    """Base class for all Samplers.
    采样器基类,可以基于此自定义采样器。
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())
    def __len__(self):
        return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
    pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
    pass
代码语言:txt复制
class BatchSampler(object):
    """Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler  # ******
        self.batch_size = batch_size
        self.drop_last = drop_last
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler)   self.batch_size - 1) // self.batch_size

由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。

[x for x in range(10)], range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.

代码语言:txt复制
[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]

BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。

代码语言:txt复制
from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

加载器 DataLoader

在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,

因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。

因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader

DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——,调用它将返回一个迭代器。

该函数可用内置函数iter直接调用,即 DataIteror = iter(DataLoader)

代码语言:txt复制
dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)

__init__参数包含两部分,前半部分用于指定数据集 采样器,后半部分为多线程参数

代码语言:txt复制
class DataLoader(object):
    """
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        if timeout < 0:
            raise ValueError('timeout option should be non-negative')
        # 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')
        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')
        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')
        # 在此处会强行指定一个 BatchSampler
        if batch_sampler is None:
            # 在此处会强行指定一个 Sampler
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        # 使用自定义的采样器和批采样器
        self.sampler = sampler
        self.batch_sampler = batch_sampler
    def __iter__(self):
        # 调用Pytorch的多线程迭代器加载数据
        return DataLoaderIter(self)
    def __len__(self):
        return len(self.batch_sampler)

数据迭代器 DataLoaderIter

迭代器与可迭代对象之间是有区别的。

可迭代对象,意思是对其使用Iter函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。

迭代器对象,内部有额外的魔法函数__next__,用内置函数next作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。

可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。

数据集/容器遍历的一般化流程:NILIS

NILIS规则: data = next(iter(loader(DataSetsampler)))data=next(iter(loader(DataSetsampler)))

  1. sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
  2. indexer 基于__item__在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。
  3. loader 基于__iter__在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。
  4. next 基于__next__在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。
代码语言:txt复制
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler)            # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable()        # __iter__()
dataIterator = DataLoaderIter(dataLoader)     #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
    data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
    data = data

<center>

代码语言:txt复制
<img src="" style="border:5px solid black;border-radius:15px;">

</center>

<b style="color:tomato;"></b>

<footer style="color:white;;background-color:rgb(24,24,24);padding:10px;border-radius:10px;"><br>

<h3 style="text-align:center;color:tomato;font-size:16px;" id="autoid-2-0-0"><br>

<b>MARSGGBO</b><b style="color:white;"><span style="font-size:25px;">♥</span>原创</b>

<b style="color:white;">

2019-8-4<p></p>

</b><p><b style="color:white;"></b>

</p></h3><br>

</footer>

0 人点赞