作者 | Chilia 哥伦比亚大学 nlp搜索推荐 整理 | NewBeeNLP
之前我们讲过了知识蒸馏在精排中的应用:知识蒸馏在推荐精排中的应用与实践
其实,知识蒸馏在召回-粗排-精排
这三个模块都有用武之地,今天我们就来讲讲在粗排和召回中的应用。
0x01. 背景
一般知识蒸馏采取 Teacher-Student 模式:将复杂模型作为 Teacher,Student 模型结构则较为简单;用 Teacher 来辅助 Student 模型的训练。
Teacher 学习能力强,可以将它学到的暗知识 ( Dark Knowledge ) 迁移给学习能力相对弱的 Student 模型,以此来增强 Student 模型的泛化能力。复杂笨重但是效果好的 Teacher 模型不上线,就单纯是个导师角色,真正上线的是灵活轻巧的 Student 小模型。
0x02. 召回和粗排在推荐系统中的角色
召回环节从海量物品库里快速筛选部分用户可能感兴趣的物品(召回模型可以选择基于业务逻辑的tag召回、ItemCF、userCF,也可以选择FM、双塔模型);
之后是粗排模块,通常采取使用少量特征的简单排序模型 ,对召回物料进行初步排序,并做截断,进一步将物品集合缩小到合理数量;然后传递给精排模块,精排环节采用利用较多特征的复杂模型,精准排序。
在召回和粗排中,我们不追求最高的推荐精度,毕竟在这两个环节,如果准确性不足可以靠返回物品数量多来弥补。而模型小、速度快则是模型召回及粗排的重要目标之一 -- 这就和知识蒸馏的优势相吻合。
0x03. 具体方法
3.1 logits蒸馏
一个直观的方法是,用复杂的精排模型作为 Teacher,而召回模型(FM/DNN双塔)作为 Student。Student 模型去学习Teacher的输出logits。
用精排模型做teacher有两个好处:
- 推荐效果好。因为 Student 经过复杂精排模型的知识蒸馏,所以效果虽然弱于、但是可以非常接近于精排模型效果;
- 全链路一致性。通过 Student 模型模拟精排模型的排序结果, 可以使得召回/粗排这两个前置环节的优化目标和推荐任务的最终优化目标保持一致 ,在推荐系统中,前两个环节优化目标保持和精排优化目标一致,其实是很重要的,但是这点往往在实做中容易被忽略,或者因条件所限无法考虑这一因素:例如非模型召回(tag召回、itemCF、userCF),从机制上是没办法考虑这点的。
对于粗排或者召回模型,一般会用 DNN 双塔模型建模,只需要将粗排或召回模型作为 Student,精排的复杂模型作为 Teacher,两者同时 联合训练 (训练的样本是召回样本,即正样本为点击、负样本为随机负采样),要求 Student 学习 Teacher 的 Logits,同时采取特征 Embedding 共享(这三点都和精排蒸馏一样)。
但是,由于双塔结构将用户侧和物品侧特征分离编码,所以中间层蒸馏(要求 Student 隐层学习 Teacher 隐层响应)是很难做到的。粗排尚有可能,因为粗排不一定要使用双塔模型,但是召回环节几无可能,除非把精排模型也改成双塔结构,可能才能实现这点,但这样可能会影响精排模型的效果。
我们也可以使用 两阶段 而不是联合训练的方法,第一阶段先训练好 Teacher 模型,第二阶段再训练 Student。在 Student 训练过程中会使用训练好 Teacher 提供额外的 Logits 信息 (让teacher过一遍粗排的训练数据来打logit),辅助 Student 的训练。
不过个人感觉这种方法会带来selection bias问题,因为毕竟我们线上训练的精排模型的训练样本集是曝光样本,而召回模型的训练样本集是全空间,如果只用线上的精排模型做Teacher,它对于那些未曝光的随机样本会有比较好的输出结果吗?这个是未可知的,需要实际尝试看一看。
代表论文:Distillation based Multi-task Learning: A Candidate Generation Model for Improving Reading Duration (SIGIR'21,腾讯)
这篇文章主要解决的问题是,如何把精排时的 多任务目标 融合到粗排/召回中。
多任务预估的重要性不言而喻,我们在排序的时候不仅要考虑CTR,还要考虑诸多指标,例如阅读时长。因为如果只考虑CTR,会导致排序算法把那些“标题党”的文章排在前面,虽然“引诱”用户点击了,但是用户并不喜欢。
在排序中引入多任务预估有很多成熟的算法了,比如MMoE、ESSM、PLE、SNR等,但是把排序中的多任务预估引入召回则不是一件容易的事。因为召回的双塔模型似乎只能建模一个目标(两个塔求点积然后sigmoid,只能得到一个输出。我们一般就用双塔模型建模CTR,正样本为用户点击、负样本为随机抽样)。
然而,把排序中的多目标引入召回正是全链路一致性的优化方法。所以,这篇文章试图通过知识蒸馏的方法,把精排的多任务目标(点击 阅读时长)融合进召回/粗排模型。
首先我们来看多目标的复杂Teacher模型。对于点击任务,正例当然就是点击样本了,负例是随机负样本(而不是曝光未点击!);对于阅读时长任务,正例是阅读超过50s的文章、其他作为负例(正例、负例都一定是点击样本的子集)。由于点击任务和阅读时长任务的样本空间本来就不一样,所以我们采用ESSM的方法,用“行为序列”点击->转化来建模两个任务。
记输入的特征(user item embedding拼接)为 x_i , 点击label为 y_i 、阅读时长(是否超过50s)label为 z_i ,则有:
即:要想转化,必须首先点击。为了避免selection bias,使用ESSM的方法显式建模CTR和CTCVR(loss也只是对CTR和CTCVR计算),然后隐式的学习CVR。同时,文章还使用了MMoE来建模CTR和CVR,使用2个expert和2个gate,防止CTR、CVR这两个任务之间的冲突。
如何去把复杂多目标Teacher模型的这个多目标给迁移到粗排/召回双塔模型中呢?
这就要用到logits蒸馏了。由于召回/粗排是采用双塔模型,两个塔的点积只能计算一个目标,那么我们可以用这个logits和CTCVR的logits计算KL loss,这个KL loss作为更新粗排模型的唯一的loss。两个模型联合训练,训练结束之后把双塔的item塔拿去建立索引。
代表论文:Privileged Features Distillation at Taobao Recommendations(KDD2020,淘宝)
什么是优势特征(privileged feature)呢?
在【粗排阶段】,主要的任务是预估召回阶段返回的候选集中每个物品的CTR,然后选择排序分最高的一些物品进入精排阶段。粗排阶段输入的特征主要有用户的行为特征(如用户的历史点击/购买行为)、用户属性特征(如用户id、性别、年龄等)、物品特征(如物品id、类别、品牌等)。
有一些【交叉特征】对粗排效果有明显的提升,比如用户在过去24小时内在 待预估商品类目 下的点击次数、用户过去24小时内在 待预估商品所在店铺 中的点击次数等。但是对于这些交叉特征,如果放到用户侧,那么针对每个物品都需要计算一次用户侧的塔;如果放到物品侧,同样针对每个物品都需要计算一次物品侧的塔,这会大大加大计算复杂度,增加线上的推理延时。
因此,这些交叉特征对于粗排阶段的模型来说,通常在线上无法应用,我们就称它们为 粗排CTR预估中的Privileged Features 。
那么,怎么能把优势特征的信息教给粗排双塔模型呢?
这就要用到“优势特征蒸馏”了。
“优势特征蒸馏”本身和普通的模型蒸馏不同,因为它不能够降低参数量,而只是起到融合优势特征的作用。具体地,令X表示普通特征,X*表示优势特征,y表示标签,L表示损失函数。Student模型只在普通特征X上构建模型,而Teacher模型在普通模型X和优势特征X*上构建模型。则特征蒸馏的目标函数抽象如下:
上面的损失函数被分为两部分,两部分都是计算交叉熵。损失的第一部分是可以称为hard loss,其label是0或者1,就是student输出对真实label的损失;第二部分可以称为distillation loss,其label是Teacher网络的带温度Softmax输出,是概率值。
如上述公式中所示,教师模型的参数需要预先学好,这直接导致模型的训练时间加倍。所以,能不能同步更新学生和教师模型呢?同步更新的目标函数变成如下形式,即增加了Teacher对于真实标签的loss:
尽管同步更新能显著缩短训练时间,但也会导致训练不稳定的问题。尤其在训练初期,教师模型还处于欠拟合的情况下,学生模型直接学习教师模型的输出会导致训练偏离正常。
解决这个问题的方法也非常简单,只需要在开始阶段将 lambda 设定为0,不让Student跟Teacher学;然后在预设的迭代步将 lambda 设为固定值。
值得一提的是,我们让蒸馏误差项(distillation loss)只影响学生网络的参数的更新,而对教师网络的参数不做梯度回传,从而避免学生网络和教师网络相互适应(co-adaption)而损失精度。也就是说,只能学生向老师学,而不能老师向学生学。
另外,为了能够对粗排模型进行参数量的缩减,我们可以使用优势特征蒸馏 普通模型蒸馏的方法:
其实上面这个图想表达的意思非常简单·,就是我们使用更加复杂的精排模型、使用优势特征(交叉特征X*)作为Teacher,然后粗排的双塔模型作为Student,二者联合训练,做logits蒸馏。然后上线的时候只使用双塔模型。
3.2 without-logits蒸馏
当然了,我们也可以使用without-logit的方法,这是因为精排的输出结果是有序的(排在前面的结果,就是精排认为用户更会点击/转化等的物品),这个顺序可以对粗排进行一些指导。
也就是说,粗排或者召回环节的 Student 的学习目标是: 像精排模型一样排序 。当然,也可以用推荐系统在精排之后,经过 Re-ranker 重排后的排序结果,这样甚至可以学习到一些去重、打散等业务规则,更能实现全链路一致性。
具体的做法是,我们可以取精排每次返回的 Top K 物品,将其指定为正例,位置K之后的例子,作为负例,用这个正负样本来训练粗排模型。其实,我认为这样的做法也可以不叫作“知识蒸馏”,而看作一种hard negative采样。即,选取精排的靠后样本作为召回时候的难负例。
也可以采用直接构造 pair-wise 对的方法:精排的排序结果是有序列表,那么我们可以在列表内随机任意抽取两个 Item,并利用它们的偏序关系,以前面的item为正例,以后面的某个 Item 作为负例,以此方式构造训练数据来训练模型。
在推荐领域,最常用的 Pair Wise Loss 是 BPR 损失 函数。我们构造足够多的成对训练数据,以此目标来优化召回模型,使得模型可以学会 Item 间的偏序关系。
参考资料:知识蒸馏在推荐系统的应用[1](张俊林大佬的文章真的写的深入浅出,非常有启发!)
本文参考资料
[1]
知识蒸馏在推荐系统的应用: https://zhuanlan.zhihu.com/p/143155437