tf.train.piecewise_constant

2022-09-04 21:01:12 浏览数 (2)

代码语言:javascript复制
tf.train.piecewise_constant(
    x,
    boundaries,
    values,
    name=None
)

分段常数来自边界和区间值。示例:对前100001步使用1.0的学习率,对后10000步使用0.5的学习率,对任何其他步骤使用0.1的学习率。

代码语言:javascript复制
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

# Later, whenever we perform an optimization step, we increment global_step.

参数:

  • x: 一个0-D标量张量。必须是下列类型之一:float32、float64、uint8、int8、int16、int32、int64。
  • boundaries: 张量、int或浮点数的列表,其条目严格递增,且所有元素具有与x相同的类型。
  • values: 张量、浮点数或整数的列表,指定边界定义的区间的值。它应该比边界多一个元素,并且所有元素应该具有相同的类型。
  • name: 一个字符串。操作的可选名称。默认为“PiecewiseConstant”。

返回值:

一个0维的张量。

当x <= boundries[0],值为values[0];

当x > boundries[0] && x<= boundries[1],值为values[1];

......

当x > boundries[-1],值为values[-1]

异常:

  • ValueError: if types of x and boundaries do not match, or types of all values do not match or the number of elements in the lists does not match.

0 人点赞