code:https://github.com/WenjieWWJ/DecRS
paper:https://arxiv.fenshishang.com/pdf/2105.10648.pdf
title:Deconfounded Recommendation for Alleviating Bias Amplification
论文阅读推文:推荐系统 因果推断(一)——KDD2021推荐系统中去除混淆缓解偏差方大
DecRS主要是通过因果推断来缓解推荐系统中的数据偏差问题,以及历史数据中的数据分布不平衡问题,具体可以看文章和之前的论文解读推文。由于推文是早期创作的,所以可能格式之类的不太好,请大家多担待。
本文主要的思想是:通过去除因果图中的混杂因子来缓解数据有偏导致的问题,因此方法中没有用到用户的点击数据而是直接使用了不同item组的先验分布,从而避免使用到有偏数据。具体可见论文或公众号推文。
今天这篇文章主要和大家分享改论文的代码部分。
文件
主要包含code和data两部分,code部分分别实现了amazon-book和ml-1m的对应的相关DecFM和DecNFM。本文后续代码以DecFM为例。
代码文件主要包含以下几部分:
- data_utils:读取数据,并且将特征转换成对应的index,得到对应的datasets类,对数据进行负采样得到负样本
- main:主文件,用于读取相应的超参数,以及实现训练和测试的迭代等
- model:实现不同模型的具体代码
- inference:用于测试
数据文件以amazon_book为例,主要包含(只写一部分,具体可详见代码):
- category_list:包含数据中的不同分组
- confounder_prior:不同组的占比,先验分布
- item_category:商品属于什么组
- user_feature:用户特征
模型
DecFM即为因果推断在FM模型上的应用,这里有两个部分用到了FM,第一部分用到的FM是为了计算M(d, u);第二部分的FM是为了对user-item之间的可能点击概率进行预测。
代码语言:javascript复制class DecFM(nn.Module):
def __init__(self, num_features, num_groups, num_factors, batch_norm,
drop_prob, num_user_features, confounder_prior=None):
super(DecFM, self).__init__()
"""
num_features: number of features,
num_factors: number of hidden factors,
batch_norm: bool type, whether to use batch norm or not,
drop_prob: list of the dropout rate for FM and MLP,
"""
self.num_features = num_features
self.num_factors = num_factors
self.num_groups = num_groups
self.num_user_features = num_user_features
self.batch_norm = batch_norm
self.drop_prob = drop_prob
self.embeddings = nn.Embedding(num_features, num_factors)
self.biases = nn.Embedding(num_features, 1)
self.bias_ = nn.Parameter(torch.tensor([0.0]))
# 用于得到不同组的embedding
self.confounder_embed = nn.Embedding(num_groups, num_factors)
# confounder_prior表示每个组所占比例,即文中图1a举得例子[0.7,0.3]这种
if confounder_prior is not None:
# bar{d} in the paper
self.confounder_prior = torch.tensor(confounder_prior, dtype=torch.float32).cuda().unsqueeze(dim=-1)
else:
# confounder prior is assumed as [1/n, 1/n, ..., 1/n]
self.confounder_prior =
torch.tensor([1.0/num_groups for x in range(num_groups)]).cuda().unsqueeze(dim=-1)
FM_modules = []
if self.batch_norm:
FM_modules.append(nn.BatchNorm1d(num_factors))
FM_modules.append(nn.Dropout(drop_prob[0]))
self.FM_layers = nn.Sequential(*FM_modules)
nn.init.normal_(self.embeddings.weight, std=0.01)
nn.init.constant_(self.biases.weight, 0.0)
def forward(self, features, feature_values):
### U & D => M: use FM to calculate M(bar{d}, u)
# map the id of group features. eg., [123, 124, 125] => [0, 1, 2]
confounder_part = features[:, -self.num_groups:] - torch.min(features[:, -self.num_groups:])
# 得到不同组的embedding
confounder_embed = self.confounder_embed(confounder_part) # N*num_confoudners*hidden_factor
# 这个对应论文中的公式7的p(g_a)v_a,也就是公式5中的sum(p(d)d),
# 利用先验对不同组的embedding进行加权,用于后续计算和用户表征的关系
# 反映用户对不同组的偏向
weighted_confounder_embed = confounder_embed * self.confounder_prior
# 得到用户特征的embedding
user_features = features[:, :self.num_user_features]
user_feature_values = feature_values[:, :self.num_user_features].unsqueeze(dim=-1)
user_features = self.embeddings(user_features)
# 对应公式7中的后半部分x*u,计算用户表征
weighted_user_features = user_features * user_feature_values # N*num_user_features*hidden_factor
user_confounder_embed = torch.cat([weighted_user_features, weighted_confounder_embed], 1)
# 后面就是正常的FM模型的计算流程,利用FM来得到M(bar{d}, u)
# 总体上是为了计算不同组的先验与用户embedding之间的关系
# FM的第一项,求和的平方
sum_square_user_confounder_embed = user_confounder_embed.sum(dim=1).pow(2) # N*hidden_factor
# FM的第二项,求平方和
square_sum_user_confounder_embed = (user_confounder_embed.pow(2)).sum(dim=1) # N*hidden_factor
user_confounder_mediator = 0.5 * (sum_square_user_confounder_embed - square_sum_user_confounder_embed)
user_confounder_mediator = user_confounder_mediator.unsqueeze(dim=1)
batch_num = user_features.size()[0]
assert list(user_confounder_mediator.size())==[batch_num, 1, self.num_factors]
### 后续是利用FM来进行预测,U & M & I => Y: similar to FM
## 以用户,商品和上面求得的M为特征
nonzero_embed = self.embeddings(features)
feature_values = feature_values.unsqueeze(dim=-1)
nonzero_embed = nonzero_embed * feature_values
nonzero_embed = torch.cat([nonzero_embed, user_confounder_mediator], 1)
# 和上面FM的计算流程一样
sum_square_embed = nonzero_embed.sum(dim=1).pow(2)
square_sum_embed = (nonzero_embed.pow(2)).sum(dim=1)
# FM model
FM = 0.5 * (sum_square_embed - square_sum_embed)
# 后续经过BN和dropout来增强模型的鲁棒性
FM = self.FM_layers(FM).sum(dim=1, keepdim=True)
# bias addition
feature_bias = self.biases(features)
feature_bias = (feature_bias * feature_values).sum(dim=1)
FM = FM feature_bias self.bias_
return FM.view(-1)
推理阶段
在inference的时候,作者考虑到了有的用户他可能就是有偏的,比如有的用户可能只喜欢喜剧。因此文中采用了KL散度来判断该用户是否是兴趣多变的用户。将原有的用户u的历史序列分为两段序列,然后计算这两段序列的对称KL散度。值越大说明越容易改变兴趣,即需要后门调整;反之,则不需要后门调整。
以下为部分代码,具体可见inference文件夹
代码语言:javascript复制# 这部分是使用了公式11,12的情况,kl_score标准化后,
# 对两个模型(去除混杂和不去除混杂)的输出分别加权后得到预测分数
threshold_list = [0, 0.5, 1, 2, 3, 4]
alpha_list = [0.5]
# alpha_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
# 最外层的for是在探索哪个阈值最好
for threshold in threshold_list:
print(f'threshold {threshold}')
# 探索哪个alpha比较好,即公式12的幂
for alpha in alpha_list:
print(f'alpha {alpha}')
user_item_pred = []
user_item_gt = []
user_item_gt_dict = {}
for userID in FM_user_item_score:
# 对于分数小于阈值的就直接采用在有偏数据上训练得到的FM计算的分数
# 对于大于阈值的,说明兴趣是容易变化的,需要结合Dec
if user_kl_score[userID]<=threshold:
continue
FM_item_scores = np.array(FM_user_item_score[userID])
DecFM_item_scores = np.array(DecFM_user_item_score[userID])
# 计算标准化后的系数
lamda = ((user_kl_score[userID] - min_kl_score)/(max_kl_score-min_kl_score))**alpha
# 利用eta*decfm_score (1-eta)*fm_score得到最终分数
combined_scores = lamda * util.sigmoid(DecFM_item_scores) (1-lamda) * util.sigmoid(FM_item_scores)
DecFM_scores = []
for i in range(len(FM_item_scores)):
itemID = user_candidates[userID][i]
DecFM_scores.append([itemID, combined_scores[i]])
DecFM_scores.sort(reverse=True, key=util.sort_function)
user_item_gt.append(test_dict[userID])
user_item_gt_dict[userID] = test_dict[userID]
user_item_pred.append([x[0] for x in DecFM_scores[:1000]])
print(f'user num {len(user_item_gt)}')
test_result = util.computeTopNAccuracy(user_item_gt, user_item_pred, topN)
util.print_results(None, None, test_result, None)