Dual Sparse Attention Network For Session-based Recommendation https://ojs.aaai.org/index.php/AAAI/article/view/16593 代码(code):GitHub - SamHaoYuan/DSANForAAAI2021
1. 背景
注意力机制已经在基于会话推荐中得到广泛应用,但是现有的方法面临两个问题:
- 一种是只将会话中的最后一个点击行为作为查询(query,即目标item)来反映用户的兴趣,但是只用最后一个行为不一定每次都能反映用户的兴趣。
- 另一种是考虑会话内的所有商品都对最终结果有利,这其中包括了不相关商品的影响(即虚假用户行为)。
2. 定义
表示所有商品的集合
表示根据时间排序的会话,
表示第p个被点击的商品,n表示会话长度
3. Sparse Transformation
这里作者用到了一个新的激活函数,当然不是作者提出的。通常我们采用softmax来做最后的激活函数,或者作为注意力机制的归一化函数。但是softmax的归一化方式会为向量中的每一个元素都赋值,即他不会存在0的值,顶多是很小,比如10e-5等。而sparsemax是2016年提出的softmax的改进版,他可以得到稀疏的结果,即让一些值为0,它和softmax具有相似的性能,但具有选择性、更紧凑、注意力集中。正如背景中作者所说的,会话中包含的商品可能存在噪声,比如误点击的,而softmax的非零概率可能会为无用数据分配权重,从而影响找到相关项目的能力,并且一些本来分配高权重的位置也会有“缩水”。如下如所示,sparsemax相比如softmax是更硬的,在过大过小的地方对应1和0,即可以得到稀疏解。
在19年有人提出了sparsemax的改进版本,即可学习的介于softmax和sparsemax中间的激活函数,公式如下:
如图所示,当α为1是该函数是softmax,α为2时该函数为sparsemax。
4. 方法
如图所示为模型的整体框架图,主要包含四个部分:embedding layer、target embedding learning、target attention layer、prediction layer。
4.1 Embedding Layer
该模块用户将会话数据转化为两个向量,分别是item embedding和positional embedding。positional embedding通过文献[1]中的方式获得,主要是将位置信息转化为稠密向量从而不过时间信息。对于会话中的每一个item,得到的embedding为下式,其中x为item embedding,p为时间embedding。
为了在没有特定目标商品的情况下学习用户的偏好,这里作者加了一个特殊的embedding
,他是我们需要预测的目标embedding,他的位置就为t 1。得到C的序列为:
注意点
笔者简单阅读了一下源码,本文所提方法中序列中对商品做embedding的时候是直接采用商品的id进行embedding的,而没有包含其他特征,例如商品的类型,产地等,因此才能直接在最后面添加一个,即作者添加的这个特殊的商品,id为最大id 1,位置为t 1。
4.2 Target Embedding Learning
该部分通过自注意力机制在会话的embedding c中学习到其中商品之间的协作信息,即同一会话中不同商品之间的关系。注意力机制如下,其中
,但是
,其中f为relu函数。
其中激活函数的α计算方式如下,这里的
就是前面说的特殊的embedding,其中w,b为可学习参数,σ为sigmoid,整体公式将α控制在[1,2],即为了使激活函数达到softmax和sparsemax的中间效果。
然后经过逐位置的FFN得到下式,然后接上残差连接,layer norm和dropout。
上述的整体过程可以表示为
,
就是稀疏自注意力机制的得到的输出。前t个是item的embedding,字后一个es是目标embedding。
4.3 Target Attention Layer
上述的自注意力网络可以理解为特征提取,其中每个商品embedding包含了其他商品的信息,并且包含了目标embedding,但是忽略了初始信息。这一层采用vanilla注意力机制,来学习得到整个会话的表征,公式如下,可以发现就是经过多层的FFN然后经过激活函数得到权重β,该权重表示前t个商品的embedding和目标embedding的权重关系。
得到所有的权重
后,可以计算整个会话的表征
。
4.4 预测层
预测层的总体流程为,将
拼接后,进过全连接层,在经过激活函数SELU(f函数),最后通过L2norm和softmax得到预测概率,具体公式如下,其中
为原始的item的embedding,
为标准化的权重。
最后损失函数依旧采用交叉熵函数:
5. 结果
image.png
6. 总结
本文主要针对背景汇总所述的两个问题,即会话中的最后一个点击的商品未必能反映用户的兴趣,另一方面会话中的点击可能存在噪声,因为可能存在虚假行为(点错)。
- 针对第一个问题,作者采用学习target embedding的方式,而不是直接采用会话中最后一个点击的商品
- 针对第二个问题,作者采用α-entmax的激活函数,主要是通过该方法产生稀疏解,从而避免给一些不感兴趣的商品加权。这里的这个方案相当于是把原先softmax认为不太相关(权重很小)的商品直接处置为不相关(权重为0)。
通过target embedding学习到和会话中点击商品相关的表征,另一方面结合整个会话的表征来共同预测,从而避免单个商品无法反应用户偏好的问题。