代码语言:javascript复制
def flatten(inputs, outputs_collections=None, scope=None):
"""Flattens the input while maintaining the batch_size.
Assumes that the first dimension represents the batch.
Args:
inputs: A tensor of size [batch_size, ...].
outputs_collections: Collection to add the outputs.
scope: Optional scope for name_scope.
Returns:
A flattened tensor with shape [batch_size, k].
Raises:
ValueError: If inputs rank is unknown or less than 2.
"""
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
outputs = core_layers.flatten(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
在保持batch_size的同时,将输入压扁。假设第一个维度表示批处理。
参数:
- inputs:一个大小张量[batch_size,…]
- outputs_collections:用于添加输出的集合
- scope:name_scope的可选作用域
返回值:
- 一个具有形状[batch_size, k]的平坦张量。
可能产生的异常:
- ValueError: If inputs rank is unknown or less than 2.