自定义PyTorch中的Sampler

2020-10-23 14:51:18 浏览数 (1)

本文使用 Zhihu On VSCode 创作并发布

在训练GAN的过程中,一次只训练一个类别据说有助于模型收敛,但是PyTorch里面没有预设这种数据加载方式,要这样训练的话,需要自己定义Sampler,即自定义数据采样方式。下面是自定义的方法:

首先,我们虚构一个Dataset类,用于测试。

这个类中的label标签是混乱的,无法通过控制index范围来实现单类别训练。

代码语言:javascript复制
class Data(Dataset):
    def __init__(self):
        self.img = torch.cat([torch.ones(2, 2) for i in range(100)], dim=0)
        self.num_classes = 2
        self.label = torch.tensor(
            [random.randint(0, self.num_classes - 1) for i in range(100)]
        )

    def __getitem__(self, index):
        return self.img[index], self.label[index]

    def __len__(self):
        return len(self.label)

然后,自定义一个Sampler类,这个类的作用是生成一个index列表,可以理解为重排data中的index。

代码语言:javascript复制
class CustomSampler(Sampler):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        indices = []
        for n in range(self.data.num_classes):
            index = torch.where(self.data.label == n)[0]
            indices.append(index)
        indices = torch.cat(indices, dim=0)
        return iter(indices)

    def __len__(self):
        return len(self.data)

定义好了之后可以封装成DataLoader并查看运行结果:

代码语言:javascript复制
d = Data()
s = CustomSampler(d)
dl = DataLoader(d, 8, sampler=s)
for img, label in dl:
    print(label)

结果

代码语言:javascript复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1])

显然,这样的结果并不能让人满意,有一个batch中还是包含了两种不同类型的标签,为了达到目的,我们还需要再定义一个BatchSampler类

代码语言:javascript复制
class CustomBatchSampler:
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        i = 0
        sampler_list = list(self.sampler)
        for idx in sampler_list:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []

            if (
                i < len(sampler_list) - 1
                and self.sampler.data.label[idx]
                != self.sampler.data.label[sampler_list[i   1]]
            ):
                if len(batch) > 0 and not self.drop_last:
                    yield batch
                    batch = []
                else:
                    batch = []
            i  = 1
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler)   self.batch_size - 1) // self.batch_size

虽然PyTorch要求Sampler需要定义成一个迭代器,但是如果你自己定义BatchSampler的话,Sampler的形式可以自己定,就算写成一个普通的列表也没关系。

再次封装成DataLoader并查看运行结果:

代码语言:javascript复制
d = Data()
s = CustomSampler(d)
bs = CustomBatchSampler(s, 8, False)
dl = DataLoader(d, batch_sampler=bs)
for img, label in dl:
    print(label)

drop_last = False 的结果:

代码语言:javascript复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1])

drop_last = True 的结果:

代码语言:javascript复制
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])

以上就是自定义Sampler的步骤了。

0 人点赞