PyTorch入门笔记-判断张量是否连续

2021-01-18 15:31:45 浏览数 (1)

判断张量是否连续

nD 张量底层实现是使用一块连续内存的一维数组,由于 PyTorch 底层实现是 C 语言 (C/C 使用行优先的存储方式),所以 PyTorch 中的 nD 张量也按照行优先的顺序进行存储的。

下图为一个形状为 (2times 3) 的 2D 张量,为了方便将其命名为 A

张量 A内存中实际以一维数组的形式进行存储,并且使用行优先的顺序进行存储,其中一维数组的形式存储比较好理解,而行优先指的就是存储顺序按照张量 A行依次存储。 张量 A 在内存中的实际存储形式如下所示。

张量 A常称为存储的逻辑结构,而实际存储的一维数组形式称为存储的物理结构。

  • 如果元素在存储的逻辑结构上相邻,在存储的物理结构中也相邻,则称为连续存储的张量;
  • 如果元素在存储的逻辑结构上相邻,但是在存储的物理结构中不相邻,则称为不连续存储的张量;

在 "改变张量形状" 中提到过,交换维度的操作能够将连续存储的张量转变成不连续存储的张量。在 PyTorch 中对于张量是否连续有一个等式。nD 张量,对于任意一个维度 i (i = 0, ...,n-1 但是 i ne n-1)都满足下面的等式则说明 nD 张量连续,不满足则说明 nD 张量不连续。

stride[i] = stride[i 1] times size[i 1]

其中 stride[i] 表示逻辑结构中第 i 个维度上相邻的元素在物理结构中间隔的元素个数,size[i] 表示逻辑结构中第 i 个维度的元素个数。

下面使用公式来判断张量 A 是否连续?2D 张量一共有两个维度,因此 i 只能取 0 (因为 ine (2-1)=1),接下来只需要判断下面等式是否成立。

stride[0] = stride[1] times size[1]

其中:

  • stride[0] 为张量 A (逻辑结构) 的第 0 个维度上相邻的元素在一维数组 (物理结构) 中间隔的元素个数。张量 A 中第 0 个维度上相邻的元素有 (0, 3) (1, 4) (2, 5),这些在张量 A 中相邻的元素,在一维数组中这些相邻元素的间隔数都为 3 (计数包含本身),即 stride[0] = 3
  • stride[1] 为张量 A (逻辑结构) 的第 1 个维度上相邻的元素在一维数组 (物理结构) 中间隔的元素个数。张量 A 中第 1 个维度上相邻的元素有 (0, 1) (1, 2) (3, 4) (4, 5),这些在张量 A 中相邻的元素,在一维数组中这些相邻元素的间隔数都为 1 (计数包含本身),即 stride[1] = 1
  • size[1] 为张量 A (逻辑结构) 中第 1 个维度上的元素个数,即 size[1] = 3

将这些对应的值代入等式 stride[0] = stride[1] times size[1] 中,即 3 = 1 times 3,等式成立,则张量 A 是连续的。

在 PyTorch 中,使用维度变换的操作能够将连续存储的张量转变成不连续存储的张量,接下来使用等式判断交换维度后的张量 A 是否还是连续存储的张量?张量 A 交换维度后的结果如下。

这里需要注意,我们是通过张量 A 交换维度后得到的是 (3times 2) 的 2D 张量,为了方便将其命名为 A^T在 PyTorch 中交换维度的操作并没有改变其实际的存储,换句话说,交换维度后的张量与原始张量共享同一块内存,因此交换维度后的张量 A^T 底层存储和原始张量 A 都是相同的一维数组。

下面来使用公式判断张量 A^T 是否连续?2D 张量一共有两个维度,因此 i 只能取 0 (因为 ine (2-1)=1),接下来只需要判断下面等式是否成立。

stride[0] = stride[1] times size[1]

其中:

  • 为张量 (逻辑结构) 的第 0 个维度上相邻的元素在一维数组 (物理结构) 中间隔的元素个数。张量 中第 0 个维度上相邻的元素有 (0, 1) (1, 2) (3, 4) (4, 5),这些在张量 中相邻的元素,在一维数组中这些相邻元素的间隔数都为 1 (计数包含本身),即 ;
  • stride[1] 为张量 A^T (逻辑结构) 的第 1 个维度上相邻的元素在一维数组 (物理结构) 中间隔的元素个数。张量 A^T 中第 1 个维度上相邻的元素有 (0, 3) (1, 4) (2, 5),这些在张量 A^T中相邻的元素,在一维数组中这些相邻元素的间隔数都为 3 (计数包含本身),即 stride[1] = 3
  • size[1] 为张量 A^T (逻辑结构) 中第 1 个维度上的元素个数,即 size[1] = 2

将这些对应的值代入公式 stride[0] = stride[1] times size[1] 中,即 1 = 3 times 2,等式不成立,则张量 A^T 是不连续的。

由于 2D 张量比较容易理解,所以这里都是以 2D 张量为例进行介绍的,2D 张量只需要满足 1 个等式即可判断是否连续,而如果是 nD 张量,则需要判断 (n-1) 个等式。

0 人点赞