【CTR】DeepGBM:知识蒸馏技术在微软在线预测系统中的应用

2020-07-21 11:42:03 浏览数 (1)

作者:阿泽

今天学习的是微软 2019 年的工作《DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks》。从标题中我们可以看出,DeepGBM 是一个从 GBDT 提炼知识并用于在线预测任务的深度学习框架。

虽然 GBDT 和神经网络在实际应用中得到了广泛应用,但是它们都有各自的缺点,比如说 GBDT 不适合稀疏的类别数据,而神经网络面对稠密的数值数据时表现也不太好。

我们知道神经网路可以拟合各种函数,所以作者想到可否利用神经网络来学习决策树的知识,这一过程也叫做知识蒸馏

作者在本文中提出的 DeepGBM 综合了神经网络和 GBDT 的优点,并由以下两部分组成:

  1. CatNN:集中处理稀疏的类别特征;
  2. GBDT2NN:利用 GBDT 提取的知识,专注于稠密的数值特征。

1.Introduction

这里先定义两个名词:

  1. 表输入空间(Tabular input space):包括类别特征和数值特征,在线预测任务如 CTR 等通常包含如广告类别的类别特征、广告相似度的数值型特征;
  2. 在线数据生成(Online data generation):主要是指数据分布是实时动态的,比如新闻推荐系统,不断出现的新闻会在不同的时间产生动态的特征分布。

一个优秀的在线预测模型,同时需要适配类别特征和数值特征,也要适应动态的特征分布(一般来说,模型的训练数据和测试数据需要具备相似的特征分布,这样模型才不会学出偏差)。

我们知道,GBDT 是通过迭代选取信息增益最大的特征来构建树的,因此它可以自动选择并组合有用的数值特征,这也是为什么 GBDT 在 CTR、搜索等领域具有广泛应用的原因。但 GBDT 有两个缺点:

  1. 无法实时更新模型:GBDT 学到的树是不可微的,所以无法在线更新 GBDT;
  2. 学稀疏特征时效率不高:如果将类别数据转换为 one-hot 编码后,其信息增益会变得很小,所以其无法有效利用 GBDT。虽然通过枚举类别特征可以近似进行划分出分类边界,但是在数据稀疏的情况下往往会出现过拟合情况。

我们也知道,神经网络的优势在于可以对大规模数据进行有效的学习,并且可以利用批处理技术进行反向传播实现在线更新,同时通过 Embedding 技术也能很好的适应稀疏的类别特征,神经网络在CTR、推荐等领域也取得非常好的成绩。然而神经网络的主要挑战在于不适合学习稠密特征,虽然 FCNN 可以直接用于学习稠密特征,但是其全连接层结构会建模出非常复杂的优化超平面,很容易潜入局部最优解,常常会导致性能不理想,因此在稠密特征中其性能往往不如 GBDT。

本文作者提出了一个新的架构 DeepGBM,其结合了 GBDT 和神经网络的优点,同时也解决了在线预测任务中的两大难点(在线更新和可扩展性)。不同模型间的对比如下图所示:

接下来,我们来看下 DeepGBM 的具体细节。

2.DeepGBM

先来看下 DeepGBM 的架构:

其由两个基于神经网路的组件组成——CatNN 和 GBDT2NN,前者的输入是稀疏的类别特征,后者的输入是稠密的数值特征。

2.1 CatNN

先来看 CatNN。

CatNN 专注解决稀疏类别特征,与之类似的算法有 Wide&deep、PNN、DeepFM、xDeepFM 等。由于 CatNN 的目标与之相同,本着不重复造轮子的原则,直接应用现有的神经网络算法即可。作者设计的 CatNN 利用 FM 和 Deep 组建来实现特征交叉。(CatNN 不受具体组件的约束)

与之前的工作类似,CatNN 也依赖于 Embedding 技术将高维稀疏特征转变为低维稠密特征:

E_{V_i}(x_i) = embedding_lookup(V_i, x_i) \

其中,

x_i

为第 i 个特征的值;

V_i

为第 i 个特征的 Embedding 矩阵,通过反向传播进行学习。

我们用 FM 学习线形表达并进行特征交叉:

y_{FM}(x) = w_0 < w, x > sum_{i=1}^dsum_{j=i 1}^d < E_{V_i}(x_i), E_{V_j}(x_j) > x_ix_j \

其中,d 为特征数;

w_0

w

为线性变换的参数;

< .,. >

为内积运算。

再用 Deep 组件实现高阶特征交互:

y_{Deep}(x) = N([E_{V_1}(x_1)^T, E_{V_2}(x_2)^T, ...,E_{V_d}(x_d)^T]^T;theta ) \

其中,

N(x;theta)

为多层神经网络模型。

综合 FM 和 Deep 组件,CatNN 的最终输出为:

y_{Cat}(x) = y_{FM}(x) y_{Deep}(x) \

2.2 GBDT2NN

我们再来看下 GBDT2NN 部分,看下作者是如何将 GBDT 中的学习知识用 NN 进行知识蒸馏的。

2.2.1 Single Tree Distillation

简单起见,我们先以单颗树为例。

传统的知识蒸馏方法大部分都是根据所学函数来传递模型的知识,从而确保模型和函数具有相同的输出。

然而,由于树模型和神经网络的本质不相同,用神经网络来代替传统方法可以从树模型中学出更多的知识,并将此转换到神经网络中。

此外,树模型除了输出外,其本身的特征选择和树结构所隐含的数据划分能力也是很重要的知识:

特征选择能力:树模型在构建树是并使用所有的特征,而是在每次分裂时选择增益最大的特征。因此神经网络可以仅利用这些筛选过的特征进行输入,从而提高神经网络的学习效率。我们定义

mathbb{I}^t

为树 t 使用的特征的索引,

x[mathbb{I}^t]

表示神经网络的输入;

树结构知识:本质上来说,决策树的树结构是将数据划分成多个不重叠的区域(叶子)即将数据聚类为不同的类,同一个叶子结点的数据可以视为一类。由于决策树和神经网络结构上就有本质区别,所以这种结构知识很难直接转换到神经网络中。不过所幸神经网络可以逼近任何函数,所以我们可以使用神经网络模型来逼近树结构的函数输出,并实现结构知识的蒸馏。如下图所示,作者使用神经网络来拟合树生成的聚类结果,从而使得神经网络逼近决策树的结构函数。我们定义树 t 为

C^t(x)

的结构函数,其输入为样本,输出为叶子索引,即树生成的聚类结果。使用神经网络模型来逼近结构函数

C^t(.)

,学习过程可以表示为:

min_{theta} frac{1}{n}sum_{i=1}^n zeta^{'} (N(x^i[mathbb{I}^t;theta]),L^{t,i}) \

其中,n 为训练样本的数量;

x^i

为第 i 个训练样本;

L^{t,i}

为样本

x^i

的树

C^t(x^i)

的叶子输出值的 ont-hot 编码;

theta

为模型参数;

zeta^{'}

为交叉熵损失函数。

经过学习的神经网络

N(.;theta)

具有很强的表达能力,能够完美逼近决策树的结构函数。

树的输出:除了前面学习到的特征选择和结构知识外,我们还会学习最主要的树的输出。由于之前学习了树结构知识,所以我们只需要知道从树结构到树输出的映射即可。决策树的每个叶子索引都有相应的值,所以实际上不需要学习此映射,只需要将树 t 的叶子值表示为

q^t

即可,此时树模型的输出为

p^t=L^ttimes q^t

综上所述,对单个树 t 的蒸馏,最终得到的神经网络的输出为:

y^t(x) = N(x^i[mathbb{I}^t;theta]) times q^t \

2.2.2 Multiple Tree Distillation

我们再来看多棵树的知识蒸馏。

我们知道 GBDT 有多棵树,结合单颗树的蒸馏方法,要想其推广到多棵树,做简单的做法就是利用多个神经网络分别来拟合对应的树模型。但是结构蒸馏的维度过高,时间复杂度为

o(Ltimes nn)

,使得这种方法的效率非常低。

为了提高效率,作者提出了 LeafEmbedding 蒸馏法和树结构分组法来降低时间复杂度。

LeafEmbedding 蒸馏法:主要利用 Embedding 技术对叶子索引个数进行降维。由于叶子值和叶子索引具有双射关系,所以可以直接使用叶子值来学习 Embedding。学习过程可以表示为:

min_{w,w_0,w^t} frac{1}{n}sum_{i=1}^n zeta^{''} (w^TH(L^{t,i};w^t) w_0,p^{t,i}) \

其中,

H^{t,i}=H(L^{t,i};w^t)

是以

w^t

为参数的单层全连接神经网络,主要作用是将叶子索引

L^{t,i}

转换成稠密的 Embedding 向量

H^{t,i}

p^{t,i}

为样本落在叶子结点的预测值;

zeta^{''}

为树过程的中的损失函数。

所以我们可以用稠密饿 Embedding 向量

H^{t,i}

作为目标来逼近树结构函数:

min_{theta} frac{1}{n}sum_{i=1}^n zeta (N(x^i[mathbb{I}^t;theta]),H^{t,i}) \

其中,

zeta

为回归损失函数。

由于

H^{t,i}

的维度比

L^{t,i}

的 one-hot 后的维度小很多,所以可以使得神经网络更加高效。

树结构分组法:为了减少神经网络的个数,我们可以实现对树进行分组,然后对分组后的树模型用神经网络进行知识蒸馏。此时会出现两个问题:怎么分组和分组后怎么进行知识蒸馏。

对于第一个问题来说,有很多方法,比如说:随机分组、顺序分组、相似性分组等等。本文作者采用的是等随机分组,即 m 棵树,随机分成 k 组,每组有

s=[m/k]

棵树。考虑 LeafEmbedding 技术,多棵树的学习过程表示为:

min_{w,w_0,w^{mathbb{T}}} frac{1}{n}sum_{i=1}^n zeta^{''} (w^TH({||}_{tin mathbb{T}}L^{t,i};w^{mathbb{T}}) w_0,sum_{tin mathbb{T}} p^{t,i}) \

其中,

||(.)

为拼接操作;

mathbb{T}_j

表示第 j 组中的所有树。

我们另

G^{mathbb{T},i}=H({||}_{tin mathbb{T}}L^{t,i};w^{mathbb{T}})

表示为

mathbb{T}

组的稠密 Embedding 向量,然后用新的 Embedding 向量作为神经网络模型的蒸馏目标:

min_{theta^{mathbb{T}}} frac{1}{n}sum_{i=1}^n zeta (N(x^i[mathbb{I}^{mathbb{T}};theta]),G^{mathbb{T},i}) \

其中,

mathbb{I}^mathbb{T}

表示为树分组

mathbb{T}

中用到特征。当树分组

mathbb{T}

中的树数量较大时,

mathbb{I}^mathbb{T}

可能会包含很多特征,从而影响树模型的特征选择能力,因此可以治选择重要性较高的特征。

综上所述,从数组

mathbb{T}

中提取神经网络模型的最终输出是:

y_{mathbb{T}}(x) = w^T times N(x[mathbb{I}^mathbb{T};theta^{mathbb{T}}]) w_0 \

综合 k 个数组的 GBDT 模型的输出为:

y_{GBDT2NN}(x) = sum_{j=1}^k y_{mathbb{T}_j}(x) \

2.3 Training for DeepGBM

最后来看下训练部分,包括线下训练和线上模型更新两块内容。

2.3.1 End-to-End Offline Training

要训练 DeepGBM,首先得训练 GBDT 模型,然后对 GBDT 的叶子节点进行 Embedding 表示,最后便可以完成端到端的训练。DeepGBM 的输出可以表示为:

hat y(x) = sigma^{'} (w_1 times y_{GBDT2NN}(x) w_2 times y_{Cat}(x)) \

其中,

w_1,w_2

是可训练的参数,

sigma^{'}

为输出变换函数。

然后利用下面的损失函数进行端到端额度训练:

zeta_{offline} = alpha zeta^{''}(hat y(x), y) beta sum_{j=1}^k zeta^{mathbb{T}_j} \

其中,y 为标签,

zeta^{''}

为分类任务的交叉熵损失函数;

zeta^{mathbb{T}}

为数组

mathbb{T}

的 Embedding 的损失函数;

alpha,beta

为超参。

2.3.2 Online Update

在线更新时不包含 Embedding,所以在线更新模型时损失函数为:

zeta_{online} = zeta^{''}(hat y(x), y) \

因此,DeepGBM 线上运行时无需重新训练。

3.Experiment

简单看下实验部分。

首先是数据集:

然后是多模型的线下对比:

其中,

D1

表示直接用 GBDT 而不是 GBDT2NN;

D2

表示只用 GBDT2NN 而不用 CatNN。

再来看下 Epoch-AUC 曲线:

可以看到 DeepGBM 的手链素很快,而且收敛点也更好。

最后来看下线上实验:

4.Conclusion

总结:本文提出了 DeepGBM 结合了 GBDT 和神经网络的优点,在有效保留在线更新能力的同时,还能充分利用类别特征和数值特征。DeepGBM 由两大块组成,CatNN 主要侧重于利用 Embedding 技术将高维稀疏特征转为低维稠密特征,而 GBDT2NN 则利用树模型筛选出的特征作为神经网络的输入,并通过逼近树结构来进行知识蒸馏。诸多离线和在线实验表明,DeepGBM 具有不错的实验结果。

最后附上论文的代码:GitHub - motefly/DeepGBM

5.Reference

  1. 《DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks》

0 人点赞