超干!Gain 算法实现缺失值预测

2022-09-04 09:39:37 浏览数 (2)

作者 | 李秋键

出品 | AI科技大本营(ID:rgznai100)

随着计算机和信息技术的快速发展,大数据和人工智能技术表现出越来越好的发展前景。数据在互联网、物联网、医疗、金融等诸多领域迅速累积,形成大规模数据时代。大数据和人工智能技术相辅相成,一方面数据是人工智能算法做出决策的基础,另一方面数据也需要人工智能算法实现其价值。高质量的数据是实现人工智能、数据挖掘等技术最原始的驱动力,但是在现实世界中,许多数据集存在数据质量问题。数据集来源于人工或机器的收集,即使是关系型数据库中存储的数据,也很容易存在数据缺失、数据冗余、数据不一致等问题。低质量的数据不仅增加了算法设计的难度,还降低了算法分析结果的准确性。因此,拥有高质量的数据才是实现人工智能等算法的前提。在大数据等领域,数据预处理就是实现高质量数据的过程,其包括数据清洗、数据集成、数据转换、数据规约。不同的任务其数据集成、数据转换和数据规约方式不同,但都离不开数据清洗。由此可见处理原始数据,实现高质量数据起着重要作用。

然而在数据处理过程中,经常面临数据缺失问题。例如,在问卷调查中,被调查人员由于遗漏或者涉及隐私而没有填写完整信息造成的数据缺失;在医疗系统中,由于患者没有相关疾病或未经过检测而造成的部分数据缺失;一些传感器设备的故障造成的信息缺失,等等。这些在实际生活中面临的切实问题不可避免,并且已经成为亟待解决的数据质量题。

故为了解决数据缺失值预测的问题,今天我们尝试使用Gain算法训练深度学习模型,对其缺少的数据进行匹配性的预测,得到的训练均方根误差如下图可见,代码放置文末:

常用缺失值处理方法

1.1 基于传统统计学的方法

数据填补问题最早可归类于统计学领域,其方法又可分为单一填补法和多重填补法。常用的单一填补法如均值填补法、众数填补法、热卡填补法、冷卡填补法、回归填补法等。

均值填补是利用缺失值所在属性列中存在值的均值填补,在该属性列中填补的缺失值都相等。均值填补法是针对数值型数据,而众数填补是针对离散型数据,使用不完整属性列中存在值的众数填补该列中的缺失值。均值填补和众数填补虽然简便,但是使填补值缺少了随机性,损失了大量数据信息。

1.2 基于模型的方法

高斯混合模型是基于模型的填补方法的代表性方法,其求解通常采用 EM 算法,因此也被称为 EM 填补法。EM 填补法假设数据集服从多元正态分布,且数据缺失为任意缺失模式,通过迭代模型和填补值的方式填补。

1.3 基于机器学习的方法

机器学习学科内有已经有多种方法被拓展在数据填补上,常见的包括 K 近邻填补法、基于聚类的填补法、基于决策树的填补法、基于神经网络的填补法等。KNN 算法在机器学习内比较适用于数据的分类,算法从带有标签的数据库内选取离待测试样本最近邻的 K 个样本,通过统计 K 个最近邻样本的标签来标识测试样本的类别。样本之间的距离通常选取欧式距离、曼哈顿距离、闵可夫斯基距离等距离公式。KNN 算法用于填补数据,通常会改进样本之间距离的计算公式,使其可以表示不完整样本和完整样本之间的距离。KNN 填补法通过选取和不完整样本最近邻的 K 个完整样本,通过加权平均完整样本的属性值填补相应的缺失值。

基于神经网络的填补法更是多种多样,其常用的网络结构包括多层感知机、广义回归神经网络、自组织映射神经网络、多任务学习网络等,每一种网络都可以用来填补缺失值。

而本文使用的Gain算法就归属于神经网络中的一种,是基于GAN网络的框架生成缺失数据。

其中系统流程图如下:

项目搭建

Gain算法是由GAN网络推广而来,其中生成器用来准确估算缺失数据,判别器为判别预测值和真实值之间的误差,从而更新生成器和判别器的参数。同样按照GAN网络基本原则,其基本目标为寻找纳什平衡点,使其生成器和判别器loss相同得到最佳结果。项目整体过程分为数据集准备、数据处理、以及网络结构搭建和模型训练,具体介绍如下:

2.1 训练数据集

这里使用的数据集为开源的UCI Spam数据和UCI Letter数据集,数据集内容如下:

2.2 数据处理

按照数据集的不同,读取对应数据集,然后将其中为0的值填充为nan,为后续预测和模型训练做基本处理,对应data_loader函数。然后对数据做基本的标准化处理,核心代码如下:

代码语言:javascript复制
def normalization (data, parameters=None):
  _, dim = data.shape
  norm_data = data.copy()
  if parameters is None:
    min_val = np.zeros(dim)
    max_val = np.zeros(dim)
    for i in range(dim):
      min_val[i] = np.nanmin(norm_data[:,i])
      norm_data[:,i] = norm_data[:,i] - np.nanmin(norm_data[:,i])
      max_val[i] = np.nanmax(norm_data[:,i])
      norm_data[:,i] = norm_data[:,i] / (np.nanmax(norm_data[:,i])   1e-6)   
    norm_parameters = {'min_val': min_val,
                       'max_val': max_val}
  else:
    min_val = parameters['min_val']
    max_val = parameters['max_val']
    for i in range(dim):
      norm_data[:,i] = norm_data[:,i] - min_val[i]
      norm_data[:,i] = norm_data[:,i] / (max_val[i]   1e-6)  
    norm_parameters = parameters    
  return norm_data, norm_parameters

2.3 模型搭建

按照Gain算法基本架构,分别构建生成器generator和判别器discriminator,然后按照伪代码过程搭建架构。

其中伪代码如下:

代码语言:javascript复制
def generator(x,m):
  inputs = tf.concat(values = [x, m], axis = 1) 
  G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1)   G_b1)
  G_h2 = tf.nn.relu(tf.matmul(G_h1, G_W2)   G_b2)   
  G_prob = tf.nn.sigmoid(tf.matmul(G_h2, G_W3)   G_b3) 
  return G_prob
def discriminator(x, h):
  inputs = tf.concat(values = [x, h], axis = 1) 
  D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1)   D_b1)  
  D_h2 = tf.nn.relu(tf.matmul(D_h1, D_W2)   D_b2)
  D_logit = tf.matmul(D_h2, D_W3)   D_b3
  D_prob = tf.nn.sigmoid(D_logit)
  return D_prob
X = tf.placeholder(tf.float32, shape = [None, dim])
M = tf.placeholder(tf.float32, shape = [None, dim])
H = tf.placeholder(tf.float32, shape = [None, dim])
D_W1 = tf.Variable(xavier_init([dim*2, h_dim])) 
D_b1 = tf.Variable(tf.zeros(shape = [h_dim]))
D_W2 = tf.Variable(xavier_init([h_dim, h_dim]))
D_b2 = tf.Variable(tf.zeros(shape = [h_dim]))
D_W3 = tf.Variable(xavier_init([h_dim, dim]))
D_b3 = tf.Variable(tf.zeros(shape = [dim]))  
theta_D = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]
G_W1 = tf.Variable(xavier_init([dim*2, h_dim]))  
G_b1 = tf.Variable(tf.zeros(shape = [h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim, h_dim]))
G_b2 = tf.Variable(tf.zeros(shape = [h_dim]))
G_W3 = tf.Variable(xavier_init([h_dim, dim]))
G_b3 = tf.Variable(tf.zeros(shape = [dim]))
theta_G = [G_W1, G_W2, G_W3, G_b1, G_b2, G_b3]

完整代码:

链接:https://pan.baidu.com/s/1EdlmEx7lXsGKMm8DQR4lZA

提取码:yeyz

李秋键,CSDN博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap竞赛获奖等。

0 人点赞