PyTorch: 计算图与动态图机制

2022-11-13 09:19:19 浏览数 (1)

本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。

文章目录
  • 计算图
  • PyTorch的动态图机制

计算图

计算图是用来描述运算的有向无环图

计算图有两个主要元素:

  • 结点 Node
  • 边 Edge

结点表示数据:如向量,矩阵,张量

边表示运算:如加减乘除卷积等

用计算图表示:y = (x w) * (w 1) a = x w b = w 1 y = a * b

计算图与梯度求导

y = (x w) * (w 1) a = x w b = w 1 y = a * b

begin{aligned} frac{partial y}{partial w} &=frac{partial y}{partial a} frac{partial a}{partial w} frac{partial y}{partial b} frac{partial b}{partial w} \ &=b * 1 a * 1 \ &=b a \ &=(w 1) (x w) \ &=2 * w x 1 \ &=2 * 1 2 1=5 end{aligned}

可见,对于变量w的求导过程就是寻找它在计算图中的所有路径的求导之和。

code:

代码语言:javascript复制
import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)     # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)
代码语言:javascript复制
tensor([5.])

计算图与梯度求导 y = (x w) * (w 1)

叶子结点 :用户创建的结点称为叶子结点,如 X 与 W

is_leaf: 指示张量是否为叶子结点

叶子节点的作用是标志存储叶子节点的梯度,而清除在反向传播过程中的变量的梯度,以达到节省内存的目的。 当然,如果想要保存过程中变量的梯度值,可以采用retain_grad()

grad_fn: 记录创建该张量时所用的方法(函数)

  • y.grad_fn= <MulBackward0>
  • a.grad_fn= <AddBackward0>
  • b.grad_fn= <AddBackward0>

PyTorch的动态图机制

根据计算图搭建方式,可将计算图分为动态图静态图

  • 动态图 运算与搭建同时进行 灵活 易调节 例如动态图 PyTorch:
  • 静态 先搭建图, 后运算 高效 不灵活。 静态图 TensorFlow

0 人点赞