详解 tf.slice 函数

2022-05-25 14:02:57 浏览数 (1)

TensorFlow 张量的索引切片方式和 NumPy 模块差不多。与此同时,TensorFlow2.X 也提供了一些比较高级的切片函数,比如:

  • 对张量进行不规则切片提取的 tf.gathertf.gather_ndtf.boolean_mask
  • 对张量的连续子区域进行切片提取的 tf.slice

相比于对张量进行不规则的切片提取的三个函数,tf.slice 的实现方式比较特殊,所以本文来详细的介绍 tf.slice 函数。

代码语言:javascript复制
tf.slice(
    input_, begin, size, name=None
)

tf.slice 函数主要有三个参数:

  • input_: 待切片提取的张量
  • begin: 张量每个维度进行切片操作的起始位置
  • size: 张量每个维度取出的元素个数

为了理解 tf.slice 函数的实现方式,首先创建一个形状为 (3, 2, 3) 的三维的张量 X。

代码语言:javascript复制
import tensorflow as tf

X = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])

print(X.shape) # (3, 2, 3)

我们知道 n 维数组可以看成是每个元素是 n - 1 维数组的一维数组,有点类似复合函数,多维张量同样如此。我们用类似复合函数的方式将形状为 (3, 2, 3) 的三维张量进行分解。

  1. 第一个维度有 3 个元素,用 A, B, C 表示,即 X = [[A], [B], [C]]
  2. 第二个维度有 2 个元素,第一个维度的 3 个元素,每个元素都有 2 个元素,用 i, j, k, l, m, n 表示,即 A = [i, j]B =[k, l]C = [m, n]
  3. 第三个维度有 3 个元素,第二个维度的 2 个元素,每个元素都有 3 个元素,即 i = [1, 1, 1]j = [2, 2, 2]k = [3, 3, 3]l = [4, 4, 4]m = [5, 5, 5]n = [6, 6, 6]

为了直观,我们可以将其绘制成层次结构:

有了这些准备,我们直接在 X 上使用 tf.slice 函数:

代码语言:javascript复制
print(tf.slice(X, [1, 0, 0], [1, 1, 3]))
'''
[[[3, 3, 3]]]
'''

此时 begin 和 size 两个参数分别是 [1, 0, 0][1, 1, 3],begin 参数为张量每个维度进行切片操作的起始位置,对于 [1, 0, 0],我们可以理解为:

  • 第一个维度从位置 1 开始
  • 第二个维度从位置 0 开始
  • 第三个维度从位置 0 开始

size 参数为张量每个维度取出元素的个数,对于 [1, 1, 3],我们可以理解为:

  • 第一个维度取出 1 个元素
  • 第二个维度取出 1 个元素
  • 第三个维度取出 3 个元素

我们按照维度整合 begin 和 size 参数:

  • 第一个维度,从位置 1 开始,并且取出 1 个元素
  • 第二个维度,从位置 0 开始,并且取出 1 个元素
  • 第三个维度,从位置 0 开始,并且取出 3 个元素

不过这里有个需要注意的地方,按照上面的说法,此时可能有两种选取方式:

  1. 第一种方式:每次选取都是独立的;
  2. 第二种方式:按照层次结构逐层进行选取。

比如,按照第一种方式,第一个维度选择 B,第二个维度选择 i, j,第三个维度选择 [5, 5, 5],这种每次选取都独立的方式显然是不合理的。tf.slice 显然使用第二种方式,这也是为什么说 tf.slice 能够对张量的连续子区域进行切片。

接下来,就可以将上面对 tf.slice 的理解对应到三维张量 X 中,为了更直观的理解,我们使用上面的层次结构图,图中红色的部分表示已经被选中的元素。对于 begin 和 size 两个参数分别是 [1, 0, 0][1, 1, 3]

  • 第一个维度,从位置 1 开始,并且取出 1 个元素(Python 的索引从 0 开始)
  • 在选中的基础上,我们继续在第二个维度,从位置 0 开始,并且取出 1 个元素
  • 在选中的基础上,我们继续在第三个维度,从位置 0 开始,并且取出 3 个元素

明白了 tf.slice 函数,下面再来几个例子。

代码语言:javascript复制
print(tf.slice(t, [1, 0, 0], [1, 2, 3]))
'''
tf.Tensor(
[[[3 3 3]
  [4 4 4]]], shape=(1, 2, 3), dtype=int32)
'''
代码语言:javascript复制
print(tf.slice(X, [1, 0, 0], [2, 1, 3]))
'''
tf.Tensor(
[[[3 3 3]]
 [[5 5 5]]], shape=(2, 1, 3), dtype=int32)
'''

References:

  1. https://www.quora.com/How-does-tf-slice-work-in-TensorFlow
  2. https://www.tensorflow.org/api_docs/python/tf/slice

0 人点赞