上一小节修改了我们的评估指标,然而效果并没有什么变化,甚至连指标都不能正常的输出出来。我们期望的是下面这种样子,安全事件都聚集在左边,危险事件都聚集在右边,中间只有少量的难以判断的事件,这样我们的模型很容易分出来,错误率也会比较低。
然而实际上我们的数据是下面这个样子的,大部分都是负样本,正样本只有一点点,在我们的数据集里面,阳性和阴性比值为1:400。
如果我们把模型构建的足够深,而且能够训练无数个周期,那模型还是有可能学出一个比较好的效果,不过我们的GPU可能等不到那会就已经累死了。所以我们还不如想想办法怎么让正样本能够多一些。
重复采样
我们期望正负样本能够平衡,就像下面右图中,传入模型的每个批次的数据中,正负样本都间隔出现,而实际情况是左边这样,若干批次数据只有负样本。
在这里,回到我们的dsets.py代码中,加入一个参数ratio_int,并为正样本和负样本分别创建两个索引。
代码语言:javascript复制class LunaDataset(Dataset):
def __init__(self,
val_stride=0,
isValSet_bool=None,
ratio_int=0,
):
self.ratio_int = ratio_int
…… self.negative_list = [
nt for nt in self.candidateInfo_list if not nt.isNodule_bool ]
self.pos_list = [
nt for nt in self.candidateInfo_list if nt.isNodule_bool ]
我们期望把数据处理成下面这样,每隔两个负样本,有一个正样本。
代码语言:javascript复制 def __getitem__(self, ndx):
if self.ratio_int:
pos_ndx = ndx // (self.ratio_int 1)
if ndx % (self.ratio_int 1):
neg_ndx = ndx - 1 - pos_ndx
neg_ndx %= len(self.negative_list)
candidateInfo_tup = self.negative_list[neg_ndx]
else:
pos_ndx %= len(self.pos_list)
candidateInfo_tup = self.pos_list[pos_ndx]
else:
candidateInfo_tup = self.candidateInfo_list[ndx]
通过上面这样的编写,我们可以重复的取到阳性样本,我们把总样本量设为20w,因为原来有50w数据,但是我们的正样本只有那么多,重复取50w其实效果也差不多,这样我们还能训练更快一点。
代码语言:javascript复制 def __len__(self):
if self.ratio_int:
return 200000
else:
return len(self.candidateInfo_list)
最后在我们的初始化里添加一个参数,用于记录是否需要平衡样本。
代码语言:javascript复制 def __init__(self, sys_argv=None):
……
parser.add_argument('--balanced',
help="Balance the training data to half positive, half negative.",
action='store_true',
default=False,
)
在初始化dataloader的时候把这个参数传进去
代码语言:javascript复制 def initTrainDl(self):
train_ds = LunaDataset(
val_stride=10,
isValSet_bool=False,
ratio_int=int(self.cli_args.balanced), #这就是转成了int,为1
augmentation_dict=self.augmentation_dict,
)
接下来就可以运行程序了。 这个地方遇到了程序崩溃,原因是内存超标了,经过多次尝试,如果内存不超过32G,建议把nums_works设置为6以下,batch_size设置为16比较稳妥。当然设置成这样后训练速度会变慢。这里仍然先准备缓存数据。
代码语言:javascript复制run('test12ch.prepcache.LunaPrepCacheApp')
然后开始训练
代码语言:javascript复制run('test12ch.training.LunaTrainingApp', '--epochs=1', '--balanced')
我这里为了训练比较快,还是只取了一个subset的数据。花了大概三个小时训练完以后,这里可以看到训练集20w数据,准确率是0.5,精确度是0.5召回0.47,f1 score是0.48。可以看出模型的效果不怎么好,但是,这个的优势是我们已经获得了足够多的正样本数据,同时在预测的时候能够分出正样本和负样本。
如原书中所用的训练结果,由于使用了全部的数据集,同时训练了10个 epoch,获得了0.92的f1 score,在验证集上正样本都可以有79.4的准确率,比起前两章的训练结果要好太多了,虽然我们还不能很完美的识别有问题的数据,但是至少已经能够去解决一些问题了。
数据增强
第二个用来解决样本不均衡问题的方法就是数据增强。所谓的数据增强就是通过对原始的样本做一些修改,在基本不改变核心信息的情况下制造出新的样本,同时这些样本跟原始数据又要有一些区别。比如说之前对于飞机的图像,我们对它进行镜像翻转,或者进行旋转,里面的飞机主体还在,但是图像已经不一样了,再或者对飞机的位置进行平移也可以生成一个新图像。但是对于一张图上,我们去修改几个像素的值,虽然图像发生了变化,但往往就不具有太大的价值。
常见的图像数据增强方法: 1.各种翻转图像,上下翻转,左右翻转等等 2.像素整体移动 3.放大缩小图像 4.图像旋转 5.添加噪声
接下来就是写代码了,在数据处理的代码dsets.py中加入getCtAugmentedCandidate方法,用来获取CT数据并对其进行修改。
代码语言:javascript复制def getCtAugmentedCandidate(
augmentation_dict,
series_uid, center_xyz, width_irc,
use_cache=True):
if use_cache: #从缓存中获取CT
ct_chunk, center_irc =
getCtRawCandidate(series_uid, center_xyz, width_irc)
else: #直接获取CT
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)#转换为张量
ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
transform_t = torch.eye(4)
# ... <1>
for i in range(3):#镜像方法
if 'flip' in augmentation_dict:
if random.random() > 0.5:
transform_t[i,i] *= -1#随机偏移方法
if 'offset' in augmentation_dict:
offset_float = augmentation_dict['offset']
random_float = (random.random() * 2 - 1)
transform_t[i,3] = offset_float * random_float#缩放
if 'scale' in augmentation_dict:
scale_float = augmentation_dict['scale']
random_float = (random.random() * 2 - 1)
transform_t[i,i] *= 1.0 scale_float * random_float#旋转
if 'rotate' in augmentation_dict:
angle_rad = random.random() * math.pi * 2
s = math.sin(angle_rad)
c = math.cos(angle_rad)
rotation_t = torch.tensor([
[c, -s, 0, 0],
[s, c, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
])
transform_t @= rotation_t
affine_t = F.affine_grid(
transform_t[:3].unsqueeze(0).to(torch.float32),
ct_t.size(),
align_corners=False,
)#这个地方是复制数据,前几个都是会产生一个新的块数据
augmented_chunk = F.grid_sample(
ct_t,
affine_t,
padding_mode='border',
align_corners=False,
).to('cpu')#加噪声
if 'noise' in augmentation_dict:
noise_t = torch.randn_like(augmented_chunk)
noise_t *= augmentation_dict['noise']
augmented_chunk = noise_t return augmented_chunk[0], center_irc
如下就是各种图像增强的效果,最后一行是合并的效果。
这时候就把各种增强手段对应的参数加入到训练环节,通过参数决定启用哪种增强手段。这里是修改traing.py代码。在init中设置接收参数
代码语言:javascript复制 parser.add_argument('--augmented',
help="Augment the training data.",
action='store_true',
default=False,
)
parser.add_argument('--augment-flip',
help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
action='store_true',
default=False,
)
parser.add_argument('--augment-offset',
help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
action='store_true',
default=False,
)
parser.add_argument('--augment-scale',
help="Augment the training data by randomly increasing or decreasing the size of the candidate.",
action='store_true',
default=False,
)
parser.add_argument('--augment-rotate',
help="Augment the training data by randomly rotating the data around the head-foot axis.",
action='store_true',
default=False,
)
parser.add_argument('--augment-noise',
help="Augment the training data by randomly adding noise to the data.",
action='store_true',
default=False,
)
然后是给这些增强方法设定预设值
代码语言:javascript复制 self.augmentation_dict = {}
if self.cli_args.augmented or self.cli_args.augment_flip:
self.augmentation_dict['flip'] = True
if self.cli_args.augmented or self.cli_args.augment_offset:
self.augmentation_dict['offset'] = 0.1
if self.cli_args.augmented or self.cli_args.augment_scale:
self.augmentation_dict['scale'] = 0.2
if self.cli_args.augmented or self.cli_args.augment_rotate:
self.augmentation_dict['rotate'] = True
if self.cli_args.augmented or self.cli_args.augment_noise:
self.augmentation_dict['noise'] = 25.0
如果我挨个尝试训练它,估计一周就过去了,所以我直接把原书的效果贴上来。
这里打开了TensorBoard的页面,对各种增强数据的训练效果做了对比。 其中完全增强和未增强都训练了20个epoch,其他情况训练了10个epoch。 从准确率来看,完全增强的整体准确率偏低,未增强和使用单一增强策略的整体准确率较高,但是完全增强数据在正样本的准确率上有很好的效果,比如说像这个业务,我们就是期望能够准确的发现有问题的结节,哪怕错误的判断了某些安全的结节都可以,因为处理完之后还会有人再去审核。从损失来看,全局损失和负样本损失各个策略都差不多,但是未增强的数据在正样本上损失偏高。未增强数据的f1 score和precision都比较高,但是完全增强数据的召回比较高,像我刚说过的,我们得看业务需求是什么样子的,来决定使用哪个方案。
通过两种补充数据的方法,我们的模型虽然还没有达到特别好的效果,但是显然已经能够开始工作了。