前言
Google官方给出了两个tensorflow的高级封装——keras和Estimator,本文主要介绍tf.Estimator的内容。tf.Estimator的特点是:既能在model_fn中灵活的搭建网络结构,也不至于像原生tensorflow那样复杂繁琐。相比于原生tensorflow更便捷、相比与keras更灵活,属于二者的中间态。
实现一个tf.Estimator主要分三个部分:input_fn、model_fn、main三个函数。其中input_fn负责处理输入数据、model_fn负责构建网络结构、main来决定要进行什么样的任务(train、eval、earlystop等等)。本文我们就通过MNIST数据集的例子,介绍一下tf.Estimator是怎么用的。
1. input_fn
读过我的另一篇文章:Tensorflow笔记:TFRecord的制作与读取 的同学应该记得那里面的read_and_decode函数,其实就和这里的input_fn逻辑是类似的,都是通过tf.data每次调用会产生一个batch的数据。
代码语言:javascript复制def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
"""
每次调用,从TFRecord文件中读取一个大小为batch_size的batch
Args:
filenames: TFRecord文件
batch_size: batch_size大小
num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
perform_shuffle: 是否乱序
Returns:
tensor格式的,一个batch的数据
"""
def _parse_fn(record):
features = {
"label": tf.FixedLenFeature([], tf.int64),
"image": tf.FixedLenFeature([], tf.string),
}
parsed = tf.parse_single_example(record, features)
# image
image = tf.decode_raw(parsed["image"], tf.uint8)
image = tf.reshape(image, [28, 28])
# label
label = tf.cast(parsed["label"], tf.int64)
return {"image": image}, label
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
如果遇到不同的问题,其实只需要改动tf.data.TFRecordDataset这一行和_parse_fn函数即可。比如如果输入数据不是TFRecord格式,而是一个LIBSVM格式:
代码语言:javascript复制def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
def _parse_fn(line):
columns = tf.string_split([line], ' ')
labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
splits = tf.string_split(columns.values[5:], ':') # filed_size=280 feature_size=6500000
id_vals = tf.reshape(splits.values, splits.dense_shape)
feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
# feat_vals = tf.sign(feat_vals) * tf.math.log(tf.abs(feat_vals) 1) # do log manual
return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TextLineDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
只是修改了_parse_fn的内容,并用tf.data.TextLineDataset替换tf.data.TFRecordDataset即可。总之这种形式的input_fn其实类似一种迭代器,每次调用都会返回一个batch的数据。但是这里面的_parse_fn函数的内容,就要根据实际情况来编写了。
2. model_fn
model_fn是Estimator中最核心,也是最复杂的一个部分,在这里面需要定义网络结构、损失、train_op、评估结果等各种与网路结构有关的内容。下面依然通过《Tensorflow笔记:TFRecord的制作与读取》中的例子:通过简单的DNN网络来预测label来说明(这一段代码虽然长,但是也是结构化的,不要嫌麻烦一个part一个part的看,其实不复杂的)。
代码语言:javascript复制def model_fn(features, labels, mode, params):
# ========== 解析参数部分 ========== #
learning_rate = params["learning_rate"]
# ========== 网络结构部分 ========== #
# input
X = tf.cast(features["image"], tf.float32, name="input_image")
X = tf.reshape(X, [-1, 28*28]) / 255
# DNN
deep_inputs = X
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
# output
y = tf.reshape(y_deep, shape=[-1, 10])
pred = tf.nn.softmax(y, name="soft_max")
# ========== 如果是 predict 任务 ========== #
predictions={"prob": pred}
export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
# Provide an estimator spec for `ModeKeys.PREDICT`
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs=export_outputs)
# ========== 如果是 eval 任务 ========== #
one_hot_label = tf.one_hot(tf.cast(labels, tf.int32, name="input_label"), depth=10, name="label")
# 构建损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=one_hot_label))
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(tf.math.argmax(one_hot_label, axis=1), tf.math.argmax(pred, axis=1))
}
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
eval_metric_ops=eval_metric_ops)
# ========== 如果是 train 任务 ========== #
# 构建train_op
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss, global_step=tf.train.get_global_step())
# Provide an estimator spec for `ModeKeys.TRAIN` modes
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
介绍一下model_fn的结构:
- Part1:解析参数部分,本例中以learning_rate为例,展示如何通过param来将参数传递进来,其他参数为了简便,直接用了数值型。
- Part2:网络结构部分。这部分只是负责构建网络结构,从input到pred,不涉及label部分,所以不要把对labels的处理写在这里,因为如果在predict任务中,可能没有label的数据,就会报错。(在这里其实是支持通过tf.keras来构造网络结构,关于tf.keras的用法我在《Tensorflow笔记:高级封装——Keras》中有详细介绍)
- Part3:predict任务部分。如果任务目的是predict,那么可以直接通过网络结构计算pred,不需要其他操作。设置好export_outputs,并以tf.estimator.EstimatorSpec形式返回即可。
- Part4:eval任务部分。如果是eval任务,除了网络结构以外还需要计算此时的损失、正确率等指标,所以对于loss的定义要放在这一部分。同时设置好评价指标eval_metric_ops,并以tf.estimator.EstimatorSpec形式返回。
- Part5:train任务部分。最后如果是train任务,除了网络结构、loss,还需要优化器、学习率等内容,所以定义train_op的部分在这里进行。最后以tf.estimator.EstimatorSpec形式返回。
model_fn部分虽然看起来长,但是对于不同的任务,只需要改动网络结构部分、loss以及train_op就可以了,说白了还是复制粘贴那点事。
3. main
最后就到了main函数这里,已经有了input_fn负责数据,model_fn负责模型,main这部分管的就是,我要怎么用这个模型。
代码语言:javascript复制def main():
# ========== 准备参数 ========== #
task_type = "train"
model_params = {
"learning_rate": 0.001,
}
# ========== 构建Estimator ========== #
config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
log_step_count_steps=100,
save_summary_steps=100,
save_checkpoints_secs=None,
save_checkpoints_steps=500,
keep_checkpoint_max=1
)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)
# ========== 执行任务 ========== #
if task_type == "train":
# early_stop_hook 是控制模型早停的控件,下面两个分别是 tf 1.x 和 tf 2.x 的写法
# early_stop_hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, metric_name="accuracy",
early_stop_hook=tf.estimator.experimental.stop_if_no_increase_hook(estimator, metric_name="accuracy", max_steps_without_increase=1000, min_steps=500)
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(tr_files, num_epochs=10, batch_size=32), hooks=[early_stop_hook])
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32), steps=None, start_delay_secs=1000, throttle_secs=1)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
elif task_type == "eval":
estimator.evaluate(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=32))
elif task_type == "infer":
preds = estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=32), predict_keys="prob")
with open("./pred.txt", "w") as fo:
for prob in preds:
fo.write("%fn" % (np.argmax(prob['prob'])))
if task_type == "export":
feature_spec = {
"image": tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name="image"),
}
serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
Estimator.export_savedmodel("./saved_model/", serving_input_receiver_fn)
其实main中主要做三件事:1. 通过tf.estimator.RunConfig()配置构建Estimator对象;2. 初始化estimator(model_dir如果非空则自动热启动);3. 执行train/eval/infer/export任务。
- train任务中初始化好TrainSpec和EvalSpec之后可以直接调用tf.estimator.train。也可以使用train_and_evaluate来一边训练一边输出验证集效果。hook可以看作是在训练验证基础上可以实现其他复杂功能的“插件”,比如本例中的early___stop,其他功能还包括热启动、Fine-tune等等,关于hook的用法比较复杂,以后单独写一篇文章。
- eval任务输出的就是在model_fn函数中eval_metric_ops定义的指标。
- infer任务就是调用estimator.predict获取在model_fn中定义的export_outputs作为预测值。
- export就是将定义Estimator时候模型路径 model_dir="./model_ckpt/" 下的模型导出为可部署模型,也就是常说的saved_model。关于saved_model和模型部署方面,我也会单独写一篇文章来介绍。另外feature_spec指的是一个请求过来所带的数据应该长什么样,对应了model_fn里面的features(即features"image"),所以这里feature_spec用的是字典的形式,建议model_fn中的features也用字典形式,哪怕是只有一个元素。
最后,直接跑main函数,或者通过tf.app.run()来运行脚本都可以:
代码语言:javascript复制# 直接运行 main 函数
main()
# 通过 tf.app.run() 来运行
if __name__ == "__main__":
tf.app.run()
4. 分布式训练
对于单机单卡和单机多卡的情况,可以通过tf.device('/gpu:0')来手动控制,这里介绍一下在多机分布式情况下Estimator如何进行分布式训练。Estimator的分布式训练和原生Tensorflow的分布式训练类似,都需要提供一份“集群名单”,并且告诉每一台机器他是名单中的谁,并在每台机器上运行脚本。下面看一个例子
代码语言:javascript复制import os
import json
import numpy as np
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", "worker", "chief/ps/worker")
tf.app.flags.DEFINE_integer("task_id", 0, "Task ID of the worker running the train")
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'chief': ["localhost:2221"],
'ps': ["localhost:2222"],
'worker': ["localhost:2223", "localhost:2224"]
},
'task': {'type': FLAGS.job_name, 'index': FLAGS.task_id}
})
本例采用本地机的两个端口模拟集群中的两个机器,"cluster"表示集群的“名单”信息。"task"表示该机器的信息,"type"表示该机器的角色,"index"表示该机器是列表中的第几个。tf.Estimator中需要指定一个chief机器,ps机也只是在特定的策略下才需要指定(这一点下文介绍)。
除此之外,只需要在tf.ConfigProto中配置train_distribute就可以了:
代码语言:javascript复制strategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0, 'CPU': 1}),
log_step_count_steps=100,
save_summary_steps=100,
save_checkpoints_secs=None,
save_checkpoints_steps=500,
keep_checkpoint_max=1,
train_distribute=strategy
)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="./model_ckpt/", params=model_params, config=config)
接下来只需要在每台机器上运行脚本,就可以完成Esitmator的分布式训练了。实际上可以声明不同的strategy,来实现不同的并行策略:
tf.distribute.MirroredStrategy
:单机多卡情况,每一个GPU都保存变量副本。tf.distribute.experimental.CentralStorageStrategy
:单机多卡情况,GPU不保存变量副本,变量都保存在CPU上。tf.distribute.experimental.MultiWorkerMirroredStrategy
:在所有机器的每台设备上创建模型层中所有变量的副本。它使用CollectiveOps
,一个用于集体通信的 TensorFlow 操作,来聚合梯度并使变量保持同步。tf.distribute.experimental.TPUStrategy
:在TPU上训练模型tf.distribute.experimental.ParameterServerStrategy
:本例中采用的策略,有专门的ps机负责处理变量和梯度,worker机专门负责训练,计算梯度。所以只有在这种策略下,才需要在os.environ'TF_CONFIG'中设置ps机。tf.distribute.OneDeviceStrategy
:用单独的设备来训练。