pytorch tensor与numpy转换

2022-09-02 13:31:17 浏览数 (1)

tensor to numpy

代码语言:javascript复制
a = torch.ones(5)
print(a)

输出

代码语言:javascript复制
tensor([1., 1., 1., 1., 1.])

进行转换

代码语言:javascript复制
b = a.numpy()
print(b)

输出

代码语言:javascript复制
[1. 1. 1. 1. 1.]

注意,转换后的tensor与numpy指向同一地址,所以,对一方的值改变另一方也随之改变

代码语言:javascript复制
a.add_(1)
print(a)
print(b)

numpy to tensor

代码语言:javascript复制
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

输出

代码语言:javascript复制
[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

除chartensor外所有tensor都可以转换为numpy

0 人点赞