24 | 使用PyTorch完成医疗图像识别大项目:图像分割数据准备

2022-07-11 15:52:39 浏览数 (1)

本周有点丧,前面几天不是忙于面试就是忙于塞尔达炸鱼,一直没更新,好在这周把这本书读完了,今天再更一篇,终于快要结束了。

之前的模型主要是研究怎么去给可能的图像分类,决定这是不是一个结节。 回头看看我们的步骤,我们前面处理的都属于第4步的事情,这中间还忽略了两个步骤,步骤2和步骤3。其中步骤2是要对结节信息进行分割,步骤3是给分割出来的结节像素分组。

今天我们就从步骤2开始。下面就看一下这个设计好的分割模型的训练流程。

1.使用一个开源的分割模型U-Net。我们不再自己编写一个模型,而是用开源的,这大概也是工作中处理业务问题的主要方式。 2.调整模型适配我们的数据。因为U-Net原本是为二维图像准备的,而我们的数据是一个三维体数据,所以这里要进行一些修改,主要有下面三点:

  • 更新模型。这里主要指的把U-Net模型融合到我们的代码里,并能够让它跑通。
  • 修改数据集。为我们的分割模型构建一套可以使用的数据集,这个跟我们之前的分类模型使用的数据集有些区别。因为我们分类模型给出的结果是简单的分类结果,而分割模型需要输出被分割的一块图像。
  • 修改训练循环。这里主要是使用心得损失函数来适配模型输出的图像结果。 3.结果观察。

图像分割

图像分割方法主要可分为两种类型:语义分割和实例分割。语义分割会使用相同的类标签标注同一类目标(下图左),而在实例分割中,相似的目标也会使用不同标签进行标注(下图右)。我们这里需要使用的是语义分割。

由于我们前面一直在做的是分类,这里我们先简单了解一下分割和分类的区别。如下是同一张图,对于分类模型需要解答的是这是不是一张关于猫的图片,但是对于分割模型来说,它需要给出的是这个图上的哪一部分是猫,并需要把猫的图像标记出来。

我们的分类模型是在一步步抽象,比如我们之前的卷积模型,通过卷积和池化不断的压缩特征,最后输出一个分类,我们可以使用它识别到图像里有猫,但是猫到底在哪一块分类就解决不了了,因为它的结果是高度抽象的结果,因此,我们可以想到要把这种抽象的信息再还原回图像。

这就涉及到我们要使用的U-Net模型的核心信息。

U-Net模型

U-Net是2015年提出来的分割模型。这个网络可以输出图像像素,也就是把分割出来的像素输出出来,用于标记要分割的形状。

上图就是图像的整体结构。左半边跟我们的卷积神经网络基本上没有什么区别,每一层都对数据进行了压缩,这部分的作用是用于捕获更多高阶特征。而模型的右侧使用反卷积实现上采样(之前介绍过就是把图像变大的方法),把图像还原回去。最后仍然采用softmax输出结果。另外在图上可以看到一些灰色的横向箭头,这是跳跃连接,主要是为了让这个阶段能够学习到更早期同尺寸层的特征。因为在压缩的过程中,原本的图像空间信息已经丢失了,让它自己去恢复可能就是野蛮生长了。 U-Net模型的论文:https://arxiv.org/abs/1505.04597

实现分割模型

确定我们要用的模型,接下来就是动手,去Github上去找我们需要的模型。我找了源地址发现打不开,好在已经有人把代码传到github的新项目上,地址是https://github.com/schopfej/pytorch_unet_jvanvugt/blob/master/unet.py。当然这个代码在书附赠的代码里也有了。这个代码很精简,让我们看看。

代码语言:javascript复制
import torchfrom torch import nnimport torch.nn.functional as Fclass UNet(nn.Module):#默认参数是原论文中使用的参数
    def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False,
                 batch_norm=False, up_mode='upconv'):
        """
    参数说明
        Args:
            in_channels (int): 输入通道数
            n_classes (int): 输出通道数
            depth (int): 网络深度
            wf (int): 过滤器数量,也就是用于控制内部的卷积块和上采样块的输入输出通道数,第一层过滤器数量是 2**wf
            padding (bool): 补充,是否使用padding
                            
            batch_norm (bool):是否使用batch norm
            up_mode (str):使用上卷积或上采样 'upconv' or 'upsample'.
                           'upconv' 使用转置卷积实现上采样
                           'upsample' 使用线性上采样
        """
        super(UNet, self).__init__()#参数处理
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()#根据深度决定实现几个卷积块以及上采样
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf i),
                                                padding, batch_norm))
            prev_channels = 2**(wf i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2**(wf i), up_mode,
                                            padding, batch_norm))
            prev_channels = 2**(wf i)#最后的卷积
        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)#前向传播
    def forward(self, x):
        blocks = []#首先执行卷积块
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path)-1:
                blocks.append(x)#平均池化
                x = F.avg_pool2d(x, 2)#然后执行上采样块
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i-1])

        return self.last(x)#卷积块的实现class UNetConvBlock(nn.Module):#在一个卷积块里面执行了两次卷积,两次激活函数变换,两次batch norm,对应架构图上每层向右的两个箭头
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
                               padding=int(padding)))
        block.append(nn.ReLU())
        # block.append(nn.LeakyReLU()) 不同的激活函数
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
                               padding=int(padding)))
        block.append(nn.ReLU())
        # block.append(nn.LeakyReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)#卷积块的前向传播
    def forward(self, x):
        out = self.block(x)
        return out#定义上采样块class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()#使用转置卷积实现上采样
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
                                         stride=2)#使用双线性模型实现上采样
        elif up_mode == 'upsample':
            self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
                                    nn.Conv2d(in_size, out_size, kernel_size=1))#卷积块
        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)#中心裁剪
    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y   target_size[0]), diff_x:(diff_x   target_size[1])]#前向传播
    def forward(self, x, bridge):
        up = self.up(x)#在这里实现跳链以及裁剪
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

看完了UNet的代码,下面就把它加到我们自己的代码里,需要先对UNet模型包装一下。这里定义一个新的模型类,修改的地方就两个,一个是在输入前增加一个batchnorm,一个是在输出前增加一个sigmoid激活函数,中间完全使用UNet模型结构。

代码语言:javascript复制
from util.unet import UNetclass UNetWrapper(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
        self.unet = UNet(**kwargs)
        self.final = nn.Sigmoid()#这里增加了一个权重初始化动作
        self._init_weights()

前向传播就很简单。

代码语言:javascript复制
   def forward(self, input_batch):
       bn_output = self.input_batchnorm(input_batch)
       un_output = self.unet(bn_output)
       fn_output = self.final(un_output)
       return fn_output

在这里可以看到我们使用的是2d的batchnorm,因为UNet比较适合处理二维图像,如果把它转换成处理三维图像内存消耗太大,所以我们采用的方案是处理数据,把我们的3D体数据转换成UNet可以处理的二维图像。接下来就要构建数据集了。

构建数据集

第一个问题是关于输入输出图像的尺寸。根据论文中的情况,UNet网络接收一个572×572大小的图像,输出一个388×388大小的图像,而我们期望输入和输出能够一样大,毕竟我们是在做一个医学项目,边角信息的丢失也可能会导致判断失误。正好在UNet网络中开启padding就可以解决这个问题。

第二个问题是我们的数据是三维数据,是512×512×128的图像,如果直接塞进UNet我们的内存就炸了。我算了一下,这一个图像就是128MB,UNet的第一层有64个channel,那我们就需要128×64MB,也就是8GB的空间。考虑之前二维图像的RGB三个通道,这里我们把不同的切片也看做通道,只保留正在处理的切片上下相邻的几个切片数据以通道的形式传入模型。当然这里会有一些信息的损失,因为本来这些切片之间是有上下顺序的,按二维图像的通道输入就没有了这种关系。

第三个问题是原数据不匹配。前几节里面介绍的标注数据,给出了中心点坐标以及直径尺寸。但是我们需要的是一个图像区域,来标明里面的哪些像素块是结节。像下面画的,我们期望要这样一个效果。

当然,在对体素进行处理的时候还有一些实际问题要考虑。因为体素存储的是密度信息,我们需要以中心点出发去寻找周边的体素,如果体素的密度在某个范围内,我就可以把它标记为1,如果不在设定的范围内,我们就标记为0,大概是这样一个思路。下图展示了搜索方式,先从中心点开始,向左右两侧搜索,确定边界,再进行上下两侧的搜索。

代码实现

代码语言:javascript复制
#获取中心点信息,从xyz坐标转换到irc坐标
            center_irc = xyz2irc(
                candidateInfo_tup.center_xyz,
                self.origin_xyz,
                self.vxSize_xyz,
                self.direction_a,
            )
            ci = int(center_irc.index)
            cr = int(center_irc.row)
            cc = int(center_irc.col)#索引轴方向上的半径检索,下面还有行方向的检索和列方向的检索
            index_radius = 2
            try:
                while self.hu_a[ci   index_radius, cr, cc] > threshold_hu and                         self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
                    index_radius  = 1
            except IndexError:  #异常捕获,如果超出阈值就返回1格
                index_radius -= 1

构建掩码数据。

代码语言:javascript复制
    def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
        boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)

        for candidateInfo_tup in positiveInfo_list:#遍历所有的结节数据…………#这里省略的就是上面的检索代码
            boundingBox_a[#根据前面检索的半径生成整个框
                 ci - index_radius: ci   index_radius   1,
                 cr - row_radius: cr   row_radius   1,
                 cc - col_radius: cc   col_radius   1] = True#然后对框内的数据进行判断,超过阈值的设为True
        mask_a = boundingBox_a & (self.hu_a > threshold_hu)

        return mask_a

把已经处理好的数据展示出来,可以看到这么处理数据有一点点问题,因为在最初的边框附近包含了肺部边缘,这种组织的密度也较高,最终也判定为1。虽然能够把它处理掉,但是会花费比较大的精力,我们先不管它了。

接下来在CT初始化的时候调用这些掩码

代码语言:javascript复制
class Ct:
    def __init__(self, series_uid):
     ……#省略了一些代码
        candidateInfo_list = getCandidateInfoDict()[self.series_uid] #获取候选信息

        self.positiveInfo_list = [  #正样本列表
            candidate_tup            for candidate_tup in candidateInfo_list            if candidate_tup.isNodule_bool        ]
        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)  #正样本掩码构建#最后这行是把具有非0计数的掩码切片的索引存下来
        self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
                                 .nonzero()[0].tolist())

跟前面一样,接下来是构建缓存。先读取原始数据。

代码语言:javascript复制
    def getRawCandidate(self, center_xyz, width_irc):#获取中心位置
        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
                             self.direction_a)#起始位置和终止位置索引
        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx   width_irc[axis])#断言 用来处理异常
            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])

            if start_ndx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])#加入切片信息
            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]#提取其中的正样本掩码数据
        pos_chunk = self.positive_mask[tuple(slice_list)]#返回数据        return ct_chunk, pos_chunk, center_irc

然后把原始数据缓存到磁盘。

代码语言:javascript复制
@raw_cache.memoize(typed=True)def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
                                                         width_irc)
    ct_chunk.clip(-1000, 1000, ct_chunk)
    return ct_chunk, pos_chunk, center_irc

接下来获取数据信息。这里需要注意的是,使用的标注数据不再是原来的annotation.csv,作者说之前那个文件里有一些重复信息会影响我们的图像分割,但是对分类没什么影响,所以在这里要把里面的重复数据去掉,具体怎么操作的没有说明,这里只是给出了新的标注数据文件。

代码语言:javascript复制
    candidateInfo_list = []
    with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])
            isMal_bool = {'False': False, 'True': True}[row[5]]#从中取出我们需要的数据并合成结果
            candidateInfo_list.append(
                CandidateInfoTuple(
                    True,
                    True,
                    isMal_bool,
                    annotationDiameter_mm,
                    series_uid,
                    annotationCenter_xyz,
                )
            )#然后从candidates.csv取出候选信息
    with open('data/part2/luna/candidates.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])#这里只使用非结节数据
            if not isNodule_bool:
                candidateInfo_list.append(
                    CandidateInfoTuple(
                        False,#是否结节
                        False,#是否恶性
                        False,#是否有标注
                        0.0,
                        series_uid,
                        candidateCenter_xyz,
                    )
                )

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list

数据的准备工作已经完成,接下来是实现Dataset类。还记得Dataset类里面两个重要的方法,一个是len,一个是getitem。不过这次的数据实现跟之前有点区别,我们准备了两种数据,在训练集中加入了随机方法和裁剪构建样本,而验证集仍然保持原样。这个数据集类称作Luna2dSegmentationDataset。

代码语言:javascript复制
class Luna2dSegmentationDataset(Dataset):
    def __init__(self,
                 val_stride=0, #验证集长度
                 isValSet_bool=None,  #是否验证集
                 series_uid=None, #序列uid
                 contextSlices_count=3, #上下文切片数,就是从中心点切片向上向下选取的切片数量,这里设置为3,就是选取7片切片作为我们的训练数据7个通道。
                 fullCt_bool=False,#这个参数用来标记是否使用所有切片,当评估的时候使用True,当训练的时候使用False,这样可以让模型更关注那些带有阳性掩码的切片,而评估的时候我们也不知道哪个有阳性,所以使用所有切片。
            ):#处理参数默认信息
        self.contextSlices_count = contextSlices_count        self.fullCt_bool = fullCt_bool        if series_uid:
            self.series_list = [series_uid]
        else:
            self.series_list = sorted(getCandidateInfoDict().keys())#在判断是否验证集的时候,如果不是验证集则需要按验证集长度剔除掉数据
        if isValSet_bool:
            assert val_stride > 0, val_stride            self.series_list = self.series_list[::val_stride]
            assert self.series_list
        elif val_stride > 0:
            del self.series_list[::val_stride]
            assert self.series_list#由于我们的CT数据切片数不一样,这里构建了一个缓存数据,用来存放CT数据的长度getCtSampleSize,等会看一下这个代码。
        self.sample_list = []
        for series_uid in self.series_list:
            index_count, positive_indexes = getCtSampleSize(series_uid)#这里处理fullct参数,可以看到如果True,则使用的index_count,如果是False使用的只有positive_indexes
            if self.fullCt_bool:
                self.sample_list  = [(series_uid, slice_ndx)
                                     for slice_ndx in range(index_count)]
            else:
                self.sample_list  = [(series_uid, slice_ndx)
                                     for slice_ndx in positive_indexes]#缓存数据
        self.candidateInfo_list = getCandidateInfoList()#建立一个set方便查找
        series_set = set(self.series_list)
        self.candidateInfo_list = [cit for cit in self.candidateInfo_list                                   if cit.series_uid in series_set]

        self.pos_list = [nt for nt in self.candidateInfo_list                            if nt.isNodule_bool]#记录日志信息
        log.info("{!r}: {} {} series, {} slices, {} nodules".format(
            self,
            len(self.series_list),
            {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
            len(self.sample_list),
            len(self.pos_list),
        ))

这是刚才中间的那个方法,用来缓存CT的尺寸

代码语言:javascript复制
@raw_cache.memoize(typed=True)def getCtSampleSize(series_uid):
    ct = Ct(series_uid)
    return int(ct.hu_a.shape[0]), ct.positive_indexes

实现getitem

代码语言:javascript复制
# getitem的功能都包装在getitemfullslice中了
    def __getitem__(self, ndx):
        series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
        return self.getitem_fullSlice(series_uid, slice_ndx)#
    def getitem_fullSlice(self, series_uid, slice_ndx):
        ct = getCt(series_uid)#初始化tensor
        ct_t = torch.zeros((self.contextSlices_count * 2   1, 512, 512))#设置起始块和终止块位置
        start_ndx = slice_ndx - self.contextSlices_count
        end_ndx = slice_ndx   self.contextSlices_count   1#把块数据取出来放进tensor中
        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
            context_ndx = max(context_ndx, 0)
            context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
            ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
        # 设置数据值的上下限
        ct_t.clamp_(-1000, 1000)#获取对应的掩码数据
        pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)

        return ct_t, pos_t, ct.series_uid, slice_ndx

前面提到我们这次要准备两个数据集,刚刚已经把基础数据集准备好了,下面是训练数据集TrainingLuna2dSegmentationDataset。训练数据围绕着结节候选进行选取,以结节为中心96×96的区域中,随机选取一个64×64的区域。这个方法不是随便想出来的,而是经过了各种实验得出的结果。作者原本使用整个切片信息去训练,但是效果不稳定,我想是因为结节数据太小了,而CT切片比较大,也就是数据不平衡问题。通过这种基础剪裁之后,候选结节和无结节区域的比例明显提升了。

代码语言:javascript复制
class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ratio_int = 2

    def __len__(self):
        return 300000

    def shuffleSamples(self):
        random.shuffle(self.candidateInfo_list)
        random.shuffle(self.pos_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
        return self.getitem_trainingCrop(candidateInfo_tup)

    def getitem_trainingCrop(self, candidateInfo_tup):
        ct_a, pos_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            (7, 96, 96),
        )
        pos_a = pos_a[3:4]#同时裁剪CT和掩码数据
        row_offset = random.randrange(0,32)
        col_offset = random.randrange(0,32)
        ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset 64,
                                     col_offset:col_offset 64]).to(torch.float32)
        pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset 64,
                                       col_offset:col_offset 64]).to(torch.long)

        slice_ndx = center_irc.index        return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx

关于数据准备先写到这里,明天继续。

0 人点赞