一般情况下,如果expand和repeat都能得到目标矩阵,则在不更改目标矩阵元素(只读用法)时使用expand, 其他情况时使用repeat.
知识准备:
numpy.may_share_memory()查看是否指向同一个数据存储的内存地址
torch.Tensor.expand
代码语言:txt复制import torch
x = torch.tensor([1, 2,4])
y = x.expand(3, -1)
# In [8]: y
# Out[8]:
# tensor([[1, 2, 4],
# [1, 2, 4],
# [1, 2, 4]])
y[1, 0] = 5
# In [10]: y
# Out[10]:
# tensor([[5, 2, 4],
# [5, 2, 4],
# [5, 2, 4]])
import numpy as np
np.may_share_memory(y[0,:], y[1,:])
# Out[11]: True
torch.Tensor.expand()不拷贝数据,只是一种view。如果更改expand生成的数据中的某元素,则相关位置元素都会发生改变。help(torch.Tensor.expand)可以了解更多。
torch.Tensor.repeat
代码语言:txt复制import torch
x = torch.tensor(1, 2,4)
y = x.repeat(3,2)
# In 17: y
# Out17:
# tensor([1, 2, 4, 1, 2, 4,
# 1, 2, 4, 1, 2, 4,
# 1, 2, 4, 1, 2, 4])
import numpy as np
np.may_share_memory(y[0,:], y[1, :])
# Out[19]: False
y[1, 0] = 5
# In 21: y
# Out21:
# tensor([1, 2, 4, 1, 2, 4,
# 5, 2, 4, 1, 2, 4,
# 1, 2, 4, 1, 2, 4])
torch.Tensor.repeat拷贝数据,与numpy.tail功能类似。