视频讲解:https://www.yuque.com/chudi/tzqav9/ny150b#aalY8
代码语言:javascript复制import tensorflow as tf
from tensorflow import keras
from utils import *
EPOCH = 10
BATCH_SIZE = 32
VEC_DIM = 10
DNN_LAYERS = [64, 128, 64]
DROPOUT_RATE = 0.5
base, test = loadData()
# 所有的特征各个类别值个数之和
FEAT_CATE_NUM = base.shape[1] - 1
K = tf.keras.backend
def run():
# 返回id化特征 和 one-hot特征
val_x_id, val_x_hot, val_y = getAllData(test)
train_x_id, train_x_hot, train_y = getAllData(base)
cate_num = val_x_id[0].shape[0]
hot_num = val_x_hot[0].shape[0]
inputs_id = keras.Input((cate_num,))
emb = keras.layers.Embedding(FEAT_CATE_NUM, VEC_DIM, input_length=cate_num)(inputs_id)
deep = keras.layers.Flatten()(emb)
deep = keras.layers.Dropout(DROPOUT_RATE)(deep)
for units in DNN_LAYERS:
deep = keras.layers.Dense(units, activation='relu')(deep)
deep = keras.layers.Dropout(DROPOUT_RATE)(deep)
wide = keras.Input((hot_num,))
wide_deep = keras.layers.concatenate([wide, deep])
outputs = keras.layers.Dense(1, activation='sigmoid')(wide_deep)
model = keras.Model(inputs=[inputs_id, wide], outputs=outputs)
model.compile(loss='binary_crossentropy', optimizer=tf.train.AdamOptimizer(0.001), metrics=[keras.metrics.AUC()])
tbCallBack = keras.callbacks.TensorBoard(log_dir='./logs',
histogram_freq=0,
write_graph=True,
write_grads=True,
write_images=True,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None)
model.fit([train_x_id, train_x_hot], train_y, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=2,
validation_data=([val_x_id, val_x_hot], val_y),
callbacks=[tbCallBack])
run()