slim.max_pool2d()

2022-09-04 21:22:39 浏览数 (1)

代码语言:javascript复制
def max_pool2d(inputs,
               kernel_size,
               stride=2,
               padding='VALID',
               data_format=DATA_FORMAT_NHWC,
               outputs_collections=None,
               scope=None):
  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')
  with ops.name_scope(scope, 'MaxPool2D', [inputs]) as sc:
    inputs = ops.convert_to_tensor(inputs)
    df = ('channels_first'
          if data_format and data_format.startswith('NC') else 'channels_last')
    layer = pooling_layers.MaxPooling2D(
        pool_size=kernel_size,
        strides=stride,
        padding=padding,
        data_format=df,
        _scope=sc)
    outputs = layer.apply(inputs)
    return utils.collect_named_outputs(outputs_collections, sc, outputs)

添加了一个2D最大池化操作,它假设池化是按每张图像完成的,但不是按批处理或通道完成的。

参数:

  • inputs:一个形状' [batch_size, height, width, channels] '的4-D张量,如果' data_format '是' NHWC ',那么' [batch_size, channels, height, width] '如果' data_format '是' NCHW '
  • kernel_size:计算op的池内核的长度2:[kernel_height, kernel_width]的列表。如果两个值相同,则可以是int
  • stride:一个长度为2的列表:[stride_height, stride_width]。如果两个步骤相同,则可以是int。注意,目前这两个步骤必须具有相同的值
  • padding:填充方法,要么“有效”,要么“相同”
  • data_format:一个字符串。支持' NHWC '(默认值)和' NCHW '
  • outputs_collections:将输出添加到其中的集合
  • scope:name_scope的可选作用域

返回值:

  • 表示池操作结果的“张量”

可能产生的异常:

  • ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
  • ValueError: If 'kernel_size' is not a 2-D list

0 人点赞