在医院闲来无事,记录一个小参数,叫做retain_graph
先来学习两段代码,来比较其异同
代码语言:javascript复制import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward(retain_graph=True)
print(w.grad)
y.backward()
print(w.grad)
输出:
代码语言:javascript复制tensor([5.])
tensor([10.])
第二段代码:
代码语言:javascript复制import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)
y.backward()
print(w.grad)
但是就会报错:
代码语言:javascript复制tensor([5.])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 12
10 y.backward()
11 print(w.grad)
---> 12 y.backward()
13 print(w.grad)
File ~.condaenvstorchgpulibsite-packagestorch_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
482 if has_torch_function_unary(self):
483 return handle_torch_function(
484 Tensor.backward,
485 (self,),
(...)
490 inputs=inputs,
491 )
--> 492 torch.autograd.backward(
493 self, gradient, retain_graph, create_graph, inputs=inputs
494 )
File ~.condaenvstorchgpulibsite-packagestorchautograd__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
246 retain_graph = create_graph
248 # The reason we repeat the same comment below is that
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C engine to run the backward pass
252 tensors,
253 grad_tensors_,
254 retain_graph,
255 create_graph,
256 inputs,
257 allow_unreachable=True,
258 accumulate_grad=True,
259 )
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
这是为什么呢,这里就要介绍一下本次要学习的参数了:
首先看一个函数的原型:
代码语言:javascript复制torch.autograd.backward(
tensors,
grad_tensors=None,
retain_graph=None,
create_graph=False,
grad_variables=None,
inputs=None)
这次我们来介绍 retain_graph.
我们都知道pytorch是经典的动态图,所以这个参数retain_graph是一个布尔类型的值,它的true or false直接说明了在计算过程中是否保留图
代码语言:javascript复制retain_graph (bool, optional) –
是否需要保留计算图。pytorch的机制是在方向传播结束时,
计算图释放以节省内存。大家可以尝试连续使用loss.backward(),
就会报错。如果需要多次求导,则在执行backward()时,retain_graph=True。
上面我们第二段代码,恰恰是计算了两次w的倒数,所以就会出现报错,所以,如果我们要计算多次导数,就要设置这个参数为true。
因为会累加梯度,所以我们在训练模型的时候经常需要设计zero_grad()这也是为了防止梯度爆炸
下面是一个手动结算的示意图,很简单,大佬勿喷。
完