wide & deep 模型与优化器理解 代码实战

2021-09-12 15:46:00 浏览数 (3)

1. 背景

wide & deep模型是Google在2016年发布的一类用于分类和回归的模型。该模型应用到了Google Play的应用推荐中,有效的增加了Google Play的软件安装量。目前wide & deep模型已经开源,并且在TensorFlow上提供了高级API。

wide & deep模型旨在使得训练得到的模型能够同时兼顾记忆(Memorization)与泛化(Generalization)能力:

  • Memorization:模型能够从历史数据中学习到高频共现的特征组合,发掘特征之间的相关性,通过特征交叉产生特征相互作用的“记忆”,高效可解释。但要泛化,则需要更多的特征工程。
  • Generalization:代表模型能够利用相关性的传递性去探索历史数据中从未出现过的特征组合,通过embedding的方法,使用低维稠密特征输入,可以更好的泛化训练样本中从未出现的交叉特征。

这篇论文的主要贡献分为以下三点:

  • Wide & Deep联合训练具有嵌入的前馈神经网络和具有特征变换的线性模型,用于具有稀疏输入的通用推荐系统;
  • W&D在google应用商店上进行了线上测试和评估;
  • 在TensorFlow API中贡献了源码,方便调用;

其使用到的特征包括:

  • 用户维度的特征(城市,年龄,人口统计学特征等)
  • 上下文特征(设备,几点请求,周几请求)
  • APP维度特征(app上线时长,app的历史统计信息)

2. 模型结构

2.1 Wide 部分

wide部分对应上图中的左侧部分,通常是一个广义线性模型即LR:y=w*x b

  • y:是要预测的结果
  • x:是一组特征向量
  • w:模型的参数
  • b:偏置量

特征集合包含的是原始输入和他们对应的特征转换,其中一个比较重要的转换是:cross-product transformation特征交叉(特征交叉前需要各个特征进行one-hot),其对应的公式如下:

phi k(x)=prod^{d}_{i=1}x_i^{c_{ki}}

c_{ki}属于0,1

上边的公式实现的其实就是one-hot编码,比如当gender=female,language=en时为1,其他为0。

2.2 Deep 部分

deep部分对应上图中的右侧部分,是一个前馈神经网络。对于分类特征,原始输入是字符串(比如language=en)。这些稀疏、高维的分类特征第一步是转化为低维、密集的向量。这些向量通常在10-100维之间,一般采用随机的方法进行初始化,在训练过过程中通过最小化损失函数来优化模型。这些低维的向量传到神经网络的隐层中去。每个隐层的计算方式如下:

a^{l 1}=f(W^la^l b^l)

其中:

  • l:神经网络的层数
  • f:激活函数(通常是relu)
  • a^l:第l层的输出值
  • b^l:第l层的偏置
  • W^l:第l层的权重

2.3 Wide & Deep的联合训练

wide部分和deep部分使用输出结果的对数几率加权和作为预测值,然后将其输入到一个逻辑回归函数用来联合训练。论文中强调了联合训练(Join training)和整体训练(ensemble)的区别。

  • Ensemble:两个模型分别独立训练,只在最终预测的时候才将两个模型结合计算;单个模型需要更大(比如进行特征转换)来保证结合后的准确率
  • Join trainging:在训练时,同时考虑wide部分和deep部分以及两个模型拼接起来的权重来同时优化所有的参数;wide部分可以通过少量的特征交叉来弥补deep部分的弱势

wide & deep的join training采用的是下批量梯度下降算法(min-batch stochastic optimization)进行优化的。在实验中,wide部分采用的是Follow-the-regularized-leader(FTRL) L1,deep部分采用的是Adga。

L1 FTRL会让Wide部分的大部分权重都为0,我们准备特征的时候就不用准备那么多0权重的特征了,这大大压缩了模型权重,也压缩了特征向量的维度。

Deep部分的输入,要么是Age,#App Installs这些数值类特征,要么是已经降维并稠密化的Embedding向量,工程师们不会也不敢把过度稀疏的特征向量直接输入到Deep网络中。所以Deep部分不存在严重的特征稀疏问题,自然可以使用精度更好,更适用于深度学习训练的AdaGrad去训练。

对于LR,模型的预测结果如下:

P(Y=1|x)=sigma(w^T_{wide}[x,phi x] W^T_{deep}a^{(l_f)} b)

其中:

  • Y:label
  • sigma():表示sigmoid函数
  • phi(x):原始特征x的交叉转换
  • b:偏置
  • w_{wide}:wide模型的权重
  • W_{deep}a^{(l_f)}的权重

3. 实践

这个是原论文中的架构图,我们自己在实践的时候不一定完全遵守。比如架构图中Wide部分只是使用了交叉特征,我们在使用的时候可以把原始的离散特征或者打散后的连续特征加过来。

3.1 特征工程与处理

  • 用户特征:注册时长、上一次访问距今时长等基础特征,最近3/7/15/30/90天活跃/浏览/关注/im数量等行为特征,以及画像偏好特征和转化率特征;
  • item特征:item基础特征,以及热度值/点击率等连续特征;
  • 交叉特征:将画像偏好和item的特征进行交叉。

  • 缺失值与异常值处理:常规操作;不同特征使用不同缺失值填补方法;异常值使用四分位;
  • 等频分桶处理:常规操作;比如价格,是一个长尾分布,这就导致大部分样本的特征值都集中在一个小的取值范围内,使得样本特征的区分度减小。
  • 归一化:常规操作;效果得到显著提升;
  • 低频过滤:常规操作;对于离散特征,过于低频的归为一类;
  • embedding;

3.2 离线训练

  • 数据切分:采用7天的数据作为训练集,1天的作为测试集
  • embedding:
  • 模型调优:
    • 防止过拟合:加入dropOut 与 L2正则
    • 加快收敛:引入了Batch Normalization
    • 保证训练稳定和收敛:尝试不同的learning rate(wide侧0.001,deep侧0.01效果较好)和batch_size(目前设置的2048)
    • 优化器:对比了SGD、Adam、Adagrad等学习器

论文中提到了一个注意点:如果每一次都重新训练的话,将会花费大量的时间和精力,为了解决这个问题,采取的方案是热启动,即每次新产生训练数据的时候,从之前的模型中读取embedding和线性模型的权重来初始化新模型,在接入实时流之前使用之前的模型进行校验,保证不出问题。

4. 拓展

有些时候对于用户或者待推荐的物体会有Text和Image,为了增加效果,可能会使用到多模态特征。

Text 和 Image 的 embedding 向量,采用 和Wide模型一样的方式加入到整体模型。

几个简单的思路。

  1. Text 和 Image 的 embedding 向量,采用 和Wide模型一样的方式加入到整体模型中就可以了。至于 两者的Embedding向量如何获取,就看你自己了。
  2. Text和Image之间使用attention之后再加入
  3. Text和Image 和Deep 模型的输出拼接之后再做一次处理
  4. Paper关键词:Multimodal Fusion

5. 代码示例

代码语言:javascript复制
train_data = "./../data/adult/adult.train"
test_data = "./../data/adult/adult.test"
train = pd.read_csv(train_data, sep=",", names=["age", "workclass", "fnlwgt", "education", "education_num", "marital_status", "occupation", "relationship", "race","sex", "capital_gain", "capital_loss", "hours_per_week", "native_country", "label"])
print(train.head(5))
代码语言:javascript复制
   age          workclass  fnlwgt  ... hours_per_week  native_country   label
0   39          State-gov   77516  ...             40   United-States   <=50K
1   50   Self-emp-not-inc   83311  ...             13   United-States   <=50K
2   38            Private  215646  ...             40   United-States   <=50K
3   53            Private  234721  ...             40   United-States   <=50K
4   28            Private  338409  ...             40            Cuba   <=50K

定义基本特征、连续特征和dnn使用的特征:

代码语言:javascript复制
# 定义基本连续的特征,linear 和 dnn 都会使用到
age = tf.feature_column.numeric_column("age")
education_num = tf.feature_column.numeric_column("education_num")
capital_gain = tf.feature_column.numeric_column("capital_gain")
capital_loss = tf.feature_column.numeric_column("capital_loss")
hours_per_week = tf.feature_column.numeric_column("hours_per_week")

# 定义离散特征
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
    key="workclass",
    vocabulary_list=["Private", "Self-emp-not-inc", "Self-emp-inc", "Federal-gov", "Local-gov", "State-gov",
                     "Without-pay", "Never-worked", "?"]
)
education = tf.feature_column.categorical_column_with_vocabulary_list(
    key="education",
    vocabulary_list=["Bachelors", "Some-college", "11th", "HS-grad", "Prof-school", "Assoc-acdm", "Assoc-voc", "9th",
                     "7th-8th", "12th", "Masters", "1st-4th", "10th", "Doctorate", "5th-6th", "Preschool"]

)
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
    key="marital_status",
    vocabulary_list=["Married-civ-spouse", "Divorced", "Never-married", "Separated", "Widowed", "Married-spouse-absent",
                     "Married-AF-spouse"]
)
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
    key="relationship",
    vocabulary_list=["Wife", "Own-child", "Husband", "Not-in-family", "Other-relative", "Unmarried"]
)

# 定义Hash特征,展示embedding的使用
occupation = tf.feature_column.categorical_column_with_hash_bucket(
    key="occupation",
    hash_bucket_size=1000
)

# age 特征分桶
age_bucket = tf.feature_column.bucketized_column(
    source_column=age,
    boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]
)

base_columns = [workclass, education, marital_status, relationship, occupation, age_bucket]

crossed_columns = [
    tf.feature_column.crossed_column(
        keys=["education", "occupation"], hash_bucket_size=1000
    ),
    tf.feature_column.crossed_column(
        keys=[age_bucket, "education", "occupation"], hash_bucket_size=1000
    )
]

deep_columns = [
    age,
    education_num,
    capital_gain,
    capital_loss,
    hours_per_week,
    tf.feature_column.indicator_column(workclass),  # 做one-hot,然后送入dnn layer
    tf.feature_column.indicator_column(education),
    tf.feature_column.indicator_column(marital_status),
    tf.feature_column.indicator_column(relationship),

    # 展示embedding的使用
    tf.feature_column.embedding_column(occupation, dimension=8)
]

定义数据:

代码语言:javascript复制
# 定义数据
_CSV_COLUMNS = [
    "age", "workclass", "fnlwgt", "education", "education_num",
    "marital_status", "occupation", "relationship", "race", "sex",
    "capital_gain", "capital_loss", "hours_per_week", "native_country", "label"
]
_CSV_COLUMN_DEFAULTS = [
    [0], [''], [0], [''], [0],
    [''], [''], [''], [''], [''],
    [0], [0], [0], [''], ['']
]

_NUM_EXAMPLES = {
    "train": 32561,
    "validation": 16281
}

定义模型:

代码语言:javascript复制
def create_model():
    model = tf.estimator.DNNLinearCombinedClassifier(
        model_dir="./model/wd/",
        linear_feature_columns=base_columns   crossed_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=[100, 50],
        linear_optimizer="Ftrl",
        dnn_optimizer="Adagrad",
        n_classes=2,
        batch_norm=False
    )
    return model

定义input_fn函数:

代码语言:javascript复制
def input_fn(data_file, num_epochs, shuffle, batch_size):
    """为Estimator创建一个input function"""
    assert tf.io.gfile.GFile(data_file), "{0} not found.".format(data_file)

    def parse_csv(line):
        # tf.decode_csv会把csv文件转换成 a list of Tensor,一列一个
        # record_defaults用于指明每一列的缺失值用什么填充
        columns = tf.io.decode_csv(line, record_defaults=_CSV_COLUMN_DEFAULTS)
        features = dict(zip(_CSV_COLUMNS, columns))
        labels = features.pop('label')
        # tf.equal(x, y) 返回一个bool类型Tensor, 表示x == y, element-wise
        # 注意数据重的空格
        return features, tf.equal(labels, ' >50K')

    dataset = tf.data.TextLineDataset(data_file).map(parse_csv, num_parallel_calls=5)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']   _NUM_EXAMPLES['validation'])

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)

    return dataset

main函数:

代码语言:javascript复制
if __name__ == "__main__":
    train_epochs = 20
    batch_size = 256

    model = create_model()
    for n in range(train_epochs):
        print("train model start ...")
        model.train(input_fn=lambda: input_fn(train_data, train_epochs, True, batch_size))
        predict_results = model.predict(input_fn=lambda: input_fn(test_data, train_epochs, False, batch_size))

        print("test model start ...")
        results = model.evaluate(input_fn=lambda: input_fn(test_data, train_epochs, False, batch_size))
        # print(results)

        print('{0:-^30}'.format('evaluate at epoch %d' % ((n   1))))
        # results 是一个字典
        print(pd.Series(results).to_frame('values'))

最后运行20个epoch之后,输出结果为:

代码语言:javascript复制
-----evaluate at epoch 20-----
                            values
accuracy                  0.826301
accuracy_baseline         0.763774
auc                       0.852878
auc_precision_recall      0.686778
average_loss              0.381445
label/mean                0.236226
loss                      0.381446
precision                 0.727232
prediction/mean           0.249888
recall                    0.423557
global_step           51008.000000

Process finished with exit code 0

Ref

  1. https://mp.weixin.qq.com/s?__biz=MzIyNTY1MDUwNQ==&mid=2247484238&idx=1&sn=c9700da77cad73f91420fe4309ff0100&chksm=e87d3168df0ab87eb8721dc7220877fb43ae5e66061c7b286fb570720dcc349d3d3d4346eeae&scene=21#wechat_redirect
  2. https://mp.weixin.qq.com/s/UOuT-E8g22iRRoKC9r23YQ
  3. https://mp.weixin.qq.com/s/Vur4kvsiXRbfOYyso81WVg
  4. https://mp.weixin.qq.com/s?__biz=MzI2ODA3NjcwMw==&mid=2247483659&idx=1&sn=deb9c5e22eabd3c52d2418150a40c68a&chksm=eaf452fbdd83dbed0d6de5e847e8569bdc0a75ef6aa23fcaa9c5586a2572cd0e216f499a529b&scene=21#wechat_redirect
  5. https://liam.page/2019/08/31/a-not-so-simple-introduction-to-FTRL/ FTRL介绍
  6. https://zhuanlan.zhihu.com/p/142958834

0 人点赞