关于张量的底层存储逻辑这一部分看的我有点头大,但是了解底层实现确实有助于理解tensor中的各种运算到底是怎么一个回事,当然大部分时间我们可以不太会用到这些存储操作,但是熟悉这些底层实现,我觉得一方面可以帮我屏蔽一些开发上的bug,或者说在查bug的时候会往这个方面思考;再一个就是如果真的有需要做比较硬核的优化的时候也能够有点想法。
张量的存储
前面我们说过,张量的存储空间是连续的,最开始我可能以为存储像张量的结构一样, 比如说像这样的方块区域
但是,实际上它是这样存储的
然后使用偏移量和步长来进行索引,关于这两个概念我们后面会讨论。
PyTorch提供了一个storage方法来访问内存,如下我们创建了一个三行二列的二维tensor,然后用storage()读取它的内存,我们可以看到结果,实际底层存储是一个size为6的连续数组,而我们的tensor方法所实现的就是怎么通过索引把数组转换成我们需要的张量以及各种运算的方法。
代码语言:javascript复制points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points.storage()
outs:
4.0
1.0
5.0
3.0
2.0
1.0
[torch.FloatStorage of size 6]
我们可以使用索引来查询这个存储区,比如
代码语言:javascript复制points_storage = points.storage()
points_storage[0]
outs:
4.0
显而易见的是,我们不能用二维索引,因为这个存储区只是一个一维数组,同时,如果我们修改存储区的数据,那么tensor的数据自然而然会发生变化。
代码语言:javascript复制points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_storage = points.storage()
points_storage[0] = 2.0 #给存储区位置0赋值2
points
outs:tensor([[2., 1.],
[5., 3.],
[2., 1.]])
关于带下划线的操作
在tensor的操作中,有少量的方法是带下划线的,比如zero_(),这样的方法只作为tensor对象的方法,我们可以认为是原地操作的方法,也就是说这样的方法是直接修改输入然后返回结果,而对应不带下划线的方法不会去改变源tensor,而是返回一个新的tensor。 让我们看一段代码:
代码语言:javascript复制import torch
a = torch.ones(3, 2)
a
outs:tensor([[1., 1.],
[1., 1.],
[1., 1.]])
b = a.zero_()
b
outs:tensor([[0., 0.],
[0., 0.],
[0., 0.]])
a
outs:tensor([[0., 0.],
[0., 0.],
[0., 0.]])
可以看到使用了zero_()方法之后,虽然我们看起来赋值给了b,但实际上底层发生了变化,a的数值也都是0了。
元数据是如何计算的
既然我们已经知道了tensor的底层存储实际上是连续的一维数组,那么下面来了解一下tensor通过什么样的方式来把底层存储处理成上层实现。
大小、偏移量、步长
这里作者给了三个概念,就是张量的大小、偏移量和步长,作者手绘的图像如下
大小(size):大小这个概念很容易理解,比如说图中给的tensor在表现上来看是一个3*3的矩阵,tensor的大小就是一个元组,里面记录了每一个维度有多少元素。
偏移量(offset):偏移量指的是这个tensor的第一个元素在当前存储区上的位置索引。我理解是这样的,对于一个完整的tensor,offset都是0。但是在某些情况,比如说我们有一个4*4的tensor,我们从它的(1,1)的位置选取一个子tensor,这个时候这个子tensor的offset就不是0了,应该是5?
步长(stride):这个概念我在抽象层面能够理解,但是实际看了例子还是差了一点,花了好长时间才搞明白一些。为此还专门去查了stride的英文意思,stride有“跨过,步幅的意思”,在这里去理解它,是指的按照tensor的顺序,沿着一个维度获取下一个元素在实际存储区所需要跳过的元素数量。
比如说上面的例子里,沿着行这个维度获取下一个元素也就是5->1这个动作,在存储区需要跨过3个元素,而沿着列这个维度获取下一个元素5->7这个动作,只需要跨过1个元素就可以了。
我们可以通过代码来查看偏移量和步长。
代码语言:javascript复制points = torch.tensor([[4.0, 1.0, 3.0, 2.0], [5.0, 3.0, 7.0, 8.0], [2.0, 1.0, 9.0, 5.0],[3.0, 8.0, 4.0, 5.0]]) #先生成一个新的tensor
second_point = points[1:,1:] #从原始tensor中摘取一个子tensor
second_point #让我们看看截取的子tensor对不对
outs:tensor([[3., 7., 8.],
[1., 9., 5.],
[8., 4., 5.]])
points.storage_offset() #原tensor的偏移量
outs:0
second_point.storage_offset() #子tensor的偏移量
outs:5 #看起来跟我们猜测的一样
#再来看一下步长
points.stride() #原始tensor的步长
outs:(4,1)
second_point.stride() #子tensor的步长
outs:(4,1)
可以看到这里的原始tensor和子tensor的步长都是一样的,这是为什么呢,很容易理解啊,我们是从(1,1)开始截取的,在底层存储不变的情况下,子tensor要按维度跳到下一个元素位置所经过的元素跟原tensor是一样的!
因此,我们修改子tensor也会引起原tensor的变化。如果要开辟一块新的空间来存这个tensor可以使用clone方法,这时候second_point就在一个新的tensor存储空间,对其修改不会影响points
代码语言:javascript复制second_point = points[1:,1:].clone()
second_point[0,0] = 10.0
second_point
outs:tensor([[10., 7., 8.],
[ 1., 9., 5.],
[ 8., 4., 5.]])
points
outs:tensor([[4., 1., 3., 2.],
[5., 3., 7., 8.],
[2., 1., 9., 5.],
[3., 8., 4., 5.]])
如果说在这里似乎还看不出这个存储方案有什么神奇之处,下面我们看看对tensor进行操作之后的情况。
转置之后发生了啥
我们重新构建一个tensor
代码语言:javascript复制points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points
outs:tensor([[4., 1.],
[5., 3.],
[2., 1.]])
points_t = points.t() #t()方法是用于二维张量转置时对transpose()方法的简写
points_t
outs:tensor([[4., 5., 2.],
[1., 3., 1.]])
转置之后发生了什么呢,其实什么都没有发生,存储区还是一个存储区,变的只是tensor对于存储区的索引结构
代码语言:javascript复制#验证这两个tensor是用的一个存储区
id(points.storage()) == id(points_t.storage())
我们来看一下步长的变化
代码语言:javascript复制points.stride()
outs:(2, 1)
points_t.stride()
out:(1, 2)
从上面的代码我们可以看出,转置之后不同维度的步长做了相应的调整,示例图如下(突然发现原图有问题,我重新画了一个),转置后的tensor按行维度找下一个元素也就是4->1,只需要跨过1个元素,同理,在列维度则需要跨过2个元素。
什么是连续张量
连续张量的概念貌似很拗口,反正我看翻译是没有看懂,所以我把原文放在下面了,大意是有这样一个张量,它的值以最右侧的维度开始按顺序在存储区间中排列,这种张量就是连续张量。虽然概念很拗口,但是理念是很简单的,这里举了一个例子:比如说一个二维tensor,沿着行移动。
A tensor whose values are laid out in the storage starting from the rightmost dimension onward (that is, moving along rows for a 2D tensor) is defined as contiguous
再来看实际的代码,就更容易理解了:
代码语言:javascript复制points.is_contiguous()
outs:True
points_t.is_contiguous()
outs:False
在tensor的顺序和存储区顺序一致的就是连续张量,否则就不是。在PyTorch中,有一些操作只针对连续张量起作用,如果我们对那些不是连续张量的张量实施这些操作就会报错。那么如果我们想用这些方法怎么办呢,PyTorch自然也给出了解决办法,那就是contiguous方法,使用这个方法会改变存储区存储顺序,使得存储区顺序符合当前tensor连续的要求。
代码语言:javascript复制points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_t = points.t()
points_t #查看转置后的tensor
outs:tensor([[4., 5., 2.],
[1., 3., 1.]])
points_t.storage() #查看存储区顺序
outs:
4.0
1.0
5.0
3.0
2.0
1.0
[torch.FloatStorage of size 6]
points_t.stride() #查看步长信息
outs:(1, 2)
points_t_cont = points_t.contiguous() #调用contiguous方法
points_t_cont #可以看到tensor的表示没有发生变化
outs:tensor([[4., 5., 2.],
[1., 3., 1.]])
points_t_cont.stride() #但是步长信息变了
outs:(3, 1)
points_t_cont.storage() #再看一下存储区,已经发生了变化
outs:
4.0
5.0
2.0
1.0
3.0
1.0
[torch.FloatStorage of size 6]
今天就看这么多吧。