slim.variance_scaling_initializer()

2022-09-04 21:23:48 浏览数 (1)

代码语言:javascript复制
def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
  """Returns an initializer performing "Xavier" initialization for weights.

  This function implements the weight initialization from:

  Xavier Glorot and Yoshua Bengio (2010):
           [Understanding the difficulty of training deep feedforward neural
           networks. International conference on artificial intelligence and
           statistics.](
           http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)

  This initializer is designed to keep the scale of the gradients roughly the
  same in all layers. In uniform distribution this ends up being the range:
  `x = sqrt(6. / (in   out)); [-x, x]` and for normal distribution a standard
  deviation of `sqrt(2. / (in   out))` is used.

  Args:
    uniform: Whether to use uniform or normal distributed random initialization.
    seed: A Python integer. Used to create random seeds. See
          `tf.set_random_seed` for behavior.
    dtype: The data type. Only floating point types are supported.
  Returns:
    An initializer for a weight matrix.
  Raises:
    ValueError: if `dtype` is not a floating point type.
    TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].
  """
  if not dtype.is_floating:
    raise TypeError('Cannot create initializer for non-floating point type.')
  if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
    raise TypeError('Unknown mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)

返回对权重执行“Xavier”初始化的初始化器。此函数实现权重初始化,从:

Xavier Glorot和yobengio(2010):[了解深度前馈神经网络训练的难点]。(http://www.jmlr.org/programedings/papers/v9/glorot10a/glorot10a.pdf)

这个初始化器的设计目的是在所有层中保持梯度的比例大致相同。在均匀分布中,这个范围是' x = sqrt(6。/ (in out);正态分布的标准差为√2。/ (in out))’。

参数:

  • factor:浮动。一个乘法因素
  • mode:字符串。“FAN_IN”、“FAN_OUT’,‘FAN_AVG’
  • uniform:是否使用均匀或正态分布随机初始化
  • seed:一个Python整数。用于创建随机种子。看到“特遣部队。set_random_seed”行为
  • dtype:数据类型。只支持浮点类型

返回值:

  • 生成单位方差张量的初始化器

可能产生的异常:

  • ValueError: if `dtype` is not a floating point type. TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].

0 人点赞