torch Dataloader中的num_workers

2022-09-02 22:12:34 浏览数 (1)

考虑这么一个场景,有海量txt文件,一个个batch读进来,测试一下torch DataLoader的效率如何。

基本信息:

  • 本机配置:8核32G内存,工作站内置一块2T的机械硬盘,数据均放在该硬盘上
  • 操作系统:ubuntu 16.04 LTS
  • pytorch:1.0
  • python:3.6

1、首先生成很多随机文本txt

代码语言:javascript复制
def gen_test_txt():
    population = list(string.ascii_letters)   ['n']
    for i in range(1000):
        with open(f'./test_txt/{i}.txt', 'w') as f:
            f.write(
                ''.join(random.choices(population, k=1000000))
            )

2、然后顺序读取作为benchmark

代码语言:javascript复制
def test_torch_reader():
    class Dst(Dataset):
        def __init__(self, paths):
            self.paths = paths

        def __len__(self):
            return len(self.paths)

        def __getitem__(self, i):
            open(self.paths[i], 'r').read()
            return 1

    dst = Dst([f'./test_txt/{i}.txt' for i in range(1000)])
    loader = DataLoader(dst, 128, num_workers=0)

    ts = time()
    time_cost = []
    for i, ele in enumerate(loader, 1):
        dur = time() - ts
        time_cost.append(dur)
        print(i, dur)
        ts = time()

    print(f"{sum(time_cost):.3f}, "
          f"{np.mean(time_cost):.3f}, "
          f"{np.std(time_cost):.3f}, "
          f"{max(time_cost):.3f}, "
          f"{min(time_cost):.3f}")

    plt.plot(time_cost)
    plt.grid()
    plt.show()

每个batch耗时的基本统计信息如下,

基本维持在0.9 sec / batch

total, mean, std, max, min

7.148, 0.893, 0.074, 1.009, 0.726

可见,一共是1000个文件,batch size 128,也就是8个batch,总共耗时7.1s,接下来清除cache,

3、设置num_workers为4

每隔4个batch,要准备4个batch,且是串行的,因此时间增大4倍,接下来3个batch几乎不占用时间

total, mean, std, max, min

7.667, 0.958, 1.652, 3.983, 0.000

接下来实验在SSD上进行,同样num_workers先0后4,如下

total, mean, std, max, min

3.251, 0.406, 0.026, 0.423, 0.338

SSD上,对比机械硬盘更加稳定

然后是num_workers = 4,

total, mean, std, max, min

1.934, 0.242, 0.421, 1.088, 0.000

观察到同样的现象,但尖峰应该是0.4*4=1.6,这里反而epoch 4 (0-index)降为一半为0.8

基本结论:可以看到,不管是在SSD,还是机械硬盘上,总的耗时基本不变(SSD小一些,但原因也可能是实验不充分),并没有因为numworkers增大而减小,令我很费解!我一贯的理解是:比如num_workers为4,那么每个worker计算一个batch,因为本机多核且大于4,讲道理4个worker并行处理,因此时间为num_workers=0的1/4才合理,那原因是为何呢?(这个实验本来是为了load audio数据,其实在audio上作类似实验也是一致的现象)

补充了一个实验,尝试用ray读取,代码如下,

代码语言:javascript复制
def test_ray():
    ray.init()

    @ray.remote
    def read(paths):
        for path in paths:
            open(path, 'r').read()
        return 1

    def ray_read(paths, n_cpu=4):
        chunk_size = len(paths) // n_cpu
        object_ids = []
        for i in range(n_cpu):
            x = read.remote(paths[i * chunk_size: (i   1) * chunk_size])
            object_ids.append(x)

        return ray.get(object_ids)

    def batch(l, bs):
        out = []
        i = 0
        while i < len(l):
            out.append(l[i: i   bs])
            i  = bs
        return out

    paths = [os.path.expanduser(f'~/test_txt/{i}.txt') for i in range(1000)]
    paths = batch(paths, 128)

    time_cost = []
    ts = time()
    for i, ele in enumerate(paths, 1):
        # read(paths[i - 1])
        ray_read(paths[i - 1], 8)
        dur = time() - ts
        time_cost.append(dur)
        print(i, dur)
        ts = time()

    print(f"{sum(time_cost):.3f}, "
          f"{np.mean(time_cost):.3f}, "
          f"{np.std(time_cost):.3f}, "
          f"{max(time_cost):.3f}, "
          f"{min(time_cost):.3f}")

    plt.plot(time_cost)
    plt.grid()
    plt.show()

流程是这样的:将输入paths分成n_cpu个chunk,chunk之间通过ray异步执行,结果是:同样是在SSD上,理论上每个batch耗时是之前的1/4,也就是0.1s左右,但实测是0.2s,也就是说,n_cpu最大有效值就是2

0 人点赞