本文使用 Zhihu On VSCode 创作并发布
在训练GAN的过程中,一次只训练一个类别据说有助于模型收敛,但是PyTorch里面没有预设这种数据加载方式,要这样训练的话,需要自己定义Sampler,即自定义数据采样方式。下面是自定义的方法:
首先,我们虚构一个Dataset类,用于测试。
这个类中的label标签是混乱的,无法通过控制index范围来实现单类别训练。
代码语言:javascript复制class Data(Dataset):
def __init__(self):
self.img = torch.cat([torch.ones(2, 2) for i in range(100)], dim=0)
self.num_classes = 2
self.label = torch.tensor(
[random.randint(0, self.num_classes - 1) for i in range(100)]
)
def __getitem__(self, index):
return self.img[index], self.label[index]
def __len__(self):
return len(self.label)
然后,自定义一个Sampler类,这个类的作用是生成一个index列表,可以理解为重排data中的index。
代码语言:javascript复制class CustomSampler(Sampler):
def __init__(self, data):
self.data = data
def __iter__(self):
indices = []
for n in range(self.data.num_classes):
index = torch.where(self.data.label == n)[0]
indices.append(index)
indices = torch.cat(indices, dim=0)
return iter(indices)
def __len__(self):
return len(self.data)
定义好了之后可以封装成DataLoader并查看运行结果:
代码语言:javascript复制d = Data()
s = CustomSampler(d)
dl = DataLoader(d, 8, sampler=s)
for img, label in dl:
print(label)
结果
代码语言:javascript复制tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1])
显然,这样的结果并不能让人满意,有一个batch中还是包含了两种不同类型的标签,为了达到目的,我们还需要再定义一个BatchSampler类
代码语言:javascript复制class CustomBatchSampler:
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
i = 0
sampler_list = list(self.sampler)
for idx in sampler_list:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if (
i < len(sampler_list) - 1
and self.sampler.data.label[idx]
!= self.sampler.data.label[sampler_list[i 1]]
):
if len(batch) > 0 and not self.drop_last:
yield batch
batch = []
else:
batch = []
i = 1
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) self.batch_size - 1) // self.batch_size
虽然PyTorch要求Sampler需要定义成一个迭代器,但是如果你自己定义BatchSampler的话,Sampler的形式可以自己定,就算写成一个普通的列表也没关系。
再次封装成DataLoader并查看运行结果:
代码语言:javascript复制d = Data()
s = CustomSampler(d)
bs = CustomBatchSampler(s, 8, False)
dl = DataLoader(d, batch_sampler=bs)
for img, label in dl:
print(label)
drop_last = False 的结果:
代码语言:javascript复制tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1])
drop_last = True 的结果:
代码语言:javascript复制tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
以上就是自定义Sampler的步骤了。