Pytorch 张量tensor

2021-02-19 15:02:43 浏览数 (1)

文章目录

    • 1. tensor 张量
    • 2. 运算
    • 3. 切片、形状size()、改变形状view()
    • 4. item() 只能读取一个元素

参考 http://pytorch123.com/

1. tensor 张量

  • empty 不初始化
代码语言:javascript复制
import torch
x = torch.empty(5,3) # 不初始化
print(x)

tensor([[1.0010e-38, 4.2246e-39, 1.0286e-38],
        [1.0653e-38, 1.0194e-38, 8.4490e-39],
        [1.0469e-38, 9.3674e-39, 9.9184e-39],
        [8.7245e-39, 9.2755e-39, 8.9082e-39],
        [9.9184e-39, 8.4490e-39, 9.6429e-39]])
  • rand 随机初始化 0 - 1 之间
代码语言:javascript复制
x = torch.rand(5,3) # 随机初始化

tensor([[0.5931, 0.2422, 0.2738],
        [0.0949, 0.4755, 0.7422],
        [0.7418, 0.5980, 0.4837],
        [0.4228, 0.4489, 0.2633],
        [0.7277, 0.7254, 0.8932]])
  • zeros 初始化为0,dtype指定数据类型
代码语言:javascript复制
x = torch.zeros(5,3,dtype=torch.long)

tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
  • 直接赋值
代码语言:javascript复制
x = torch.tensor([[5.5, 3], [2,4]])

tensor([[5.5000, 3.0000],
        [2.0000, 4.0000]])
  • new_* 方法,继承之前张量的属性,也可以覆盖以前的属性
代码语言:javascript复制
x = x.new_ones(5,3,dtype=torch.double)
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

x = x.new_zeros(2,4)
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=torch.float64) # 可见属性继承了之前的
  • rand_like 形状跟之前的一样
代码语言:javascript复制
x = torch.randn_like(x,dtype=torch.float)
print(x)
print(x.size())

tensor([[ 0.2575, -0.3525,  1.2242, -0.0641],
        [ 0.0307,  0.0433, -0.3609,  2.0844]])
torch.Size([2, 4])

2. 运算

代码语言:javascript复制
x = torch.eye(3)
y = torch.zeros(3,3)
print(x y) #  
print(torch.add(x,y)) # add

res = torch.empty(2,2)
print(res.size())  # torch.Size([2, 2])
torch.add(x,y,out=res) # out 为输出变量
print(res)
print(res.size()) # torch.Size([3, 3])

# in-place 加法
y.add_(x) # y = y x, y 会变, 注意是 add_ 有下划线
print(y)

3. 切片、形状size()、改变形状view()

切片跟numpy一样

代码语言:javascript复制
print(x[:,:1].size())  # torch.Size([3, 1])

x = torch.randn(4,4)
y = x.view(16)
z = x.view(-1,8) # -1 自动推断
print(x.size(), y.size(), z.size())

# torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])

4. item() 只能读取一个元素

代码语言:javascript复制
x = torch.randn(1)
print(x)
print(x.item())
# tensor([-0.3280])
# -0.327981561422348

x = torch.randn(2,3)
print(x)
print(x[0,1].item()) # 只能获取一个元素

# tensor([[-1.2239,  0.3518,  1.1019],
#         [-0.1341,  1.0625,  0.2442]])
# 0.3518247902393341

0 人点赞