【云+社区年度征文】tensorflow 2.0 Estimator Keras读取saved model并预测

2020-12-07 17:35:03 浏览数 (1)

背景

使用tensorflow2.0以上版本框架用Keras或者Estimator方式保存模型有两种方式加载模型并预测。

Keras框架保存模型后可以直接加载并调用predict方法预测;

estimator将比较麻烦,需要签名并传入tensor才可以预测;

Keras模型预测

代码语言:txt复制
import tensorflow as tf
from tensorflow import keras
model = tf.keras.models.load_model(export_dir)

# dataframe 特征读取与处理
X = dict(dataframe)
c = model.predict(X)
output = np.argmax(c, axis=1)

Estimator模型预测

代码语言:txt复制
import tensorflow as tf
# 加载模型 & 签名
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
代码语言:txt复制
# 转换为tensor并预测
out_df = pd.DataFrame()
def predict(dataframe):
    examples = []
    for row in dataframe.itertuples():
        feature_map = {}
        # 特征处理 将特征放入dict中
        example = tf.train.Example(
            features=tf.train.Features(
                feature = feature_map
            )
        )
        examples.append(example.SerializeToString())
            
    ex = tf.constant(examples)
    result = f(examples=ex)
    out_df['high_rank_score'] = np.max(result["probabilities"].numpy(), axis=1)
    out_df['tag'] = np.argmax(result["probabilities"].numpy(), axis=1)
    return out_df

Ref

  1. http://d0evi1.com/tensorflow/custom_estimators/
  2. https://www.tensorflow.org/guide/saved_model?hl=zh-cn#加载和使用自定义模型
  3. https://zhuanlan.zhihu.com/p/66872472
  4. https://yinguobing.com/load-savedmodel-of-estimator-by-keras/

0 人点赞