深度学习实战之医学图像分割

2021-10-14 16:16:45 浏览数 (1)

计算机视觉领域有三大问题:图像分类、目标检测以及图像分割。前两类问题及应用在公众号之前的文章里都有介绍,那么今天我们就来介绍剩下的图像分割问题,并以医学图像分割为例介绍它在现实中的应用。

目录

1、语义分割问题简介

2、Unet模型简介

3、腹腔核磁共振数据集实战

01

语义分割问题简介

图像分割问题是图像处理和计算机视觉领域的关键问题之一。分割结果直接影响着后续任务的有效性。图像分割的目的就是把目标从背景中提取出来,分割过程主要是基于图像的固有特征,如灰度、纹理、对比度、亮度、彩色特征等将图像分成具有各自特性的同质区域。近年来随着深度学习的发展,许多图像分割问题正在采用深层次的结构来解决,最常见的就是卷积神经网络,它在精度上以及效率上大大超过了其他方法。

肺部CT分割

语义分割

我们通常听到的比较多的名词是图像语义分割,即为图像中每个像素赋予一个指定的标签(像素级类别预测问题),而图像分割泛指将图片划分为不同区域,对于每个区域的语义信息并没有要求,传统图像分割有很多这样的分割算法。但我们现在讨论的自然图像语义分割和医学图像分割,其实都属于图像语义分割范畴。医学图像分割的主要目的还是对图像中具有特殊语义信息(如肿瘤、器官、血管等)赋予标签,但医学图像分割的类别个数一般没有自然图像语义分割那么多。如VOC2012包含20个类别和一个背景类别,但医学图像分割很多都是二分类问题。

脑部图像

医学图像属于图像的子类,所以针对图像的方法,应用到医学图像中是没有问题的,但我们通常说的图像特指自然图像(RGB图像),而医学图像包含的图像种类(格式,例如CT、MRI等等)范围更加广泛,两者又有一定的区别。所以在进行训练和预测之前,我们通常需要对原始图像和标注数据进行预处理,以达到模型输入的要求。

02

Unet模型

Unet最早发表在2015的MICCAI上,截至现在,引用量已经达到了6.7K,足以见其影响力。而后Unet成为大多医疗影像语义分割任务的baseline,也启发了大量研究者去思考U型语义分割网络。如今在自然影像理解方面,也有越来越多的语义分割和目标检测SOTA模型开始关注和使用U型结构,比如语义分割Discriminative Feature Network(DFN)(CVPR2018),目标检测Feature Pyramid Networks for Object Detection(FPN)(CVPR 2017)等。

Unet包括两部分组成,可以看上图,第一部分是特征提取,与VGG类似。第二部分是上采样部分。由于网络结构像字母U,所以叫做Unet。特征提取部分,每经过一个池化层就变换一个尺度,包括原图尺度一共有5个。在上采样部分,每上采样一次,就和特征提取部分对应的通道数相同尺度融合(图中标注为copy and crop),但是融合之前要将其crop。这里的融合其实就是拼接。可以看到,输入是图像分辨率为572x572,但是输出变成了388x388,这说明图像经过网络以后,输出的结果和原图分辨率并不是完全对应的(可以进行手动调整)。

蓝色箭头代表3x3的卷积操作,步长(stride)为1,padding策略是vaild,因此,每个该卷积以后,特征图(feature map)的大小会缩小一倍。红色箭头代表2x2的最大池化(max pooling)操作,需要注意的是,此时的padding策略也是vaild,这就会导致如果pooling之前featuremap的大小是奇数,会损失一些信息 。因为2*2的max-pooling算子适用于偶数像素点的图像长宽,所以要选取合适的输入大小。绿色箭头代表2x2的卷积 上采样操作,此处应该注意,某些教程将其错误理解为了反卷积,此处是卷积 上采样,与反卷积(转置卷积)不同。该操作会将feature map的大小乘以2。灰色箭头表示复制和剪切操作,可以发现,在同一层左边的最后一层要比右边的第一层分辨率略大,这就导致了,想要利用浅层的特征,就要进行一些剪切。输出的最后一层,使用了1x1的卷积层做了分类,输出的两层为前景和背景(二分类时)。

代码实现参考 https://github.com/milesial/Pytorch-UNet

03

腹腔核磁共振数据集实战

本次我们使用的是CHAOS数据集(https://chaos.grand-challenge.org/)中的腹腔MRI数据,每张核磁共振图像中包括脾脏、肝脏、左肾和右肾等四个器官,我们的目标是分割出其中的肝脏。

数据集中有CT/MR两种数据,都是dcm格式,每一张就是一个slice。对于核磁共振图像,一共有40个病例,训练集和测试集各分了20例。因为是比赛用数据,其中的测试集并未提供ground truth,因此我们训练过程中能使用的只有原训练集中的数据(T1、InPhase),并要重新对其进行划分。为了达到尽可能好的训练效果我们将其中16例作为训练集,4例作为验证集,不设置测试集(受限于数据量且重点是展示训练过程,实际中请勿模仿),使用没有ground truth的数据检验分割效果。

ps:除注明参考的部分以外代码均为原创

①.数据集重新划分

代码语言:javascript复制
import os
import shutil
import random

class Spliter:
    def __init__(self):
        self.train_img_dir = "./data/train/img"
        self.train_lbl_dir = "./data/train/lbl"
        self.test_img_dir = "./data/val/img"
        self.test_lbl_dir = "./data/val/lbl"


    def get_path(self,patient_dir):
        lbl_paths = []
        img_paths = []

        t1_img_dir = os.path.join(patient_dir, "T1DUAL")
        lbl_dir = os.path.join(t1_img_dir, "lbl")
        lbl_names = os.listdir(lbl_dir)
        nums_lbl = len(lbl_names)
        # 拼接lbl文件夹的文件,存入到lbl_paths列表中
        for i in range(nums_lbl):
            lbl_paths.append(os.path.join(lbl_dir, lbl_names[i]))

        img_dir = os.path.join(t1_img_dir, "DICOM_anon", "img")
        img_names = os.listdir(img_dir)


        # 拼接img文件夹的文件,存入到img_paths列表中
        for i in range(len(img_names)):
            img_paths.append(os.path.join(img_dir, img_names[i]))

        return lbl_paths, img_paths
    def main(self):
        dataset_dir = os.path.join("CHAOS_Train_Sets", "Train_Sets", "MR")
        train_split_rate = 0.8
        val_split_rate = 0.2

        for root, dirs, files in os.walk(dataset_dir):
            random.shuffle(dirs)
            dir_count = len(dirs)
            
            i = 0
            for sub_dir in dirs:  # sub_dir代表病人编号
                if i <= int(dir_count * train_split_rate):
                    patient_dir = os.path.join(root, sub_dir)
                    lbl_paths, img_paths = self.get_path(patient_dir)
                    for j in range(len(lbl_paths)):
                        new_lbl_path = os.path.join(self.train_lbl_dir, "T1_Patient%s_No%d.png" % (sub_dir,j))
                        shutil.copy(lbl_paths[j], new_lbl_path)

                    for j in range(len(img_paths)):
                        new_img_path = os.path.join(self.test_img_dir, "T1_Patient%s_No%d.dcm" % (sub_dir, j))
                        shutil.copy(img_paths[j], new_img_path)

                    i  = 1
                else:
                    patient_dir = os.path.join(root, sub_dir)
                    lbl_paths, img_paths = self.get_path(patient_dir)
                    for j in range(len(lbl_paths)):
                        new_lbl_path = os.path.join(self.test_lbl_dir, "T1_Patient%s_No%d.png" % (sub_dir, j))
                        shutil.copy(lbl_paths[j], new_lbl_path)

                    for j in range(len(img_paths)):
                        new_img_path = os.path.join(self.test_img_dir, "T1_Patient%s_No%d.dcm" % (sub_dir, j))
                        shutil.copy(img_paths[j], new_img_path)

                    i  = 1


if __name__ == '__main__':
     Spliter().main()

该段代码的作用就是将训练集中T1/InPhase 的20个病例划分成16个训练集,4个测试集,并重新存储到自定义的文件夹下. 对于图像文件也进行了命名规范,对第i个病人的第j张slice,命名规则为T1_Patienti_Noj.dcm

效果展示

②.分离出肝脏Mask

GroundTruth的图像包含了四种器官,此处我们只需要肝脏的mask,根据数据集描述,肝脏的灰度值范围为55-70,参考百度提供的教程,使用OpenCV库可以根据灰度值分离出肝脏部分

代码语言:javascript复制
import os
import cv2

def extract_liver(dataset_dir):
    src_img_names = os.listdir(dataset_dir)
    if src_img_names[0] == 'Liver':
        src_img_names.remove('Liver')
    src_img_num = len(src_names)
    new_dir = os.path.join(dataset_dir, "Liver")
    for num in range(src_img_num):
        src_img_path = os.path.join(dataset_dir, src_img_names[num])
        src_img = cv2.imread(src_path,0)   # 0表示灰度图,默认参数为1(RGB图像)
        result = 0
        for i in range(src_img.shape[0]):
            for j in range(src_img.shape[1]):
                for k in range(src_img.shape[2]):
                    if 55 <= src_img.item(i, j, k) <= 70:
                        result = 1  # 表示有肝脏
                        src_img.itemset((i, j, k), 255)
                    else:
                        src_img.itemset((i, j, k), 0)
        if result == 1:
            new_path = os.path.join(new_dir, src_img_names[num])
            cv2.imwrite(new_path, src_img)


if __name__ == '__main__':
    train_dir = os.path.join("data", "train", "lbl")
    test_dir = os.path.join("data", "val", "lbl")
    extract_liver(train_dir)
    extract_liver(test_dir)

处理好的mask

③.数据增强

因为我们处理的是腹腔MRI图像,需要平移与缩放不变性,并且对形变和灰度变化鲁棒。所以在预处理环节将训练样本进行数据增广,是改善分割网络训练效果的关键。这里我们使用功能强大的数据增广库albumentations。

代码语言:javascript复制
import albumentations as albu

def _aug_img_lbl(img_np, lbl_np, img_name=None):
    aug_results = []

    # img augmentation
    tf_res = albu.Compose([
        albu.HorizontalFlip(p=0.8),#垂直翻转
        albu.Flip(p=0.8),#水平翻转
        albu.Transpose(p=0.8),#转置
        albu.RandomScale(scale_limit=0.1,interpolation=1,always_apply=False,p=0.5),#随机缩放
        albu.RandomRotate90(p=0.8),#随机旋转90度
        albu.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, brightness_by_max=None,
                                      always_apply=False, p=0.8),], p=1.0)(image=img_np, mask=lbl_np)#随机亮度对比度变化

    aug_results.append({"image": tf_res["image"],"mask": tf_res["mask"]})
    return aug_results

增广后的数据

④.自定义dataset

代码语言:javascript复制
from torch.utils.data import Dataset
from PIL import Image
import os
import torchvision.transforms.functional as TF


class LiverDataset(Dataset):
    def __init__(self, root_dir):
        self.data_path = self.get_data_path(root_dir)

    def __getitem__(self, index):
        img_path, lbl_path = self.data_path[index]
        img = Image.open(img_path).convert('L')
        lbl = Image.open(lbl_path).convert('L')
        img = TF.to_tensor(img)
        lbl = TF.to_tensor(lbl)
        return img, lbl

    def __len__(self):
        return len(self.data_path)

    def get_data_path(self, root):
        data_path = []
        img_path = os.path.join(root, "img")
        lbl_path = os.path.join(root, "lbl")
        names = os.listdir(img_path)
        n = len(names)
        for i in range(n):
            img = os.path.join(img_path, names[i])
            lbl = os.path.join(lbl_path, names[i])
            data_path.append((img, lbl))
        return data_path

⑤.开始训练

考虑到前景和背景所占的比例不平衡的问题,这里的损失函数我们选用二分类的Focal loss,初始学习率设定为1e-4,每4个epoch缩小10倍,共训练20个epoch(完整代码前往评论区下载)。

代码语言:javascript复制
   def train_one_epoch(self):
        t_epoch_start = time.time()

        self.model.train()

        losses = []

        print("train batches: %s" % (len(self.train_dataloader)))
        for batch_i, (features_batch, labels_batch) in enumerate(self.train_dataloader):
            print(f"batch_i={batch_i}")
            t0 = time.time()

            try:
                if self.cuda:
                    features_batch = features_batch.cuda()
                    labels_batch = labels_batch.cuda().long()

                # forward
                y, loss = self.model(features_batch, labels_batch)
                if loss == 0 or not torch.isfinite(loss):
                    continue

                # backward
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()

                # for print
                losses.append(loss.item())

                time_span = time.time() - t0

                print(
                    "batch {:05d} | Time(s):{:.2f}s | Loss:{:.4f}".format(batch_i, time_span, loss.item()))

            except Exception as e:
                print('[Error]', traceback.format_exc())
                print(e)
                continue

        durations = time.time() - t_epoch_start

        return {
            "train_durations": durations,
            "average_train_loss": np.mean(losses),
        }

    def train(self):

        # ================================================
        # train loop
        # ================================================
        print("total epochs: %s" % (self.n_epochs))

        try:
            for epoch_i in range(self.n_epochs):
                print()
                print("------------------------")
                print("Epoch {:05d} training...".format(epoch_i))
                print("------------------------")

                self.epoch_i = epoch_i

                one_epoch_result = self.train_one_epoch()
                self.scheduler.step()

                print(
                    "Epoch {:05d} training complete...: | Time(s):{:.2f}s | Average Loss:{:.4f}".format(
                        epoch_i, one_epoch_result["train_durations"], one_epoch_result["average_train_loss"]
                    ))

                # ================================================
                # after each epoch ends
                # ================================================

                val_one_epoch_result = self.val_one_epoch()
                if val_one_epoch_result["average_val_loss"] < self.best_loss:
                    self.best_loss = val_one_epoch_result["average_val_loss"]
                    self.save_model()


        except KeyboardInterrupt:
            pass


    def save_model(self):
        print("saving model ...")
        torch.save({
            'model_state_dict': self.model.state_dict(),
        }, './checkpoint/best_parameters.pth')

训练结束之后,得到的权重文件会保存在./checkpoint/文件夹下。

⑥.进行预测

预测代码与训练代码类似,这里不再展示(完整项目代码可在评论区下载)。运行predict.py结束后,预测的结果会保存在data/predict_output文件夹下。下面为测试结果展示,横向的连续三张图分别为GroundTruth,网络预测图及原图。

- END -

0 人点赞