深入理解Pytroch中的hook机制

2021-01-12 10:30:41 浏览数 (1)

【GiantPandaCV导语】Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码,这样就可以通过自定义各种操作来修改计算图中的张量。

点击小程序观看视频(时长22分)

视频太长不看版:

Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码(注入的代码也可以删除),这样用户就可以通过自定义各种操作来修改计算图中的张量。

Pytroch 中主要有两种hook,分别是注册在Tensor上的hook和注册在Module上的 hook。

注册在 Tensor 上的 hook,可以在反向回传过程中对梯度作修改,分为两种:

  • 叶子节点上的hook

会在 AccumulateGrad 之前对梯度做一些操作

  • 中间张量上的hook 在输出梯度传入 backward 函数计算输入梯度之前,调用注册的hook的函数对梯度做一些操作

注意:

最好不要在hook函数中对梯度做 inplace 修改,因为会直接修改该梯度张量,

如果该op有多个输入,比如 add op,那么在反向阶段,如果其中一个张量上注册的hook函数对梯度做了inplace修改,那么就会有可能影响到另一个输入张量的梯度。

注册在 Module 上的 hook,则可以在前后过程中对张量作修改,主要有三种:

  • 在module的前向被调用之前调用的hook函数

对Module的输入张量做一些操作

  • 在module的前向被调用之后调用的hook函数

对Module的输入和输出张量做一些操作

  • 后向过程会调用的hook

可以打印module输入张量的梯度,但是目前还有bug,建议不要用。

github上相关的讨论:https://github.com/pytorch/pytorch/issues/598


0 人点赞