torch tensor的repeat和expand

2020-01-14 17:48:11 浏览数 (2)

一般情况下,如果expand和repeat都能得到目标矩阵,则在不更改目标矩阵元素(只读用法)时使用expand, 其他情况时使用repeat.

知识准备:

numpy.may_share_memory()查看是否指向同一个数据存储的内存地址

torch.Tensor.expand

代码语言:txt复制
import torch
x = torch.tensor([1, 2,4])
y = x.expand(3, -1)
# In [8]: y
# Out[8]:
# tensor([[1, 2, 4],
#        [1, 2, 4],
#        [1, 2, 4]])

y[1, 0] = 5
# In [10]: y
# Out[10]:
# tensor([[5, 2, 4],
#        [5, 2, 4],
#        [5, 2, 4]])

import numpy as np
np.may_share_memory(y[0,:], y[1,:])
# Out[11]: True

torch.Tensor.expand()不拷贝数据,只是一种view。如果更改expand生成的数据中的某元素,则相关位置元素都会发生改变。help(torch.Tensor.expand)可以了解更多。

torch.Tensor.repeat

代码语言:txt复制
import torch
x = torch.tensor(1, 2,4)
y = x.repeat(3,2)

# In 17: y
# Out17:
# tensor([1, 2, 4, 1, 2, 4,
# 1, 2, 4, 1, 2, 4,
# 1, 2, 4, 1, 2, 4])

import numpy as np
np.may_share_memory(y[0,:], y[1, :])
# Out[19]: False

y[1, 0] = 5

# In 21: y
# Out21:
# tensor([1, 2, 4, 1, 2, 4,
#            5, 2, 4, 1, 2, 4,
#            1, 2, 4, 1, 2, 4])

torch.Tensor.repeat拷贝数据,与numpy.tail功能类似。

0 人点赞