上一节我们理解了业务,也就是我们这个项目到底要做什么事情,并定好了一个方案。这一节我们就开始动手了,动手第一步就是把数据搞清楚,把原始数据搞成我们可以用PyTorch处理的样子。这个数据不同于我们之前用的图片数据,像之前那种RGB图像拿过来做一些简单的预处理就可以放进tensor中,这里的医学影像数据预处理部分就要复杂的多。比如说怎么去把影像数据导入进来,怎么转换成我们能处理的形式;数据可能存在错误,给定的结节位置和实际的坐标位置有偏差;数据量太大我们不能一次性加载怎么处理等等。今天理解数据这部分处理的就是之前整个项目框架图的第一步,关于数据加载的问题。
原始CT数据
把数据解压之后,我们可以看到CT数据源文件,一个CT数据实际上包含两个文件,‘.mhd’文件包含了元数据头部信息,‘.raw’则是存储的三维CT原始数据。前面的文件名称为它的uid,符合DICOM数据命名法。
除了CT数据源文件,我们还需要把上一节提到的网站上的其他数据也下载下来,其中比较重要的是annotations.csv和candidates.csv文件,这两个文件是对数据的标注信息。
解析标注数据
annotations.csv文件里面给出了结节的位置信息,我们可以看一下里面的数据,总共有5列,第一列是uid,中间三列是坐标位置,最后一列是结节的尺寸,直径信息。candidates.csv是对影像中小圆点的标注,数据也是五列,有区别的是最后一列,candidates.csv的最后一列标明了这个小圆点是否是结节。
其中candidates.csv文件中包含了551065条数据,如果我们对分类标签统计一下,可以发现总共有1351条被标注为是结节。这个时候数据的问题就来了,在annotations.csv文件里只给出了1187条结节的信息。
除了数据量的问题,让我们再仔细观察一下数据。
我们从里面抽出一条数据,比如说从candidates中找到第一个标注为1的数据,这个uid对应的有两条结节数据,结果我们发现candidates和annotations里面对这两个结节的位置信息标注并不是完全一样的,用眼观察我们大概能知道这两个怎么匹配,但是如果让我们的计算机硬匹配那估计是没戏了。所以这里还需要进行一个数据对齐的工作,把足够相近的两个点位认证成同一个结节,如离得太远又没办法配对的数据,我们可能得考虑把它丢掉。
candidates
annotations
由于项目代码是一个整体,而且里面涉及到很多操作没必要在这里都写出来,所以我只写上跟当前相关的部分,全部的代码可以在项目文档中获得。这里定义了一个结节候选集信息元组,使用这个元组来存储结节信息,我们需要把前面提到的两个文件中的信息整合起来。
代码语言:javascript复制CandidateInfoTuple = namedtuple( #命名元组
'CandidateInfoTuple', #元组名称
'isNodule_bool, diameter_mm, series_uid, center_xyz', #元组元素,包含是否是结节,结节尺寸,CT影响id,结节中心点坐标
)
这里的CandidateInfoTuple 当然还不是训练数据,它只能算作是一个索引文件,因为这里面还不包含我们的CT影像数据。接下来这个方法就是跟读取CT影像有关了,因为CT影像非常大,所以我们需要使用内存缓存,这也是我们加速运算的好方法,对数据的读取控制不好的话也会极大的影响模型训练的速度。
代码语言:javascript复制@functools.lru_cache(1) #这是一个内存缓存标准库
def getCandidateInfoList(requireOnDisk_bool=True):
#通过读取所有mhd文件,我们把所有的影像id都缓存到presentOnDisk_set里面,方便我们调用数据
mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
接下来我们就合并两个标注文件。首先是读取annotations.csv文件,并把uid作为key存入diameter_dict这样一个字典中。
代码语言:javascript复制diameter_dict = {}
with open('data/part2/luna/annotations.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
annotationDiameter_mm = float(row[4])
diameter_dict.setdefault(series_uid, []).append(
(annotationCenter_xyz, annotationDiameter_mm)
)
然后是读取candidates.csv文件,先是一个异常处理,判断从candidates中获取的uid能否在我们前面影像数据中找到对应的文件,如果找不到就跳过这一轮。接着读取标签和坐标信息。然后是去匹配我们之前读到的结节尺寸信息,这里判断标准是candidates存储的中心坐标和annotations中存储的中心坐标距离是否超过了结节直径的四分之一,如果在这个范围内我们可以认为这两条信息指向的是同一个结节,我们可以把信息合并,存入CandidateInfoTuple中,最后把每一个CandidateInfoTuple存入candidateInfo_list中。
代码语言:javascript复制candidateInfo_list = [] # 用来存储结构信息的list
with open('data/part2/luna/candidates.csv', "r") as f: #读取candidates.csv文件 for row in list(csv.reader(f))[1:]:
series_uid = row[0]
if series_uid not in presentOnDisk_set and requireOnDisk_bool: #异常处理 continue
isNodule_bool = bool(int(row[4])) #读取标签信息
candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) #读取candidates.csv中的坐标
candidateDiameter_mm = 0.0 #用于记录结节直径信息,如果不是真的结节,这个值就为0
for annotation_tup in diameter_dict.get(series_uid, []): #对比两个文件中的结节坐标信息
annotationCenter_xyz, annotationDiameter_mm = annotation_tup for i in range(3):
delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i]) #判断坐标距离是否在可接受范围 if delta_mm > annotationDiameter_mm / 4:
break
else:
candidateDiameter_mm = annotationDiameter_mm break
candidateInfo_list.append(CandidateInfoTuple( #存储最终结果
isNodule_bool,
candidateDiameter_mm,
series_uid,
candidateCenter_xyz,
))
从上面这些代码我们已经遇到了一些数据预处理的问题,当然这里的问题已经比较简单了,而且我们直接给出了解决方案。在实际的工作中,你需要自己去思考这种数据异常的问题该怎么解决,或许有很多的方法,那你可能还需要尝试每种方法给结果带来的影响有多大,当你做了很多项目之后你或许对怎么处理数据会有一些经验,能够知道哪些数据比较重要,哪些数据我们不太需要关心它的异常。 最后我们把数据进行排序。
代码语言:javascript复制candidateInfo_list.sort(reverse=True) #这里没有指定排序key,就以第一个元素为依据,并按照reverse=True来处理,也就是‘是结节’将排在前面return candidateInfo_list
加载CT影像数据
处理完标注数据的索引,下一步就是研究怎么获取我们的CT影像数据了,毕竟这才是我们需要作为训练样本的东西。为了能够按照我们的索引把相关的数据取到,我们首先得把原始影像读出来,然后把它变换成我们可以用的坐标来表示。
对于这些特殊格式的数据处理是很重要的,如果你使用了错误的解析器那么可能得到错误的结果,或者是不好用的结果。好在大部分项目都是基础的图像或者文本,如果你要处理一些特种数据那就得去研究一下有没有现成的库可以使用了。
在这里我们可以使用一个叫做SimpleITK的库来访问我们的数据。再次重申一下,这里的代码目前不能像前面的章节一样直接运行,你需要使用完整的代码,这里我们只是用来讲解。
代码语言:javascript复制import SimpleITK as sitkclass Ct: #定义了一个类 用于处理CT文件
def __init__(self, series_uid): # 这里传入uid,因为我们后面是按照前面制定的索引对应的去取数据,对于其他的数据我们不需要关心,所以这里设置了一个参数,而不是把所有CT数据都读进来,我们的内存也扛不住。在代码中看不到哪里传入了raw文件,实际上在sitk.ReadImage方法里它自己会根据头文件的地址去寻找对应的raw文件。
mhd_path = glob.glob(
'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
)[0]
ct_mhd = sitk.ReadImage(mhd_path)
ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
试想,如果你只是拿到了一些raw数据,但是你不知道能用什么去处理它,是不是会非常费劲?你可以输出一下ct_a的形状看一下,它是一个三维数组,存储了我们的三维体素数据。
关于体素所使用的的单位我们这时候需要介绍一下了。因为这也涉及到数据清洗的工作。raw数据中存储的体素数据单位是亨氏单位,中文也可以叫做CT值,查一下百度,其中水是0HU(1g/cm^3)
CT值是测定人体某一局部组织或器官密度大小的一种 计量单位 ,通常称亨氏单位(hounsfield unit ,HU)空气为-1000(0g/cm3),致密骨为 1000(2-3g/cm3)。实际上 CT 值 是CT 图像中各组织与X 线衰减系数相当的对应值。
这里我们限制数据中数值的上下界,使得数值介于-1000到1000之间,超出这个范围的会被设定为临界值。
代码语言:javascript复制 ct_a.clip(-1000, 1000, ct_a)
self.series_uid = series_uid self.hu_a = ct_a #然后把数据传递进来
制定坐标
我们现在能读进来CT数据了,但是怎么把CT数据上的一个点位信息和之前结节数据的坐标信息关联起来从而获取那一块CT数据呢?因为我们前面加载的结节信息是用毫米来表示,而不是体素,显然它俩之间需要有某种变换关系来实现关联。 为了方便记录,我们这里把以毫米为单位的坐标称为(X,Y,Z)坐标,以体素为单位的坐标称为(I,R,C)坐标。下图就是(X,Y,Z)坐标的展示。
除了坐标的问题,我们还需要知道的是,对于一块体素数据,它不是一个立方体,而是1.125mm×1.125mm×2.5mm的一个立体块,如果按照立方体来对数据进行绘制的时候,人会看起来更胖一些,所以如果要按真实的样貌进行展示还需要加入一个变换比例。当然对于不同的设备扫描的CT数据尺寸可能不一样,这个信息会存在它的头文件中。一般来说CT影像横切面是一个512行×512列的大小,然后会有100-250个切面,那总共会有2^25个体素,也就是3200w个。
知道了体素块和尺寸的对应关系,下面我们手写代码实现它。首先定义两个工具方法,实现xyz和irc之间的转换。除了原始的xyz或者irc数据,我们还需要三个输入参数,一个是毫米偏移量,一个是体素的尺寸,最后一个是方向矩阵,这三个参数都可以从mhd文件中获取,这里我们先不深入研究这个工具类,只要知道它实现了两种数据之间的转换就可以了。
代码语言:javascript复制#irc存储的是切片索引及切片的行列IrcTuple = collections.namedtuple('IrcTuple',['index','row','col'])XyzTuple = collections.namedtuple('XyzTuple',['x','y','z'])def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_a):
cri_a = np.array(coord_irc)[::-1]
origin_a = np.array(origin_xyz)
vxSize_a = np.array(vxSize_xyz)
coords_xyz = (direction_a @ (cri_a * vxSize_a)) origin_a return XyzTuple(*coords_xyz)def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_a):
origin_a = np.array(origin_xyz)
vxSize_a = np.array(vxSize_xyz)
coords_xyz = np.array(coord_xyz)
cri_a = ((coord_a - origin_a) @ np.linalg.inv(direction_a)) / vxSize_a
cri_a = np.round(cri_a)
return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))
接着就是在处理CT影像的类中,初始化的时候获取刚才说的三项参数。
代码语言:javascript复制 self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
CT类的初始化差不多可以了,接下来实现获取圆点对应的CT数据。
代码语言:javascript复制def getRawCandidate(self, center_xyz, width_irc):#首先把xyz转换成irc
center_irc = xyz2irc(
center_xyz,
self.origin_xyz,
self.vxSize_xyz,
self.direction_a,
)
slice_list = []
for axis, center_val in enumerate(center_irc):
#根据中心点找到起始位置和终止位置
start_ndx = int(round(center_val - width_irc[axis]/2))
end_ndx = int(start_ndx width_irc[axis])
#从这里开始是异常判断,判断是否超出界限
assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
if start_ndx < 0: #这里还有一个日志记录,在实际的项目中,我们一般都会记录运行日志,方便出问题的时候查看 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
start_ndx = 0
end_ndx = int(width_irc[axis])
if end_ndx > self.hu_a.shape[axis]:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
end_ndx = self.hu_a.shape[axis]
start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
#异常判断到这里结束
slice_list.append(slice(start_ndx, end_ndx))#存储入ct_chunk
ct_chunk = self.hu_a[tuple(slice_list)]
return ct_chunk, center_irc
slice() 函数实现切片对象,主要用在切片操作函数里的参数传递。
这段代码读起来可能有点费劲,需要仔细检查一下每一个方法都是什么意思。
转为Dataset数据集
接下来,我们像前面做CIFAR数据一样,把它转换成Dataset数据集,方便我们使用同样的API。还记得之前说的要符合Dataset数据集的要求,需要实现两个方法 _ _ len _ _ 和 _ _ getitem _ _,在这之前我们还需要增加两个缓存方法,实际上一般在刚开始搞一个项目的时候可能不会用这样的方法,等到后面发现性能跟不上再考虑优化,不过这里我们的CT文件很大,就直接加入了缓存方法来优化速度。其中getCtRawCandidate对之前getRawCandidate进行了包装,输出是一样的。
代码语言:javascript复制@functools.lru_cache(1, typed=True)def getCt(series_uid):
return Ct(series_uid)@raw_cache.memoize(typed=True)def getCtRawCandidate(series_uid, center_xyz, width_irc):
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
return ct_chunk, center_irc
上面的代码中,首先对getCt的结果进行了缓存,它会存在内存中,但是需要注意的是,在内存中只会缓存一个CT文件,如果频繁访问不同的CT文件就会导致大量的miss,这种缓存就没有太大意义了,所以我们处理的时候需要注意顺序。然后是getCtRawCandidate,获取的CT的数值会在磁盘中缓存,同时我们减少了数据的数量级,从而降低了读取压力。接下来是编写两个方法。其中len很简单,我们的数据集大小就是所有小圆点信息的大小,getitem相对复杂。
代码语言:javascript复制 def __len__(self):
return len(self.candidateInfo_list)
def __getitem__(self, ndx): #ndx是一个整数,表示索引值#首先取出候选小圆点信息
candidateInfo_tup = self.candidateInfo_list[ndx]#定义尺寸信息
width_irc = (32, 48, 48)#获取对应的ct数据块
candidate_a, center_irc = getCtRawCandidate(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
width_irc,
)#转换成tensor
candidate_t = torch.from_numpy(candidate_a)
candidate_t = candidate_t.to(torch.float32)
candidate_t = candidate_t.unsqueeze(0)#构建分类标签ont-hot信息,一共两位,第一位是不是结节,第二位是结节
pos_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool ],
dtype=torch.long,
)#返回结果,其中包含候选小圆点数据,分类标签,id,以及中心位置
return (
candidate_t,
pos_t,
candidateInfo_tup.series_uid,
torch.tensor(center_irc),
)
分割训练集和验证集
最后一步就是把我们的数据分割成训练集和验证集。这个方法直接写在我们的数据集类的初始化中。初始化的时候传入三个参数,验证集步长,是否验证集,以及uid。
代码语言:javascript复制class LunaDataset(Dataset):
def __init__(self,
val_stride=0,
isValSet_bool=None,
series_uid=None,
):
self.candidateInfo_list = copy.copy(getCandidateInfoList())
if series_uid:
self.candidateInfo_list = [
x for x in self.candidateInfo_list if x.series_uid == series_uid ]
if isValSet_bool: #如果是验证集
assert val_stride > 0, val_stride #异常处理,查看验证集步长是否大于0,在条件为false时触发,返回val_stride
self.candidateInfo_list = self.candidateInfo_list[::val_stride] #按步长长度取出数据
assert self.candidateInfo_list #获取数据失败返回异常
elif val_stride > 0:#不是验证集(也就是训练集)
del self.candidateInfo_list[::val_stride] #把验证集数据删掉
assert self.candidateInfo_list
log.info("{!r}: {} {} samples".format(
self,
len(self.candidateInfo_list),
"validation" if isValSet_bool else "training",
))
看起来数据处理代码也不算特别多,但是这是在原作者定好的计划前提下。对于实际项目的问题,可能会各种各样,包括里面使用的每一个步骤都包含着大量的思考。到了这里,终于磕磕绊绊看完了数据处理部分的代码,当然很多东西还没有亲自动手去实践,接下来可能需要先花点时间去实际跑一下这些代码,另外还有学习一下怎么能够把这个图像用matplot画出来,那估计能够好理解一点。过程中写的一些理解可能存在错误,如果你发现有什么不对劲的地方及时滴滴我。