pytorch之autograd

2024-05-27 17:31:16 浏览数 (1)

在医院闲来无事,记录一个小参数,叫做retain_graph

先来学习两段代码,来比较其异同

代码语言:javascript复制
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)  

y.backward(retain_graph=True)
print(w.grad)
y.backward()
print(w.grad)

输出:

代码语言:javascript复制
tensor([5.])
tensor([10.])

第二段代码:

代码语言:javascript复制
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)
y.backward()
print(w.grad)

但是就会报错:

代码语言:javascript复制
tensor([5.])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 12
     10 y.backward()
     11 print(w.grad)
---> 12 y.backward()
     13 print(w.grad)

File ~.condaenvstorchgpulibsite-packagestorch_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~.condaenvstorchgpulibsite-packagestorchautograd__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C   engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

这是为什么呢,这里就要介绍一下本次要学习的参数了:

首先看一个函数的原型:

代码语言:javascript复制
torch.autograd.backward(
        tensors, 
        grad_tensors=None, 
        retain_graph=None, 
        create_graph=False, 
        grad_variables=None, 
        inputs=None)

这次我们来介绍 retain_graph.

我们都知道pytorch是经典的动态图,所以这个参数retain_graph是一个布尔类型的值,它的true or false直接说明了在计算过程中是否保留图

代码语言:javascript复制
retain_graph (bool, optional) – 
是否需要保留计算图。pytorch的机制是在方向传播结束时,
计算图释放以节省内存。大家可以尝试连续使用loss.backward(),
就会报错。如果需要多次求导,则在执行backward()时,retain_graph=True。

上面我们第二段代码,恰恰是计算了两次w的倒数,所以就会出现报错,所以,如果我们要计算多次导数,就要设置这个参数为true。

因为会累加梯度,所以我们在训练模型的时候经常需要设计zero_grad()这也是为了防止梯度爆炸

下面是一个手动结算的示意图,很简单,大佬勿喷。

0 人点赞