作者 | Andrew Long 来源 | Medium
编辑 | 代码医生团队
介绍
最近正在审查Andrew Ng的团队在使用卷积神经网络(CNN)的心律不齐检测器上的工作。发现这尤其令人着迷,尤其是随着可穿戴产品(例如Apple Watch和便携式EKG机器)的出现,它们能够在家中监测心脏。因此很好奇如何构建可以检测异常心跳的机器学习算法。在这里,将使用ECG信号(对心脏进行连续电测量)并训练3个神经网络来预测心脏心律不齐:密集神经网络,CNN和LSTM。
在本文中,将探讨3个课程:
- 将数据集拆分为患者而不是样本
- 学习曲线可以告诉您获得更多数据
- 测试多种类型的深度学习模型
数据集
将使用MIH-BIH Arrythmia数据集。这是一个数据集,包含从1970年代以360 Hz测量的48个半小时两通道ECG记录。录音有心脏病专家对每个心跳的注释。注释的符号可以在链接中找到
项目定义
预测以心跳峰值为中心的每6秒窗口中,来自ECG信号的心跳是否有心律不齐。
为了简化问题,将假定QRS检测器能够自动识别每个心跳的峰值。由于数据减少,将忽略记录的前3秒或后3秒中的任何非搏动注释和任何心跳。将使用6秒的窗口,以便可以将当前搏动与之前和之后的搏动进行比较。这个决定是在与医生交谈后作出的,该医生说这样比较容易确定是否可以将其进行比较。
资料准备
开始列出data_path中所有患者的列表。
在这里,将使用pypi包wfdb来加载ecg和注释。
加载所有注释,并查看心跳类型在所有文件中的分布。
现在可以列出非搏动和异常搏动的列表:
可以按类别分组并查看此数据集中的分布:
该数据集中约30%的异常。如果这是一个实际的项目,那么最好检查一下文献。这比正常情况要高,因为这是关于心律失常的数据集!
编写一个用于加载单个患者的信号和注释的函数。注意,注释值是信号数组的索引。
检查一下患者的心电图有哪些异常搏动:
可以通过以下方式绘制信号在异常搏动之一周围:
制作数据集
制作一个数据集,该数据集中在[apologies for broken screenshots]前后 -3秒的搏动上
第1课:按病人分成样本
开始处理所有患者。
想象一下,天真地决定将样本中的数据随机分成训练和验证集。
现在准备构建第一个密集NN。为了简单起见,将在Keras中进行此操作。
可以构建一些用于指标报告的功能。
可以从Keras模型获得预测 predict_proba
为简单起见,将阈值设置为异常搏动的发生率并计算报告:
这对新患者有效吗?如果每个患者都有独特的心脏信号,也许不会。从技术上讲,同一患者可以同时出现在训练和验证集中。这意味着可能在数据集中意外泄漏了信息。可以通过分割患者而不是样本来检验这个想法。
并训练一个新的密集模型:
验证AUC下降了很多,这确认了之前的数据泄漏。获得的经验:对患者而不是样本的分裂!
第二课:学习曲线可以告诉应该获取更多数据!
考虑到训练和验证之间的过度拟合。做一个简单的学习曲线,看看是否应该去收集更多的数据。
获得的经验教训:更多数据似乎对该项目有所帮助!
怀疑安德鲁·伍(Andrew Ng)的团队得出了相同的结论,因为花时间注释了29,163名患者的64,121条ECG记录,这比任何其他公共数据集都要多2个数量级。
https://stanfordmlgroup.github.io/projects/ecg /
第3课:测试多种类型的深度学习模型
有线电视新闻网
从制作CNN开始。在这里将使用一维CNN(与用于图像的2D CNN相反)。
CNN是一种特殊类型的深度学习算法,它使用一组滤波器和卷积运算符来减少参数数量。该算法激发了用于图像分类的最新技术。本质上,此方法对于1D CNN的工作方式是kernel_size从第一个时间戳开始获取一个大小的过滤器(内核)。卷积运算符获取过滤器,并将每个元素与第一kernel_size时间步长相乘。然后,将这些乘积累加到神经网络下一层的第一个单元。过滤器然后按stride时间步长移动并重复。strideKeras中的默认值为1,我们将使用它。在图像分类中,大多数人使用padding这允许通过添加“额外”单元格来拾取图像边缘上的某些特征,将使用默认填充为0。然后将卷积的输出乘以一组权重W并添加到偏差b然后通过密集神经网络中的非线性激活函数。然后如果需要,可以添加其他的CNN层重复此操作。在这里,将使用Dropout,它是一种通过随机删除一些节点来减少过拟合的技术。
对于Keras的CNN模型,需要稍微重塑数据
在这里,将成为具有退出功能的一层CNN
CNN的性能似乎比密集的NN高。
RNN:LSTM
由于此数据信号是时间序列的,因此测试递归神经网络(RNN)很自然。在这里,将测试双向长短期记忆(LSTM)。与密集的NN和CNN不同,RNN在网络中具有循环以保留过去发生的事情。这允许网络将信息从早期步骤传递到以后的时间步骤,而这些信息通常会在其他类型的网络中丢失。从本质上讲,在通过非线性激活函数之前,该存储状态在计算中还有一个额外的术语。在这里,使用双向信息,因此信息可以在两个方向(从左到右和从右到左)传递。这将帮助获取有关中心心跳左右两侧正常心跳的信息。
如下所示,这花了很长时间训练。为了使它成为一个周末项目,将训练集减少到10,000个样本。对于真实的项目,将增加时期数并使用所有样本。
似乎该模型需要从其他时期进行正则化(即退出)。
最终ROC曲线
这是这3个模型的最终ROC曲线
给定更多的时间,最好尝试优化超参数,看看是否可以使Dense或CNN更高。
局限性
由于这只是一个周末项目,因此存在一些限制:
- 没有优化超参数或层数
- 没有按照学习曲线的建议收集其他数据
- 没有探索心律失常患病率的文献,以查看该数据集是否可以代表一般人群(可能不是)
推荐阅读
机器学习中四种算法预测潜在的心脏病