TF flags的简介

2019-10-28 11:50:49 浏览数 (1)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/HHTNAN/article/details/102743006

1、TF flags的简介

1、flags可以帮助我们通过命令行来动态的更改代码中的参数。Tensorflow 使用flags定义命令行参数的方法。ML的模型中有大量需要tuning的超参数,所以此方法,迎合了需要一种灵活的方式对代码某些参数进行调整的需求 (1)、比如,在这个py文件中,首先定义了一些参数,然后将参数统一保存到变量FLAGS中,相当于赋值,后边调用这些参数的时候直接使用FLAGS参数即可 (2)、基本参数类型有三种flags.DEFINE_integer、flags.DEFINE_float、flags.DEFINE_boolean。 (3)、第一个是参数名称,第二个参数是默认值,第三个是参数描述

2、使用过程

#第一步,调用flags = tf.app.flags,进行定义参数名称,并可给定初值、参数说明 #第二步,flags参数直接赋值 #第三步,运行tf.app.run()

代码语言:javascript复制
FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')
tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')
tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')

示例如下:

代码语言:javascript复制
import tensorflow as tf
#取上述代码中一部分进行实验
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')

#通过print()确定下面内容的功能
FLAGS = tf.flags.FLAGS #FLAGS保存命令行参数的数据
FLAGS._parse_flags() #将其解析成字典存储到FLAGS.__flags中
print(FLAGS.__flags)

print(FLAGS.num_seqs)

print("nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")

遇到问题可以参考:相关解决办法

0 人点赞