PyTorch 小课堂!带你解析数据处理全流程(二)

2022-05-25 16:05:20 浏览数 (1)

目录

1 单进程

2 多进程

3 锁页内存 (Memory Pinning)

4 预取 (prefetch)

5 代码讲解

小伙伴们大家好呀,PyTorch 源码解读系列又来更新啦!在上一篇文章中,我们介绍了对于 torch.utils.data 而言,重点的 Dataset,Sampler,DataLoader 三个模块基本内容。今天,我们着重对单进程/多进程,prefetch,pin_memory 等组件进行介绍,并对其特定功能予以解读,最后也会附上数据处理代码详解。感兴趣的小伙伴们,继续往下看吧~

1. 单进程

在单进程模式下,DataLoader 初始化的进程和取数据的进程是一样的 。因此,数据加载可能会阻止计算。但是,当用于在进程之间共享数据的资源(例如共享内存,文件描述符)有限时,或者当整个数据集很小并且可以完全加载到内存中时,此模式可能是我们首选。此外,单进程加载通常可以显示更多可读的错误跟踪,这对于我们调试代码很有用

2. 多进程

多进程处理(multi-process)

为了避免在加载数据时阻塞计算,PyTorch 提供了一个简单的开关,只需将参数设置 num_workers 为正整数即可执行多进程数据加载,而设置为 0 时执行单线程数据加载。

在设置多进程模式时,每次 DataLoader 创建 iterator 时(例如,当调用 enumerate(dataloader) 时),都会创建 num_workers 个工作进程。此时dataset, collate_fn, worker_init_fn 都会被传到每个worker中,而每个 worker 都用独立的进程。

对于 map-style 数据,主线程会用 Sampler 产生 indices,并将它们送到 worker 里。因此,shuffle 是在主线程做的。

而对于 iterable-style 数据,因为每个 worker 都有相同的 data 复制样本,并在各个进程里进行不同的操作,以防止每个进程输出的数据是重复的,所以一般会使用 torch.utils.data.get_worker_info() 来进行辅助处理。这里,torch.utils.data.get_worker_info() 会返回 worker 进程的一些信息(如id, dataset, num_workers, seed),如果在主线程的话返回 None。

注意,通常不建议在多进程加载中返回 CUDA 张量,因为在使用 CUDA 和在多处理中共享 CUDA 张量时存在许多微妙之处(文档中提出:只要接收过程保留张量的副本,就需要发送过程来保留原始张量)。建议采用 pin_memory=True ,以将数据快速传输到支持 CUDA 的 GPU。简而言之,不建议在使用多线程的情况下返回 CUDA 的 Tensor

3. 锁页内存

首先我们先了解一下锁页内存的概念。

主机中的内存,有两种存在方式,一是锁页,二是不锁页。锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。主机到 GPU 副本源自固定(页面锁定)内存时,速度要快得多。CPU 张量和存储暴露了一种 pin_memory() 方法,该方法返回对象的副本,并将数据放在固定的区域中。

而显卡中的显存全部是锁页内存!当计算机的内存充足的时候,可以设置 pin_memory=True。设置 pin_memory=True,则意味着生成的 Tensor 数据最开始是属于内存中的锁页内存,这样将内存的 Tensor 转义到 GPU 的显存就会更快一些。同时,由于 pin_memory 的作用是将张量返回之前将其复制到 CUDA 固定的内存中,所以只有在 CUDA 环境支持下才有用。

PyTorch 原生的 pin_memory 方法如下,其支持大部分 python 数据类型的处理:

代码语言:javascript复制
def pin_memory(data):
    if isinstance(data, torch.Tensor):
        return data.pin_memory()
    elif isinstance(data, string_classes):
        return data
    elif isinstance(data, container_abcs.Mapping):
        return {k: pin_memory(sample) for k, sample in data.items()}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return type(data)(*(pin_memory(sample) for sample in data))
    elif isinstance(data, container_abcs.Sequence):
        return [pin_memory(sample) for sample in data]
    elif hasattr(data, "pin_memory"):
        return data.pin_memory()
    else:
        return data

默认情况下,如果固定逻辑对于一个属于自定义类型(custom type)的 batch(如果有一个 collate_fn 返回自定义批处理类型的批处理,则会发生),或者如果该批处理的每个元素都是 custom type,则该固定逻辑将无法识别它们,它会返回该批处理(或那些元素)而无需固定内存。而要为自定义批处理或数据类型启用内存固定,我们需使用 pin_memory() 在自定义类型上自定义一个方法。如下:

代码语言:javascript复制
class SimpleCustomBatch:
    # 自定义一个类,该类不能被PyTorch原生的pin_memory方法所支持

    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())  # True
    print(sample.tgt.is_pinned())  # True

4. 预取(prefetch)

DataLoader 通过指定 prefetch_factor (默认为 2)来进行数据的预取。

代码语言:javascript复制
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        ...
        self._reset(loader, first_iter=True)

    def _reset(self, loader, first_iter=False):
        ...
        # prime the prefetch loop
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()

通过源码可以看到,prefetch 功能仅适用于多进程加载中(下面也会有多进程 dataloader 的部分代码分析)。

5. 代码详解

那么现在让我们来看看具体的代码调用流程:

代码语言:javascript复制
for data, label in train_loader:
    ......

for 循环会调用 dataloader 的 __iter__(self) 方法,以此获得迭代器来遍历 dataset。

代码语言:javascript复制
class DataLoader(Generic[T_co]):
    ...
    def __iter__(self) -> '_BaseDataLoaderIter':

        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

__iter__(self) 方法中,dataloader 调用了 self._get_iterator() 方法,根据 num_workers 获得迭代器,并指示是进行单进程还是多进程处理。

代码语言:javascript复制
class DataLoader(Generic[T_co]):
    ...
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

为了描述更加清晰,我们只考虑单进程的代码。下面是 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter) ,以及其父类 class _BaseDataLoaderIter(object): 的重点代码片段:

代码语言:javascript复制
class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        # 初始化赋值一些 DataLoader 参数,
        # 以及用户输入合法性进行校验
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._index_sampler = loader._index_sampler
        ...

    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data() # 重点代码行,通过此获取数据
            self._num_yielded  = 1
            ...
            return data

    next = __next__  # Python 2 compatibility

    def __len__(self) -> int:
        return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)

    def __getstate__(self):
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

_BaseDataLoaderIter 是所有 DataLoaderIter 的父类。dataloader获得了迭代器之后,for 循环需要调用 __next__() 来获得下一个对象,从而实现遍历。通过 __next__() 方法调用 _next_data() 获取数据。

代码语言:javascript复制
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

从 _SingleProcessDataLoaderIter 的初始化参数可以看到,其在父类 _BaseDataLoaderIter 的基础上定义了 _dataset_fetcher,并传入 _dataset,_auto_collation,_collate_fn 等参数,用于定义获取数据的方式。其具体实现会在稍后解释。

在 _next_data() 被调用后,其需要 _next_index() 获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本。

代码语言:javascript复制
class DataLoader(Generic[T_co]):
    ...
    @property
    def _auto_collation(self):
        return self.batch_sampler is not None

    @property
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler

class _BaseDataLoaderIter(object):
    ...
    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        ...

    def _next_index(self):
        # sampler_iter 来自于 index_sampler
        return next(self._sampler_iter)  # may raise StopIteration

从这里看出,dataloader 提供了 sampler(可以是batch_sampler 或者是其他 sampler 子类),然后 _SingleProcessDataLoaderIter 迭代 sampler 获得索引。

下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher)和 Iterable-style dataset(对应 _IterableDatasetFetcher),使其在 Dataloader 内能使用相同的接口 fetch,代码更加简洁。

· 对于 Map-style:直接输入索引 index,作为 map 的 key,获得对应的样本(即 value)。

代码语言:javascript复制
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 有batch_sampler,_auto_collation就为True,
            # 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

· 对于 Iterable-style: __init__ 方法内设置了 dataset 初始的迭代器,fetch 方法内获取元素,此时 index 其实已经没有多大作用了。

代码语言:javascript复制
class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 对于batch_sampler(即auto_collation==True)
            # 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            # 对于sampler,直接往后遍历并提取1个样本
            data = next(self.dataset_iter)
        return self.collate_fn(data)

最后,我们通过索引传入 fetcher,fetch 得到想要的样本。因此,整个过程调用关系总结如下:

loader.iter --> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> __next__() --> self._next_data() --> self._next_index() -->next(self._sampler_iter) 即 next(iter(self._index_sampler)) --> 获得 index --> self._dataset_fetcher.fetch(index) --> 获得 data

而对于多进程而言,借用 PyTorch 内源码的注释,其运行流程解释如下:

代码语言:javascript复制
# Our data model looks like this (queues are indicated with curly brackets):
#
#                main process                              ||
#                     |                                    ||
#               {index_queue}                              ||
#                     |                                    ||
#              worker processes                            ||     DATA
#                     |                                    ||
#            {worker_result_queue}                         ||     FLOW
#                     |                                    ||
#      pin_memory_thread of main process                   ||   DIRECTION
#                     |                                    ||
#               {data_queue}                               ||
#                     |                                    ||
#                data output                               /
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
#      `pin_memory=False`.

首先 dataloader 基于 multiprocessing 产生多进程,每个子进程的输入输出通过两个主要的队列(multiprocessing.Queue() 类)产生,分别为:

· index_queue:每个子进程的队列中需要处理的任务的下标

· _worker_result_queue:返回时处理完任务的下标

· data_queue:表明经过 pin_memory 处理后的数据队列

并且有以下这些比较重要的 flag 参数来协调各个 worker 之间的工作:

· _send_idx: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx

· _rcvd_idx: 接受索引,记录要从 data_queue 中取出的 batch 的 idx

· _task_info: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整形索引),value 为 (worker_id,) 或 (worker_id, data),分别对应数据未取和已取的情况

· _tasks_outstanding: 整形,代表已经准备好的 task/batch 的数量(可能有些正在准备中)

每个 worker 一次产生一个 batch 的数据,返回 batch 数据前放入下一个批次要处理的数据下标,对应构造函数子进程初始化如下:

代码语言:javascript复制
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)
        ...
        self._worker_result_queue = multiprocessing_context.Queue()  # 把该worker取出的数放入该队列,用于进程间通信
        ...
        self._workers_done_event = multiprocessing_context.Event()
        self._index_queues = []
        self._workers = []
        for i in range(self._num_workers):
            index_queue = multiprocessing_context.Queue()  # 索引队列,每个子进程一个队列放要处理的下标
            index_queue.cancel_join_thread()
            # _worker_loop 的作用是:从index_queue中取索引,然后通过collate_fn处理数据,
            # 然后再将处理好的 batch 数据放到 data_queue 中。(发送到队列中的idx是self.send_idx)
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop,  # 每个worker子进程循环执行的函数,主要将数据以(idx, data)的方式传入_worker_result_queue中
                args=(self._dataset_kind, self._dataset, index_queue, 
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed   i, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers))
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)
        if self._pin_memory:
            self._pin_memory_thread_done_event = threading.Event()
            self._data_queue = queue.Queue()  # 用于存取出的数据进行 pin_memory 操作后的结果
            pin_memory_thread = threading.Thread(
                target=_utils.pin_memory._pin_memory_loop,
                args=(self._worker_result_queue, self._data_queue,
                      torch.cuda.current_device(),
                      self._pin_memory_thread_done_event))
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
            # Similar to workers (see comment above), we only register
            # pin_memory_thread once it is started.
            self._pin_memory_thread = pin_memory_thread
        else:
            self._data_queue = self._worker_result_queue
        ...
        self._reset(loader, first_iter=True)
    def _reset(self, loader, first_iter=False):
        super()._reset(loader, first_iter)
        self._send_idx = 0  # idx of the next task to be sent to workers,发送索引,用来记录这次要放 index_queue 中 batch 的 idx
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__,接受索引,记录要从 data_queue 中取出的 batch 的 idx
        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
        #                   (worker_id, data)   if data is already fetched (out-of-order)
        self._task_info = {}
        # _tasks_outstanding 指示当前已经准备好的 task/batch 的数量(可能有些正在准备中)
        # 初始值为 0, 在 self._try_put_index() 中  1,在 self._next_data 中-1
        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
        # this indicates status that a worker still has work to do *for this epoch*.
        self._workers_status = [True for i in range(self._num_workers)] 
        # We resume the prefetching in case it was enabled
        if not first_iter:
            for idx in range(self._num_workers):
                self._index_queues[idx].put(_utils.worker._ResumeIteration())
            resume_iteration_cnt = self._num_workers
            while resume_iteration_cnt > 0:
                data = self._get_data()
                if isinstance(data, _utils.worker._ResumeIteration):
                    resume_iteration_cnt -= 1
        ...
        # 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index() # 进行预取

dataloader 初始化的时候,每个 worker 的 index_queue 默认会放入两个 batch 的 index,从 index_queue 中取出要处理的下标。

代码语言:javascript复制
def _try_put_index(self):
        # self._prefetch_factor 默认为 2
        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
        try:
            index = self._next_index()
        except StopIteration:
            return
        for _ in range(self._num_workers):  # find the next active worker, if any
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]:
                break
        else:
            # not found (i.e., didn't break)
            return
        self._index_queues[worker_queue_idx].put((self._send_idx, index)) # 放入 任务下标 和 数据下标
        self._task_info[self._send_idx] = (worker_queue_idx,)
        # _tasks_outstanding   1,表明预备好的batch个数 1
        self._tasks_outstanding  = 1
        # send_idx 发送索引, 记录从sample_iter中发送索引到index_queue的次数
        self._send_idx  = 1

调用 _next_data(self) 方法进行数据读取,其中 _process_data(self, data) 用于返回数据。

代码语言:javascript复制
def _next_data(self):
        while True:

            while self._rcvd_idx < self._send_idx: # 确保待处理的任务(待取的batch)下标 > 处理完毕要返回的任务(已经取完的batch)下标
                info = self._task_info[self._rcvd_idx]
                worker_id = info[0]
                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                    break
                del self._task_info[self._rcvd_idx]
                self._rcvd_idx  = 1
            else:
                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                if not self._persistent_workers:
                    self._shutdown_workers()
                raise StopIteration

            # Now `self._rcvd_idx` is the batch index we want to fetch

            # Check if the next sample has already been generated
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data)

            assert not self._shutdown and self._tasks_outstanding > 0
            idx, data = self._get_data() # 调用 self._try_get_data() 从 self._data_queue 中取数
            self._tasks_outstanding -= 1  # 表明预备好的batch个数需要减1
            if self._dataset_kind == _DatasetKind.Iterable:
                # Check for _IterableDatasetStopIteration
                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                    if self._persistent_workers:
                        self._workers_status[data.worker_id] = False
                    else:
                        self._mark_worker_as_unavailable(data.worker_id)
                    self._try_put_index()
                    continue

            if idx != self._rcvd_idx:
                # store out-of-order samples
                self._task_info[idx]  = (data,)
            else:
                del self._task_info[idx]
                return self._process_data(data) # 返回数据

    def _process_data(self, data):
        self._rcvd_idx  = 1
        self._try_put_index() # 同上,主要放入队列索引 以及 更新flag
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data

这样,多进程模式的 dataloader 就能通过多个 worker 的协作来共同完成数据的加载。

以上就是本次数据处理全流程解析全部内容了,你,学会了嘛?感兴趣的小伙伴,不要忘记点赞收藏评论呀~在之后的系列文章里,我们还会带大家回味 PyTorch 中的神经网络模块,即 torch.nn 模块,记得来看噢!

0 人点赞