view和reshape
PyTorch 中改变张量形状有 view、reshape 和 resize_ (没有原地操作的resize方法未来会被丢弃) 三种方式,「其中 resize_ 比较特殊,它能够在修改张量形状的同时改变张量的大小,而 view 和 reshape 方法不能改变张量的大小,只能够重新调整张量形状。」
resize_ 方法比较特殊,后续用到的时候再详细介绍。本文主要介绍 view 和 reshape 方法,在 PyTorch 中 view 方法存在很长时间,reshape 方法是在 PyTorch0.4 的版本中引入,两种方法功能上相似,但是一些细节上稍有不同,因此这里介绍两个方法的不同之处。
- view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储
nD 张量底层实现是使用一块连续内存的一维数组,由于 PyTorch 底层实现是 C 语言 (C/C 使用行优先方式),所以n维张量也使用行优先方式。比如对于下面形状为 (3 x 3) 的 2D 张量:
2D 张量在内存中实际以一维数组的形式进行存储,行优先的方式指的是存储的顺序按照 2D 张量的行依次存储。
上面形状为 (3 x 3) 的 2D 张量通常称为存储的逻辑结构,而实际存储的一维数组形式称为存储的物理结构。
- 如果元素在存储的逻辑结构上相邻,在存储的物理结构中也相邻,则称为连续存储的张量;
- 如果元素在存储的逻辑结构上相邻,但是在存储的物理结构中不相邻,则称为不连续存储的张量;
为了方便理解代码,先来熟悉一些方法。
- 可以通过
tensor.is_contiguous()
来查看 tensor 是否为连续存储的张量; - PyTorch 中的转置操作能够将连续存储的张量变成不连续存储的张量;
>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
True
>>> view_a = a.view(1, 9)
>>> reshape_a = a.reshape(9, 1)
>>> # 通过转置操作将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
False
>>> # view_t_a = a.view(1, 9) error
>>> reshape_t_a = a.reshape(1, 9)
其中 view_t_a = a.view(1, 9)
会抛出异常,再次验证了 view 只能用于数据连续存储的张量,而 reshape 则不需要考虑张量中的数据是否连续存储。
- view 方法会返回原始张量的视图,而 reshape 方法可能返回的是原始张量的视图或者拷贝
原始张量的视图简单来说就是和原始张量共享数据,因此如果改变使用 view 方法返回的新张量,原始张量也会发生相对应的改变。
代码语言:javascript复制>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> view_a = a.view(1, 9)
>>> print(view_a)
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
>>> # 更改张量中的元素值
>>> view_a[:, 1] = 100
>>> print(a)
tensor([[ 0, 100, 2],
[ 3, 4, 5],
[ 6, 7, 8]])
>>> print(view_a)
tensor([[ 0, 100, 2, 3, 4, 5, 6, 7, 8]])
reshape 方法可能返回的是原始张量的视图或者拷贝,当处理连续存储的张量 reshape 返回的是原始张量的视图,而当处理不连续存储的张量 reshape 返回的是原始张量的拷贝。
代码语言:javascript复制>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
True
>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)
tensor([[ 0, 100, 2],
[ 3, 4, 5],
[ 6, 7, 8]])
>>> print(reshape_a)
tensor([[ 0, 100, 2, 3, 4, 5, 6, 7, 8]])
代码语言:javascript复制>>> import torch
>>> a = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> # 通过转置将a变成不连续存储的张量
>>> a.t_()
>>> # 判断a是否为连续存储的张量
>>> print(a.is_contiguous())
False
>>> reshape_a = a.reshape(1, 9)
>>> # 更改新张量的元素值
>>> reshape_a[:, 1] = 100
>>> print(a)
tensor([[0, 3, 6],
[1, 4, 7],
[2, 5, 8]])
>>> print(reshape_a)
tensor([[ 0, 100, 6, 1, 4, 7, 2, 5, 8]])