Pytorch Autograd 基础(三)

2022-04-15 08:28:37 浏览数 (1)

本篇介绍如何关闭和打开Autograd。

  • 关闭和打开Autograd的最简单的方法是更改tensor的requires_grad 属性。
代码语言:javascript复制
import torch

a = torch.ones(2, 3, requires_grad=True)
print(a)

b1 = 2 * a  # b1 由 a 计算得来,继承了 a 当前额 requires_grad属性
print(b1)

a.requires_grad = False # 关闭 Autograd,不再追踪计算历史
b2 = 2 * a  # b2 由 a 计算得来,继承了 a 当前额 requires_grad属性
print(b2) # b2 也 关闭了 Autograd
print(b2.requires_grad) 
代码语言:javascript复制
tensor([[1., 1., 1.],
        [1., 1., 1.]], requires_grad=True)
tensor([[2., 2., 2.],
        [2., 2., 2.]], grad_fn=<MulBackward0>)
tensor([[2., 2., 2.],
        [2., 2., 2.]])
False

再次打开a的Autograd,并不影响b2。

代码语言:javascript复制
a.requires_grad = True
print(b2)
代码语言:javascript复制
tensor([[2., 2., 2.],
        [2., 2., 2.]])

还是可以将b2的requires_grad 属性设为True

代码语言:javascript复制
b2.requires_grad= True
print(b2)
代码语言:javascript复制
tensor([[2., 2., 2.],
        [2., 2., 2.]], requires_grad=True)
  • 如果只是想临时的关闭Augograd,最好的方式是用torch.no_grad()。
代码语言:javascript复制
a = torch.ones(2, 3, requires_grad=True) * 2
b = torch.ones(2, 3, requires_grad=True) * 3
c1 = a   b  # Autograd 自动打开
print(c1)

with torch.no_grad(): # 在这个上下文中临时关闭 Autograd
    c2 = a   b
    
print(c2)
c3 = a * b  # Autograd 任然自动打开
print(c3)
代码语言:javascript复制
tensor([[5., 5., 5.],
        [5., 5., 5.]], grad_fn=<AddBackward0>)
tensor([[5., 5., 5.],
        [5., 5., 5.]])
tensor([[6., 6., 6.],
        [6., 6., 6.]], grad_fn=<MulBackward0>)
  • torch.no_grad() 可以用做函数或者方法的装饰器,来关闭Autograd
代码语言:javascript复制
def add_tensors1(x, y):
    return x   y


@torch.no_grad() # 关闭 Augograd
def add_tensors2(x, y):
    return x   y


a = torch.ones(2, 3, requires_grad=True) * 2
b = torch.ones(2, 3, requires_grad=True) * 3
c1 = add_tensors1(a, b) # c1由a和b计算而来,跟随a和b,打开Autograd
print(c1)

c2 = add_tensors2(a, b) # 由于有@torch.no_grad(),c2关闭了Autograd
print(c2)
代码语言:javascript复制
tensor([[5., 5., 5.],
        [5., 5., 5.]], grad_fn=<AddBackward0>)
tensor([[5., 5., 5.],
        [5., 5., 5.]])

0 人点赞