tf.slice

2022-09-03 22:03:01 浏览数 (2)

从张量中提取一个切片。

代码语言:javascript复制
tf.slice(
    input_,
    begin,
    size,
    name=None
)

这个操作从begin指定的位置开始,从张量输入中提取一个大小为size的切片。切片大小用张量形状表示,其中size[i]是要切片的输入的第i维的元素个数。切片的起始位置(begin)表示为输入每个维度中的偏移量。换句话说,begin[i]是要从中切片的输入的第i维的偏移量。注意,tf.Tensor。getitem通常是执行切片的一种更符合python风格的方法,因为它允许您编写foo[3:7,:-2]而不是tf。切片(foo, [3,0], [4, foo.get_shape()[1]-2])。开始是从零开始的;大小从1。如果size[i]为-1,则维度i中的所有剩余元素都包含在切片中。换句话说,这相当于设置:

  • size[i] = input.dim_size(i) - begin[i]

这项行动需要:

  • 0 <= begin[i] <= begin[i] size[i] <= Di for i in[0, n]

例:

代码语言:javascript复制
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
                                   #   [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
                                   #  [[5, 5, 5]]]

参数:

  • input_:张量。
  • begin:一个int32或int64张量。
  • 大小:一个int32或int64张量。
  • name:操作的名称(可选)。

返回值:

  • 与输入类型相同的张量。

0 人点赞