【GiantPandaCV导语】Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码,这样就可以通过自定义各种操作来修改计算图中的张量。
点击小程序观看视频(时长22分)
视频太长不看版:
Pytorch 中的 hook 机制可以很方便的让用户往计算图中注入控制代码(注入的代码也可以删除),这样用户就可以通过自定义各种操作来修改计算图中的张量。
Pytroch 中主要有两种hook,分别是注册在Tensor上的hook和注册在Module上的 hook。
注册在 Tensor 上的 hook,可以在反向回传过程中对梯度作修改,分为两种:
- 叶子节点上的hook
会在 AccumulateGrad 之前对梯度做一些操作
- 中间张量上的hook 在输出梯度传入 backward 函数计算输入梯度之前,调用注册的hook的函数对梯度做一些操作
注意:
最好不要在hook函数中对梯度做 inplace 修改,因为会直接修改该梯度张量,
如果该op有多个输入,比如 add op,那么在反向阶段,如果其中一个张量上注册的hook函数对梯度做了inplace修改,那么就会有可能影响到另一个输入张量的梯度。
注册在 Module 上的 hook,则可以在前后过程中对张量作修改,主要有三种:
- 在module的前向被调用之前调用的hook函数
对Module的输入张量做一些操作
- 在module的前向被调用之后调用的hook函数
对Module的输入和输出张量做一些操作
- 后向过程会调用的hook
可以打印module输入张量的梯度,但是目前还有bug,建议不要用。
github上相关的讨论:https://github.com/pytorch/pytorch/issues/598