使用内存映射加快PyTorch数据集的读取

2022-11-11 17:49:51 浏览数 (1)

点击上方“Deephub Imba”,关注公众号,好文章不错过 !

本文将介绍如何使用内存映射文件加快PyTorch数据集的加载速度

在使用Pytorch训练神经网络时,最常见的与速度相关的瓶颈是数据加载的模块。如果我们将数据通过网络传输,除了预取和缓存之外,没有任何其他的简单优化方式。

但是如果数据本地存储,我们可以通过将整个数据集组合成一个文件,然后映射到内存中来优化读取操作,这样我们每次文件读取数据时就不需要访问磁盘,而是从内存中直接读取可以加快运行速度。

什么是内存映射文件

内存映射文件(memory-mapped file)是将完整或者部分文件加载到内存中,这样就可以通过内存地址相关的load或者store指令来操纵文件。为了支持这个功能,现代的操作系统会提供一个叫做mmap的系统调用。这个系统调用会接收一个虚拟内存地址(VA),长度(len),protection,一些标志位,一个打开文件的文件描述符,和偏移量(offset)。

由于虚拟内存代表的附加抽象层,我们可以映射比机器的物理内存容量大得多的文件。正在运行的进程所需的内存段(称为页)从外部存储中获取,并由虚拟内存管理器自动复制到主内存中。

使用内存映射文件可以提高I/O性能,因为通过系统调用进行的普通读/写操作比在本地内存中进行更改要慢得多,对于操作系统来说,文件以一种“惰性”的方式加载,通常一次只加载一个页,因此即使对于较大的文件,实际RAM利用率也是最低的,但是使用内存映射文件可以改善这个流程。

什么是PyTorch数据集

Pytorch提供了用于在训练模型时处理数据管道的两个主要模块:Dataset和DataLoader。

DataLoader主要用作Dataset的加载,它提供了许多可配置选项,如批处理、采样、预读取、变换等,并抽象了许多方法。

Dataset是我们进行数据集处理的实际部分,在这里我们编写训练时读取数据的过程,包括将样本加载到内存和进行必要的转换。

对于Dataset,必须实现:__init_,__len____getitem__ 三个方法

实现自定义数据集

接下来,我们将看到上面提到的三个方法的实现。

最重要的部分是在__init__中,我们将使用 numpy 库中的 np.memmap() 函数来创建一个ndarray将内存缓冲区映射到本地的文件。

在数据集初始化时,将ndarray使用可迭代对象进行填充,代码如下:

代码语言:javascript复制
class MMAPDataset(Dataset):
    def __init__(
        self,
        input_iter: Iterable[np.ndarray],
        labels_iter: Iterable[np.ndarray],
        mmap_path: str = None,
        size: int = None,
        transform_fn: Callable[..., Any] = None,
        
    ) -> None:
        super().__init__()

        self.mmap_inputs: np.ndarray = None
        self.mmap_labels: np.ndarray = None
        self.transform_fn = transform_fn

        if mmap_path is None:
            mmap_path = os.path.abspath(os.getcwd())
        self._mkdir(mmap_path)

        self.mmap_input_path = os.path.join(mmap_path, DEFAULT_INPUT_FILE_NAME)
        self.mmap_labels_path = os.path.join(mmap_path, DEFAULT_LABELS_FILE_NAME)
        self.length = size

        for idx, (input, label) in enumerate(zip(input_iter, labels_iter)):
            if self.mmap_inputs is None:
                self.mmap_inputs = self._init_mmap(
                    self.mmap_input_path, input.dtype, (self.length, *input.shape)
                )
                self.mmap_labels = self._init_mmap(
                    self.mmap_labels_path, label.dtype, (self.length, *label.shape)
                )

            self.mmap_inputs[idx][:] = input[:]
            self.mmap_labels[idx][:] = label[:]


    def __getitem__(self, idx: int) -> Tuple[Union[np.ndarray, torch.Tensor]]:
        if self.transform_fn:
            return self.transform_fn(self.mmap_inputs[idx]), torch.tensor(self.mmap_labels[idx]) 
        return self.mmap_inputs[idx], self.mmap_labels[idx]


    def __len__(self) -> int:
        return self.length

我们在上面提供的代码中还使用了两个辅助函数。

代码语言:javascript复制
def _mkdir(self, path: str) -> None:
        if os.path.exists(path):
            return

        try:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            return
        except:
            raise ValueError(
                "Failed to create the path (check the user write permissions)."
            )


    def _init_mmap(self, path: str, dtype: np.dtype, shape: Tuple[int], remove_existing: bool = False) -> np.ndarray:
        open_mode = "r "

        if remove_existing:
            open_mode = "w "
        
        return np.memmap(
            path,
            dtype=dtype,
            mode=open_mode,
            shape=shape,
        )

可以看到,上面我们自定义数据集与一般情况的主要区别就是_init_mmap中调用的np.memmap(),所以这里我们对np.memmap() 做一个简单的解释:

Numpy的memmap对象,它允许将大文件分成小段进行读写,而不是一次性将整个数组读入内存。memmap也拥有跟普通数组一样的方法,基本上只要是能用于ndarray的算法就也能用于memmap。

使用函数np.memmap并传入一个文件路径、数据类型、形状以及文件模式,即可创建一个新的memmap存储在磁盘上的二进制文件创建内存映射。

对于更多的介绍请参考Numpy的文档,这里就不做详细的解释了

基准测试

为了实际展示性能提升,我将内存映射数据集实现与以经典方式读取文件的普通数据集实现进行了比较。这里使用的数据集由 350 张 jpg 图像组成。

从下面的结果中,我们可以看到我们的数据集比普通数据集快 30 倍以上:

总结

本文中介绍的方法在加速Pytorch的数据读取是非常有效的,尤其是使用大文件时,但是这个方法需要很大的内存,在做离线训练时是没有问题的,因为我们能够完全的控制我们的数据,但是如果想在生产中应用还需要考虑使用,因为在生产中有些数据我们是无法控制的。

最后Numpy的文档地址如下:

https://numpy.org/doc/stable/reference/generated/numpy.memmap.html

有兴趣的可以详细了解

本文的作者在github上也创建了一个项目,有兴趣的可以看看:

https://github.com/DACUS1995/pytorch-mmap-dataset

作者:Tudor Surdoiu


MORE

kaggle比赛交流和组队

加我的微信,邀你进群

喜欢就关注一下吧:

点个 在看 你最好看!

0 人点赞