Tensorflow笔记:高级封装——tf.Estimator

2021-09-24 15:06:01 浏览数 (1)

前言

Google官方给出了两个tensorflow的高级封装——keras和Estimator,本文主要介绍tf.Estimator的内容。tf.Estimator的特点是:既能在model_fn中灵活的搭建网络结构,也不至于像原生tensorflow那样复杂繁琐。相比于原生tensorflow更便捷、相比与keras更灵活,属于二者的中间态。

实现一个tf.Estimator主要分三个部分:input_fn、model_fn、main三个函数。其中input_fn负责处理输入数据、model_fn负责构建网络结构、main来决定要进行什么样的任务(train、eval、earlystop等等)。本文我们就通过MNIST数据集的例子,介绍一下tf.Estimator是怎么用的。

1. input_fn

读过我的另一篇文章:Tensorflow笔记:TFRecord的制作与读取 的同学应该记得那里面的read_and_decode函数,其实就和这里的input_fn逻辑是类似的,都是通过tf.data每次调用会产生一个batch的数据。

代码语言:javascript复制
def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

如果遇到不同的问题,其实只需要改动tf.data.TFRecordDataset这一行和_parse_fn函数即可。比如如果输入数据不是TFRecord格式,而是一个LIBSVM格式:

代码语言:javascript复制
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    def _parse_fn(line):
        columns = tf.string_split([line], ' ')
        labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
        splits = tf.string_split(columns.values[5:], ':')  # filed_size=280 feature_size=6500000
        id_vals = tf.reshape(splits.values, splits.dense_shape)
        feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
        feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
        feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
        # feat_vals = tf.sign(feat_vals) * tf.math.log(tf.abs(feat_vals)   1)  # do log manual
        return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TextLineDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)  # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

只是修改了_parse_fn的内容,并用tf.data.TextLineDataset替换tf.data.TFRecordDataset即可。总之这种形式的input_fn其实类似一种迭代器,每次调用都会返回一个batch的数据。但是这里面的_parse_fn函数的内容,就要根据实际情况来编写了。

2. model_fn

model_fn是Estimator中最核心,也是最复杂的一个部分,在这里面需要定义网络结构、损失、train_op、评估结果等各种与网路结构有关的内容。下面依然通过《Tensorflow笔记:TFRecord的制作与读取》中的例子:通过简单的DNN网络来预测label来说明(这一段代码虽然长,但是也是结构化的,不要嫌麻烦一个part一个part的看,其实不复杂的)。

代码语言:javascript复制
def model_fn(features, labels, mode, params):
    # ==========  解析参数部分  ========== #
    learning_rate = params["learning_rate"]

    # ==========  网络结构部分  ========== #
    # input
    X = tf.cast(features["image"], tf.float32, name="input_image")
    X = tf.reshape(X, [-1, 28*28]) / 255
    # DNN
    deep_inputs = X
    deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
    deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
    y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
    # output
    y = tf.reshape(y_deep, shape=[-1, 10])
    pred = tf.nn.softmax(y, name="soft_max")

    
    # ==========  如果是 predict 任务  ========== #
    predictions={"prob": pred}
    export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
    # Provide an estimator spec for `ModeKeys.PREDICT`
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)
    

    # ==========  如果是 eval 任务  ========== #
    one_hot_label = tf.one_hot(tf.cast(labels, tf.int32, name="input_label"), depth=10, name="label")
    # 构建损失
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=one_hot_label))
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(tf.math.argmax(one_hot_label, axis=1), tf.math.argmax(pred, axis=1))
    }
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                eval_metric_ops=eval_metric_ops)
    

    # ==========  如果是 train 任务  ========== #
    # 构建train_op
    train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss, global_step=tf.train.get_global_step())
    # Provide an estimator spec for `ModeKeys.TRAIN` modes
    if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                train_op=train_op)

介绍一下model_fn的结构:

  • Part1:解析参数部分,本例中以learning_rate为例,展示如何通过param来将参数传递进来,其他参数为了简便,直接用了数值型。
  • Part2:网络结构部分。这部分只是负责构建网络结构,从input到pred,不涉及label部分,所以不要把对labels的处理写在这里,因为如果在predict任务中,可能没有label的数据,就会报错。(在这里其实是支持通过tf.keras来构造网络结构,关于tf.keras的用法我在《Tensorflow笔记:高级封装——Keras》中有详细介绍)
  • Part3:predict任务部分。如果任务目的是predict,那么可以直接通过网络结构计算pred,不需要其他操作。设置好export_outputs,并以tf.estimator.EstimatorSpec形式返回即可。
  • Part4:eval任务部分。如果是eval任务,除了网络结构以外还需要计算此时的损失、正确率等指标,所以对于loss的定义要放在这一部分。同时设置好评价指标eval_metric_ops,并以tf.estimator.EstimatorSpec形式返回。
  • Part5:train任务部分。最后如果是train任务,除了网络结构、loss,还需要优化器、学习率等内容,所以定义train_op的部分在这里进行。最后以tf.estimator.EstimatorSpec形式返回。

model_fn部分虽然看起来长,但是对于不同的任务,只需要改动网络结构部分、loss以及train_op就可以了,说白了还是复制粘贴那点事。

3. main

最后就到了main函数这里,已经有了input_fn负责数据,model_fn负责模型,main这部分管的就是,我要怎么用这个模型。

代码语言:javascript复制
def main():
    # ==========  准备参数 ========== #
    task_type = "train"
    model_params = {
        "learning_rate": 0.001,
    }

    # ==========  构建Estimator  ========== #
    config = tf.estimator.RunConfig().replace(
        session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
        log_step_count_steps=100,
        save_summary_steps=100,
        save_checkpoints_secs=None,
        save_checkpoints_steps=500,
        keep_checkpoint_max=1
    )
    estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)

    # ==========  执行任务  ========== #
    if task_type == "train":
        # early_stop_hook 是控制模型早停的控件,下面两个分别是 tf 1.x 和 tf 2.x 的写法
        # early_stop_hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, metric_name="accuracy",
        early_stop_hook=tf.estimator.experimental.stop_if_no_increase_hook(estimator, metric_name="accuracy", max_steps_without_increase=1000, min_steps=500)
        train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(tr_files, num_epochs=10, batch_size=32), hooks=[early_stop_hook])
        eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32), steps=None, start_delay_secs=1000, throttle_secs=1)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    elif task_type == "eval":
        estimator.evaluate(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32))
    elif task_type == "infer":
        preds = estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=32), predict_keys="prob")
        with open("./pred.txt", "w") as fo:
            for prob in preds:
                fo.write("%fn" % (np.argmax(prob['prob'])))
    if task_type == "export":
        feature_spec = {
            "image": tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name="image"),
        }
        serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
        Estimator.export_savedmodel("./saved_model/", serving_input_receiver_fn)

其实main中主要做三件事:1. 通过tf.estimator.RunConfig()配置构建Estimator对象;2. 初始化estimator(model_dir如果非空则自动热启动);3. 执行train/eval/infer/export任务。

  • train任务中初始化好TrainSpec和EvalSpec之后可以直接调用tf.estimator.train。也可以使用train_and_evaluate来一边训练一边输出验证集效果。hook可以看作是在训练验证基础上可以实现其他复杂功能的“插件”,比如本例中的early___stop,其他功能还包括热启动、Fine-tune等等,关于hook的用法比较复杂,以后单独写一篇文章。
  • eval任务输出的就是在model_fn函数中eval_metric_ops定义的指标。
  • infer任务就是调用estimator.predict获取在model_fn中定义的export_outputs作为预测值。
  • export就是将定义Estimator时候模型路径 model_dir="./model_ckpt/" 下的模型导出为可部署模型,也就是常说的saved_model。关于saved_model和模型部署方面,我也会单独写一篇文章来介绍。另外feature_spec指的是一个请求过来所带的数据应该长什么样,对应了model_fn里面的features(即features"image"),所以这里feature_spec用的是字典的形式,建议model_fn中的features也用字典形式,哪怕是只有一个元素。

最后,直接跑main函数,或者通过tf.app.run()来运行脚本都可以:

代码语言:javascript复制
# 直接运行 main 函数
main()

# 通过 tf.app.run() 来运行
if __name__ == "__main__":
    tf.app.run()

4. 分布式训练

对于单机单卡和单机多卡的情况,可以通过tf.device('/gpu:0')来手动控制,这里介绍一下在多机分布式情况下Estimator如何进行分布式训练。Estimator的分布式训练和原生Tensorflow的分布式训练类似,都需要提供一份“集群名单”,并且告诉每一台机器他是名单中的谁,并在每台机器上运行脚本。下面看一个例子

代码语言:javascript复制
import os
import json
import numpy as np
import tensorflow as tf


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", "worker", "chief/ps/worker")
tf.app.flags.DEFINE_integer("task_id", 0, "Task ID of the worker running the train")

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'chief': ["localhost:2221"],
        'ps':  ["localhost:2222"],
        'worker': ["localhost:2223", "localhost:2224"]
    },
    'task': {'type': FLAGS.job_name, 'index': FLAGS.task_id}
})

本例采用本地机的两个端口模拟集群中的两个机器,"cluster"表示集群的“名单”信息。"task"表示该机器的信息,"type"表示该机器的角色,"index"表示该机器是列表中的第几个。tf.Estimator中需要指定一个chief机器,ps机也只是在特定的策略下才需要指定(这一点下文介绍)。

除此之外,只需要在tf.ConfigProto中配置train_distribute就可以了:

代码语言:javascript复制
strategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig().replace(
    session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
    log_step_count_steps=100,
    save_summary_steps=100,
    save_checkpoints_secs=None,
    save_checkpoints_steps=500,
    keep_checkpoint_max=1,
    train_distribute=strategy
)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)

接下来只需要在每台机器上运行脚本,就可以完成Esitmator的分布式训练了。实际上可以声明不同的strategy,来实现不同的并行策略:

  • tf.distribute.MirroredStrategy:单机多卡情况,每一个GPU都保存变量副本。
  • tf.distribute.experimental.CentralStorageStrategy:单机多卡情况,GPU不保存变量副本,变量都保存在CPU上。
  • tf.distribute.experimental.MultiWorkerMirroredStrategy :在所有机器的每台设备上创建模型层中所有变量的副本。它使用CollectiveOps,一个用于集体通信的 TensorFlow 操作,来聚合梯度并使变量保持同步。
  • tf.distribute.experimental.TPUStrategy:在TPU上训练模型
  • tf.distribute.experimental.ParameterServerStrategy:本例中采用的策略,有专门的ps机负责处理变量和梯度,worker机专门负责训练,计算梯度。所以只有在这种策略下,才需要在os.environ'TF_CONFIG'中设置ps机
  • tf.distribute.OneDeviceStrategy:用单独的设备来训练。

0 人点赞