PyTorch入门笔记-改变张量的形状

2021-01-03 15:58:58 浏览数 (1)

view和reshape

PyTorch 中改变张量形状有 view、reshape 和 resize_ (没有原地操作的resize方法未来会被丢弃) 三种方式,「其中 resize_ 比较特殊,它能够在修改张量形状的同时改变张量的大小,而 view 和 reshape 方法不能改变张量的大小,只能够重新调整张量形状。」

resize_ 方法比较特殊,后续用到的时候再详细介绍。本文主要介绍 view 和 reshape 方法,在 PyTorch 中 view 方法存在很长时间,reshape 方法是在 PyTorch0.4 的版本中引入,两种方法功能上相似,但是一些细节上稍有不同,因此这里介绍两个方法的不同之处。

  • view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储

nD 张量底层实现是使用一块连续内存的一维数组,由于 PyTorch 底层实现是 C 语言 (C/C 使用行优先方式),所以n维张量也使用行优先方式。比如对于下面形状为 (3 x 3) 的 2D 张量:

2D 张量在内存中实际以一维数组的形式进行存储,行优先的方式指的是存储的顺序按照 2D 张量的行依次存储。

上面形状为 (3 x 3) 的 2D 张量通常称为存储的逻辑结构,而实际存储的一维数组形式称为存储的物理结构。

  1. 如果元素在存储的逻辑结构上相邻,在存储的物理结构中也相邻,则称为连续存储的张量;
  2. 如果元素在存储的逻辑结构上相邻,但是在存储的物理结构中不相邻,则称为不连续存储的张量;

为了方便理解代码,先来熟悉一些方法。

  • 可以通过 tensor.is_contiguous() 来查看 tensor 是否为连续存储的张量;
  • PyTorch 中的转置操作能够将连续存储的张量变成不连续存储的张量;
代码语言:javascript复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

True

>>> view_a = a.view(1, 9)
>>> reshape_a = a.reshape(9, 1)
>>> # 通过转置操作将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

False

>>> # view_t_a = a.view(1, 9) error
>>> reshape_t_a = a.reshape(1, 9)

其中 view_t_a = a.view(1, 9) 会抛出异常,再次验证了 view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储。

  • view 方法会返回原始张量的视图,而 reshape 方法可能返回的是原始张量的视图或者拷贝

原始张量的视图简单来说就是和原始张量共享数据,因此如果改变使用 view 方法返回的新张量,原始张量也会发生相对应的改变。

代码语言:javascript复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                  [3, 4, 5],
                  [6, 7, 8]])
>>> view_a = a.view(1, 9)
>>> print(view_a)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])

>>> # 更改张量中的元素值
>>> view_a[:, 1] = 100
>>> print(a)

tensor([[  0, 100,   2],
        [  3,   4,   5],
        [  6,   7,   8]])

>>> print(view_a)

tensor([[  0, 100,   2,   3,   4,   5,   6,   7,   8]])

reshape 方法可能返回的是原始张量的视图或者拷贝,当处理连续存储的张量 reshape 返回的是原始张量的视图,而当处理不连续存储的张量 reshape 返回的是原始张量的拷贝。

代码语言:javascript复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

True

>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)

tensor([[  0, 100,   2],
        [  3,   4,   5],
        [  6,   7,   8]])

>>> print(reshape_a)

tensor([[  0, 100,   2,   3,   4,   5,   6,   7,   8]])
代码语言:javascript复制
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])
>>> # 通过转置将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())

False

>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)

tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])

>>> print(reshape_a)

tensor([[  0, 100,   6,   1,   4,   7,   2,   5,   8]])

0 人点赞