slim.flatten()

2022-09-04 21:22:55 浏览数 (4)

代码语言:javascript复制
def flatten(inputs, outputs_collections=None, scope=None):
  """Flattens the input while maintaining the batch_size.

    Assumes that the first dimension represents the batch.

  Args:
    inputs: A tensor of size [batch_size, ...].
    outputs_collections: Collection to add the outputs.
    scope: Optional scope for name_scope.

  Returns:
    A flattened tensor with shape [batch_size, k].
  Raises:
    ValueError: If inputs rank is unknown or less than 2.
  """
  with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
    inputs = ops.convert_to_tensor(inputs)
    outputs = core_layers.flatten(inputs)
    return utils.collect_named_outputs(outputs_collections, sc, outputs)

在保持batch_size的同时,将输入压扁。假设第一个维度表示批处理。

参数:

  • inputs:一个大小张量[batch_size,…]
  • outputs_collections:用于添加输出的集合
  • scope:name_scope的可选作用域

返回值:

  • 一个具有形状[batch_size, k]的平坦张量。

可能产生的异常:

  • ValueError: If inputs rank is unknown or less than 2.

0 人点赞