背景
使用tensorflow2.0以上版本框架用Keras或者Estimator方式保存模型有两种方式加载模型并预测。
Keras框架保存模型后可以直接加载并调用predict方法预测;
estimator将比较麻烦,需要签名并传入tensor才可以预测;
Keras模型预测
代码语言:txt复制import tensorflow as tf
from tensorflow import keras
model = tf.keras.models.load_model(export_dir)
# dataframe 特征读取与处理
X = dict(dataframe)
c = model.predict(X)
output = np.argmax(c, axis=1)
Estimator模型预测
代码语言:txt复制import tensorflow as tf
# 加载模型 & 签名
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
代码语言:txt复制# 转换为tensor并预测
out_df = pd.DataFrame()
def predict(dataframe):
examples = []
for row in dataframe.itertuples():
feature_map = {}
# 特征处理 将特征放入dict中
example = tf.train.Example(
features=tf.train.Features(
feature = feature_map
)
)
examples.append(example.SerializeToString())
ex = tf.constant(examples)
result = f(examples=ex)
out_df['high_rank_score'] = np.max(result["probabilities"].numpy(), axis=1)
out_df['tag'] = np.argmax(result["probabilities"].numpy(), axis=1)
return out_df
Ref
- http://d0evi1.com/tensorflow/custom_estimators/
- https://www.tensorflow.org/guide/saved_model?hl=zh-cn#加载和使用自定义模型
- https://zhuanlan.zhihu.com/p/66872472
- https://yinguobing.com/load-savedmodel-of-estimator-by-keras/