TensorFlow 张量的索引切片方式和 NumPy 模块差不多。与此同时,TensorFlow2.X 也提供了一些比较高级的切片函数,比如:
- 对张量进行不规则切片提取的
tf.gather
、tf.gather_nd
和tf.boolean_mask
; - 对张量的连续子区域进行切片提取的
tf.slice
。
相比于对张量进行不规则的切片提取的三个函数,tf.slice
的实现方式比较特殊,所以本文来详细的介绍 tf.slice
函数。
tf.slice(
input_, begin, size, name=None
)
tf.slice
函数主要有三个参数:
- input_: 待切片提取的张量
- begin: 张量每个维度进行切片操作的起始位置
- size: 张量每个维度取出的元素个数
为了理解 tf.slice
函数的实现方式,首先创建一个形状为 (3, 2, 3) 的三维的张量 X。
import tensorflow as tf
X = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
print(X.shape) # (3, 2, 3)
我们知道 n 维数组可以看成是每个元素是 n - 1 维数组的一维数组,有点类似复合函数,多维张量同样如此。我们用类似复合函数的方式将形状为 (3, 2, 3) 的三维张量进行分解。
- 第一个维度有 3 个元素,用 A, B, C 表示,即
X = [[A], [B], [C]]
; - 第二个维度有 2 个元素,第一个维度的 3 个元素,每个元素都有 2 个元素,用 i, j, k, l, m, n 表示,即
A = [i, j]
、B =[k, l]
、C = [m, n]
; - 第三个维度有 3 个元素,第二个维度的 2 个元素,每个元素都有 3 个元素,即
i = [1, 1, 1]
、j = [2, 2, 2]
、k = [3, 3, 3]
、l = [4, 4, 4]
、m = [5, 5, 5]
、n = [6, 6, 6]
。
为了直观,我们可以将其绘制成层次结构:
有了这些准备,我们直接在 X 上使用 tf.slice
函数:
print(tf.slice(X, [1, 0, 0], [1, 1, 3]))
'''
[[[3, 3, 3]]]
'''
此时 begin 和 size 两个参数分别是 [1, 0, 0]
和 [1, 1, 3]
,begin 参数为张量每个维度进行切片操作的起始位置,对于 [1, 0, 0]
,我们可以理解为:
- 第一个维度从位置 1 开始
- 第二个维度从位置 0 开始
- 第三个维度从位置 0 开始
size 参数为张量每个维度取出元素的个数,对于 [1, 1, 3]
,我们可以理解为:
- 第一个维度取出 1 个元素
- 第二个维度取出 1 个元素
- 第三个维度取出 3 个元素
我们按照维度整合 begin 和 size 参数:
- 第一个维度,从位置 1 开始,并且取出 1 个元素
- 第二个维度,从位置 0 开始,并且取出 1 个元素
- 第三个维度,从位置 0 开始,并且取出 3 个元素
不过这里有个需要注意的地方,按照上面的说法,此时可能有两种选取方式:
- 第一种方式:每次选取都是独立的;
- 第二种方式:按照层次结构逐层进行选取。
比如,按照第一种方式,第一个维度选择 B,第二个维度选择 i, j,第三个维度选择 [5, 5, 5]
,这种每次选取都独立的方式显然是不合理的。tf.slice
显然使用第二种方式,这也是为什么说 tf.slice
能够对张量的连续子区域进行切片。
接下来,就可以将上面对 tf.slice
的理解对应到三维张量 X 中,为了更直观的理解,我们使用上面的层次结构图,图中红色的部分表示已经被选中的元素。对于 begin 和 size 两个参数分别是 [1, 0, 0]
和 [1, 1, 3]
:
- 第一个维度,从位置 1 开始,并且取出 1 个元素(Python 的索引从 0 开始)
- 在选中的基础上,我们继续在第二个维度,从位置 0 开始,并且取出 1 个元素
- 在选中的基础上,我们继续在第三个维度,从位置 0 开始,并且取出 3 个元素
明白了 tf.slice
函数,下面再来几个例子。
print(tf.slice(t, [1, 0, 0], [1, 2, 3]))
'''
tf.Tensor(
[[[3 3 3]
[4 4 4]]], shape=(1, 2, 3), dtype=int32)
'''
代码语言:javascript复制print(tf.slice(X, [1, 0, 0], [2, 1, 3]))
'''
tf.Tensor(
[[[3 3 3]]
[[5 5 5]]], shape=(2, 1, 3), dtype=int32)
'''
References:
- https://www.quora.com/How-does-tf-slice-work-in-TensorFlow
- https://www.tensorflow.org/api_docs/python/tf/slice