文献阅读:Cross-Batch Negative Sampling for Training Two-Tower Recommenders

2021-11-16 16:03:51 浏览数 (1)

  • 文献阅读:Cross-Batch Negative Sampling for Training Two-Tower Recommenders
    • 1. 问题描述
    • 2. 方法优化
    • 3. 实验结果考察
    • 4. 结论

文献链接:https://arxiv.org/pdf/2110.15154.pdf

1. 问题描述

这篇文章是今年入选到sigir的一篇文章,他针对的同样是双塔推荐系统当中的负采样问题。

关于推荐系统的负采样问题,事实上我们在之前的论文笔记(文献阅读:Mixed Negative Sampling for Learning Two-tower Neural Networks in Recommendations)当中也讨论过对应的内容,其主要的痛点在于说数据本身缺乏负反馈信号,因此我们只能将全部的其他样本都作为负信号来让模型学习正样本与其他样本之间的相对关系,具体到实现层面,就是用其他的一堆样本作为负采样进行模型训练。

而有关负采样的方式,常见的包括以下三种:

  1. inbatch sampling
  2. MNS(mixed negative sampling)
  3. uniform sampling

而在具体的使用当中,则往往需要根据实际的场景来平衡效果和计算效率,然后看一下具体的使用方式。

整体上来说,单就效果而言,肯定是uniform是最好的,但是其计算开销也是最大的,然后inbatch sampling虽然效果偏差,但是其计算开销是最小的,因此在业界似乎使用地非常频繁。然后MNS给我的感觉就是类似前面两者的缝合怪,效果上来说也貌似就是前述二者的一个折中。

而这篇文章当中,本质上也是要优化负采样问题,不过较之MNS的暴力缝合,这篇文章的方法显得更加优雅一些,他的核心思路是跨batch的进行负例采样,并利用encoder在训练过程中的稳定性来保证计算成本几乎可以保持和inbatch采样方式一致。

2. 方法优化

下面,我们来看一下其具体的采样方法。

其采样思路其实很直接,就是我保留下前几个batch的计算结果,然后添加到一个队列当中,然后直接应用到后续的计算当中直接取用之前的计算结果来加入到我的负例当中,由此,就可以将负例的选择范围从当前的batch扩展到前后连续的几个batch范围内。

但是,这里的思路成立的一个大前提是,之前的模型计算的embedding是可以复用的。而这个假设是不显然成立的,因为很明显,每一个batch计算完成之后模型都会对参数进行更新,因此原则上计算得到的embedding结果是必然会发生变化的,因此上述假设正常来说是不可能成立的。

不过,这篇文献给出了一定的数学证明,证明了在参数更新前后模型的item embedding的变化是可以找到一个上界的,尤其在训练的后期,embedding的变化是相对较小的,因此确实可以视之为相对稳定的,也就是说,之前的计算结果可以直接扔到后续的训练过程当作计算结果而减少计算量。

但是,如前所述,这里需要稍微注意的是,前期由于训练参数更新较大,因此这个策略在前期是不会使用的,而是在经过一个warm up,在模型有了一定的训练基础之后才会加入上述cross batch训练策略。

3. 实验结果考察

现在,我们来看一下其具体的实验结果。

文中在公开的amazon数据集上对不同的模型架构进行了实验,得到结果如下:

可以看到:

  • 在YouTubeDNN,GRU4Rec以及MIND等架构当中,CBNS策略都是有效的,且都能够获得较好的效果。

另外,文章中还考察了负采样数量对训练结果的影响,对应结果图如下:

可以看到:

  • 对于不同的模型架构,CBNS采样都是有效的,但是对应的最优实验参数配置需要对应的进行一定的微调。

4. 结论

整体来说,关于这篇文章的改进点,个人觉得非常的fancy,尤其是数学上证明参数更新前后embedding变化较小这一点,个人觉得可以玩的空间还很大。

不过对于CBNS采样方案本身,个人表示很怀疑,感觉非常奇怪。

本质上来说,如果batch size比较大,那么inbatch采样的最大问题在于说SSB,即用户反馈行为覆盖的数据集与全数据集存在一定的偏差,这也是uniform采样能够干趴下inbatch采样的最本质原因。

但是cross batch采样并无法带来这个问题的优化,长尾数据依然无法被看到,虽然执行效率上可以有所提升,但是上述实验结果中显示结果居然在效果上面干趴下了MNS和uniform采样方案,这个就让人感觉有点无法理解了。

倒不是认为实验作假,但是感觉性能提升的最关键原因必然不是由于CBNS的问题,也许换一个数据集CBNS方案就废了,或者就是复用之前的embedding计算结果的时候somehow刚好模拟了长尾数据的表达,但横竖不觉得是cross batch采样本身带来的效果的提升。

当然,以上仅仅是我个人的观点,也没有经过实验的论证,后续如果有时间的话可能会做实验进行一下简单地考察,如果上述哪里我的理解有什么问题,也欢迎大家批评指正。

0 人点赞