作者:Frank Odom 编译:McGL
什么是钩子(Hook)?
Hook 实际上在软件工程中相当常见,并不是 PyTorch 所独有的。一般来说,“hook”是在特定事件之后自动执行的函数。在现实世界中,你可能遇到过的一些 hook 的例子:
- 网站在你访问 N 个不同页面后会显示一个广告。
- 你的账户有资金入账时,银行 app 发送通知消息。
- 当周围光线减弱时,手机屏幕亮度会变暗。
这些事情没有 hook 也可以实现,但是很多情况下,hook 使程序员的生活更轻松。
PyTorch 为每个张量或 nn.Module 对象注册 hook。hook 由对象的向前或向后传播触发。它们具有以下函数签名:
代码语言:javascript复制from torch import nn, Tensor
def module_hook(module: nn.Module, input: Tensor, output: Tensor):
# For nn.Module objects only.
def tensor_hook(grad: Tensor):
# For Tensor objects only.
# Only executed during the *backward* pass!
每个 hook 都可以修改输入、输出或内部模块参数。最常见的是用于调试目的。但我们将看到它们还有很多其他用途。
示例 #1: 模型执行详情
你自己有没有在模型中插入 print 语句,来试图找出错消息的原因?(我当然对此有罪恶感。)这是一个丑陋的调试实践,而且在很多情况下,我们在完成 print 语句时忘记删除它。导致我们的代码看起来很不专业,用户每次使用你的代码都会得到一些奇怪的信息。
以后再也不会了!让我们使用 hook 来调试模型,而不用以任何方式修改它们的实现。例如,假如你想知道每个层输出的形状。我们可以创建一个简单的 wrapper,使用 hook 打印输出形状。
代码语言:javascript复制class VerboseExecution(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
# Register a hook for each layer
for name, layer in self.model.named_children():
layer.__name__ = name
layer.register_forward_hook(
lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
)
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
最大的好处是: 它甚至可以用于不是我们创建的 PyTorch 模块!下面用 ResNet50 和一些虚拟输入来展示一下。
代码语言:javascript复制import torch
from torchvision.models import resnet50
verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)
_ = verbose_resnet(dummy_input)
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])
示例 #2: 特征提取
通常,我们希望从一个预先训练好的网络中生成特性,然后用它们来完成另一个任务(例如分类、相似度搜索等)。使用 hook,我们可以提取特征,而不需要重新创建现有模型或以任何方式修改它。
代码语言:javascript复制from typing import Dict, Iterable, Callable
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self._features = {layer: torch.empty(0) for layer in layers}
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
layer.register_forward_hook(self.save_outputs_hook(layer_id))
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, __, output):
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
return self._features
我们可以像使用其他 PyTorch 模块一样使用特征提取器。用之前同样的虚拟输入,运行得到:
代码语言:javascript复制resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)
print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}
示例 #3: 梯度裁剪
梯度裁剪是处理梯度爆炸的一种著名方法。PyTorch 已经提供了梯度裁剪的工具方法,但是我们也可以很容易地使用 hook 来实现它。其他任何用于梯度裁剪/归一化/修改的方法都可以用同样的方式实现。
代码语言:javascript复制def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
for parameter in model.parameters():
parameter.register_hook(lambda grad: grad.clamp_(-val, val))
return model
这个 hook 是后向传播时触发的,所以这次我们还计算了一个虚拟的损失度量。在执行 loss.backward() 之后,我们可以手动检查参数梯度,以确认它是否正常工作。
代码语言:javascript复制clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()
print(clipped_resnet.fc.bias.grad[:25])
# tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015, 0.0027, 0.0017, -0.0023,
# 0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018, 0.0062,
# 0.0034, -0.0010, 0.0052, 0.0021, 0.0010, 0.0017, -0.0100, 0.0021,
# 0.0020])
「来源:」https://towardsdatascience.com/how-to-use-pytorch-hooks-5041d777f904