25 | 使用PyTorch完成医疗图像识别大项目:分割模型实现

2022-07-11 15:53:22 浏览数 (2)

前面已经把分割模型的数据处理的差不多了,最后再加一点点关于数据增强的事情,我们就可以开始训练模型了。

常见的瓶颈

在搞机器学习项目的时候,总会有各种各样的瓶颈问题,比如IO问题,内存问题,GPU问题等等。因为我们的设备总会有一个短板的地方。

  • 1.数据加载环节,数据的大量IO(读写)可能会比较慢。
  • 2.使用CPU进行数据预处理环节可能出现瓶颈,通常来说是进行正则化和数据增强的时候。
  • 3.在模型训练的时候GPU可能是最大的瓶颈,如果说一定存在瓶颈那么我们希望是在GPU这块,因为GPU是最贵的。
  • 4.在CPU和GPU直接传输数据的带宽会影响GPU的运算。 这里我们要处理的就是在数据增强环节使用GPU。这里所做的数据增强方式跟之前一模一样,只不过这次通过类似于模型的方式实现,我们把这些步骤放在前向传播的方法里面,把它变成一个看起来像模型训练的过程,
代码语言:javascript复制
class SegmentationAugmentation(nn.Module):
    def __init__(
            self, flip=None, offset=None, scale=None, rotate=None, noise=None
    ):
        super().__init__()

        self.flip = flip        self.offset = offset        self.scale = scale        self.rotate = rotate        self.noise = noise    def forward(self, input_g, label_g):#这里是获取变换方法
        transform_t = self._build2dTransformMatrix()
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)#因为GPU适合处理浮点数,这里传入GPU的同时转换成浮点数
        transform_t = transform_t.to(input_g.device, torch.float32)#affine_grid和grid_sample就是实现变换和重新采样(生成新图像)的方法
        affine_t = F.affine_grid(transform_t[:,:2],
                input_g.size(), align_corners=False)

        augmented_input_g = F.grid_sample(input_g,
                affine_t, padding_mode='border',
                align_corners=False)#这里同时在掩码操作
        augmented_label_g = F.grid_sample(label_g.to(torch.float32),
                affine_t, padding_mode='border',
                align_corners=False)#最后是增加噪声
        if self.noise:
            noise_t = torch.randn_like(augmented_input_g)
            noise_t *= self.noise

            augmented_input_g  = noise_t        return augmented_input_g, augmented_label_g > 0.5

    def _build2dTransformMatrix(self):
        transform_t = torch.eye(3)

        for i in range(2):
            if self.flip:
                if random.random() > 0.5:
                    transform_t[i,i] *= -1

            if self.offset:
                offset_float = self.offset
                random_float = (random.random() * 2 - 1)
                transform_t[2,i] = offset_float * random_float            if self.scale:
                scale_float = self.scale
                random_float = (random.random() * 2 - 1)
                transform_t[i,i] *= 1.0   scale_float * random_float        if self.rotate:
            angle_rad = random.random() * math.pi * 2
            s = math.sin(angle_rad)
            c = math.cos(angle_rad)

            rotation_t = torch.tensor([
                [c, -s, 0],
                [s, c, 0],
                [0, 0, 1]])

            transform_t @= rotation_t        return transform_t

接下来就是实现training环节。我们先把内部的一些方法写好。第一个是给模型进行初始化。

代码语言:javascript复制
    def initModel(self):#使用我们封装的UNet模型
        segmentation_model = UNetWrapper(
            in_channels=7,
            n_classes=1,
            depth=3,
            wf=4,
            padding=True,
            batch_norm=True,
            up_mode='upconv',
        )#数据增强模型,实际上并不是一个真的模型
        augmentation_model = SegmentationAugmentation(**self.augmentation_dict)#设置使用GPU,甚至是GPU并行运算
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                segmentation_model = nn.DataParallel(segmentation_model)
                augmentation_model = nn.DataParallel(augmentation_model)#把模型传入GPU
            segmentation_model = segmentation_model.to(self.device)
            augmentation_model = augmentation_model.to(self.device)#返回模型实例
        return segmentation_model, augmentation_model

第二个要定义的是优化器。在这里使用Adam优化器。Adam有很多的优点,比如说不太需要我们去调整参数,它会为每个参数维护一个单独的学习率,并且可以根据训练的进行自动更新学习率。这个只需要一行调用就可以实现,如果你想了解Adam的细节,可以点进去研究一下它的源代码。

代码语言:javascript复制
    def initOptimizer(self):
        return Adam(self.segmentation_model.parameters())

第三个是定义损失函数。这块我们又要换一个新的损失计算方法了。前面我们已经学过L1损失,L2损失,交叉熵损失,现在新加一个骰子损失(Dice Loss)。它的计算逻辑也不难理解,是按照实际的图像面积和预测出来的图像面积进行比较的,这是在图像分割领域常用的损失计算方法。看下面这张图,考虑实际的图像是圆圈内的图像,预测的图像是方框区域的图像,其中阴影部分就是预测命中的部分,而dice系数的计算就是阴影面积的二倍除方框加圆圈的面积。可以想象,当预测完全准确的时候这个系数计算出来是1.0,所以我们使用1-dice系数作为损失,因为我们期望预测越准确的时候损失越小。

image.png

代码语言:javascript复制
    def diceLoss(self, prediction_g, label_g, epsilon=1):
        diceLabel_g = label_g.sum(dim=[1,2,3])
        dicePrediction_g = prediction_g.sum(dim=[1,2,3])
        diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])

        diceRatio_g = (2 * diceCorrect_g   epsilon)             / (dicePrediction_g   diceLabel_g   epsilon)

        return 1 - diceRatio_g

计算批量损失。

代码语言:javascript复制
    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
                         classificationThreshold=0.5):
        input_t, label_t, series_list, _slice_ndx_list = batch_tup#数据传入GPU
        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)#判断是否需要增强数据,训练时候需要,验证时候不需要
        if self.segmentation_model.training and self.augmentation_dict:
            input_g, label_g = self.augmentation_model(input_g, label_g)#运行分割模型
        prediction_g = self.segmentation_model(input_g)#计算损失
        diceLoss_g = self.diceLoss(prediction_g, label_g)#这个fnLoss使用的是prediction_g * label_g输入,也就是只保留了预测正确的那一部分,用于后面我们对损失进行加权
        fnLoss_g = self.diceLoss(prediction_g * label_g, label_g)#结果指标存储,批数据的起始位置和终止位置
        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx   input_t.size(0)

        with torch.no_grad():
            predictionBool_g = (prediction_g[:, 0:1]
                                > classificationThreshold).to(torch.float32)#计算真阳性,假阴性,假阳性数目
            tp = (     predictionBool_g *  label_g).sum(dim=[1,2,3])
            fn = ((1 - predictionBool_g) *  label_g).sum(dim=[1,2,3])
            fp = (     predictionBool_g * (~label_g)).sum(dim=[1,2,3])

            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
            metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
            metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
            metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp#这个地方进行了损失加权,这里×了8,表示正向的像素重要性比负向像素高8倍,用来增强我们把图像分割出结节的情况,因为我们希望能更多的找到结节,所以哪怕召回多一些也没关系,总比丢掉了一部分要好。
        return diceLoss_g.mean()   fnLoss_g.mean() * 8

再往下,我们研究把图像导入TensorBoard,以便我们能够在TensorBoard上显性地观察模型效果。做图像任务的好处就是比较容易观察中间结果,其实我自己做NLP比较多,中间结果输出出来也看不出什么效果。

图像记录方法。

代码语言:javascript复制
    def logImages(self, epoch_ndx, mode_str, dl):#把模型设置为eval模式
        self.segmentation_model.eval()#获取12个CT
        images = sorted(dl.dataset.series_list)[:12]
        for series_ndx, series_uid in enumerate(images):
            ct = getCt(series_uid)#取出6个切片
            for slice_ndx in range(6):
                ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5
                sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)

                ct_t, label_t, series_uid, ct_ndx = sample_tup

                input_g = ct_t.to(self.device).unsqueeze(0)
                label_g = pos_g = label_t.to(self.device).unsqueeze(0)

                prediction_g = self.segmentation_model(input_g)[0]
                prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
                label_a = label_g.cpu().numpy()[0][0] > 0.5

                ct_t[:-1,:,:] /= 2000
                ct_t[:-1,:,:]  = 0.5

                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()

                image_a = np.zeros((512, 512, 3), dtype=np.float32)
                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                image_a[:,:,0]  = prediction_a & (1 - label_a) #把假阳性区域标记成红色
                image_a[:,:,0]  = (1 - prediction_a) & label_a #假阴性标记为橙色
                image_a[:,:,1]  = ((1 - prediction_a) & label_a) * 0.5 

                image_a[:,:,1]  = prediction_a & label_a  #真阳性标记为绿色
                image_a *= 0.5
                image_a.clip(0, 1, image_a)

                writer = getattr(self, mode_str   '_writer')
                writer.add_image(
                    f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
                    image_a,
                    self.totalTrainingSamples_count,
                    dataformats='HWC',
                )

                if epoch_ndx == 1:
                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                    # image_a[:,:,0]  = (1 - label_a) & lung_a # Red
                    image_a[:,:,1]  = label_a  # Green
                    # image_a[:,:,2]  = neg_a  # Blue

                    image_a *= 0.5
                    image_a[image_a < 0] = 0
                    image_a[image_a > 1] = 1
                    writer.add_image(
                        '{}/{}_label_{}'.format(
                            mode_str,
                            series_ndx,
                            slice_ndx,
                        ),
                        image_a,
                        self.totalTrainingSamples_count,
                        dataformats='HWC',
                    )
                # This flush prevents TB from getting confused about which
                # data item belongs where.
                writer.flush()

然后在训练的main方法里面调用它。

代码语言:javascript复制
def main(self):……self.validation_cadence = 5
        for epoch_ndx in range(1, self.cli_args.epochs   1):
        ……
            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)#记录第一个epoch或者每隔几个周期的时候记录图像信息
            if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
                # if validation is wanted
                valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
                best_score = max(score, best_score)

                self.saveModel('seg', epoch_ndx, score == best_score)

                self.logImages(epoch_ndx, 'trn', train_dl)
                self.logImages(epoch_ndx, 'val', val_dl)

下图是书上给出的样例图,每个图上还有滚动条,通过滚动可以查看在不同迭代周期的图像。

image.png

除了记录图像,我们还得把迭代的指标信息也记录下来,这部分跟前面基本一样,就不再过多解释了。 最后,如果我们的模型效果还不错,我们要把它存下来,实际上我们存储的是模型训练好的参数信息。

代码语言:javascript复制
    def saveModel(self, type_str, epoch_ndx, isBest=False):#存储文件路径信息
        file_path = os.path.join(
            'data-unversioned',
            'part2',
            'models',
            self.cli_args.tb_prefix,
            '{}_{}_{}.{}.state'.format(
                type_str,
                self.time_str,
                self.cli_args.comment,
                self.totalTrainingSamples_count,
            )
        )#创建目录
        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)#获取模型
        model = self.segmentation_model        if isinstance(model, torch.nn.DataParallel):
            model = model.module#需要存储的状态信息
        state = {
            'sys_argv': sys.argv,  #系统参数
            'time': str(datetime.datetime.now()), #时间信息
            'model_state': model.state_dict(), #模型状态
            'model_name': type(model).__name__, #模型名称
            'optimizer_state' : self.optimizer.state_dict(), #优化器状态
            'optimizer_name': type(self.optimizer).__name__, #优化器名称
            'epoch': epoch_ndx, #迭代周期
            'totalTrainingSamples_count': self.totalTrainingSamples_count, #训练样本数量
        }#存储,通过存储模型,我们可以在下次接着训练
        torch.save(state, file_path)

        log.info("Saved model params to {}".format(file_path))#这里做一个备份,如果这是效果最好的一版模型,就再存一次,记得多做这种操作,并且文件命名一定要好,具体为什么你自己考虑,说多了都是泪。
        if isBest:
            best_path = os.path.join(
                'data-unversioned', 'part2', 'models',
                self.cli_args.tb_prefix,
                f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
            shutil.copyfile(file_path, best_path)

            log.info("Saved model params to {}".format(best_path))#最后这个hash是用于校验文件的
        with open(file_path, 'rb') as f:
            log.info("SHA1: "   hashlib.sha1(f.read()).hexdigest())

接下来就是训练模型然后看结果了。

0 人点赞