如何使用 PyTorch Hook

2020-11-10 11:07:59 浏览数 (1)

作者:Frank Odom 编译:McGL

什么是钩子(Hook)?

Hook 实际上在软件工程中相当常见,并不是 PyTorch 所独有的。一般来说,“hook”是在特定事件之后自动执行的函数。在现实世界中,你可能遇到过的一些 hook 的例子:

  • 网站在你访问 N 个不同页面后会显示一个广告。
  • 你的账户有资金入账时,银行 app 发送通知消息。
  • 当周围光线减弱时,手机屏幕亮度会变暗。

这些事情没有 hook 也可以实现,但是很多情况下,hook 使程序员的生活更轻松。

PyTorch 为每个张量或 nn.Module 对象注册 hook。hook 由对象的向前或向后传播触发。它们具有以下函数签名:

代码语言:javascript复制
from torch import nn, Tensor

def module_hook(module: nn.Module, input: Tensor, output: Tensor):
    # For nn.Module objects only.
    
def tensor_hook(grad: Tensor):
    # For Tensor objects only.
    # Only executed during the *backward* pass!

每个 hook 都可以修改输入、输出或内部模块参数。最常见的是用于调试目的。但我们将看到它们还有很多其他用途。

示例 #1: 模型执行详情

你自己有没有在模型中插入 print 语句,来试图找出错消息的原因?(我当然对此有罪恶感。)这是一个丑陋的调试实践,而且在很多情况下,我们在完成 print 语句时忘记删除它。导致我们的代码看起来很不专业,用户每次使用你的代码都会得到一些奇怪的信息。

以后再也不会了!让我们使用 hook 来调试模型,而不用以任何方式修改它们的实现。例如,假如你想知道每个层输出的形状。我们可以创建一个简单的 wrapper,使用 hook 打印输出形状。

代码语言:javascript复制
class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

最大的好处是: 它甚至可以用于不是我们创建的 PyTorch 模块!下面用 ResNet50 和一些虚拟输入来展示一下。

代码语言:javascript复制
import torch
from torchvision.models import resnet50

verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)

_ = verbose_resnet(dummy_input)
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])

示例 #2: 特征提取

通常,我们希望从一个预先训练好的网络中生成特性,然后用它们来完成另一个任务(例如分类、相似度搜索等)。使用 hook,我们可以提取特征,而不需要重新创建现有模型或以任何方式修改它。

代码语言:javascript复制
from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features

我们可以像使用其他 PyTorch 模块一样使用特征提取器。用之前同样的虚拟输入,运行得到:

代码语言:javascript复制
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)

print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}

示例 #3: 梯度裁剪

梯度裁剪是处理梯度爆炸的一种著名方法。PyTorch 已经提供了梯度裁剪的工具方法,但是我们也可以很容易地使用 hook 来实现它。其他任何用于梯度裁剪/归一化/修改的方法都可以用同样的方式实现。

代码语言:javascript复制
def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(-val, val))
    
    return model

这个 hook 是后向传播时触发的,所以这次我们还计算了一个虚拟的损失度量。在执行 loss.backward() 之后,我们可以手动检查参数梯度,以确认它是否正常工作。

代码语言:javascript复制
clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()

print(clipped_resnet.fc.bias.grad[:25])
# tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015,  0.0027,  0.0017, -0.0023,
#          0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018,  0.0062,
#          0.0034, -0.0010,  0.0052,  0.0021,  0.0010,  0.0017, -0.0100,  0.0021,
#          0.0020])

「来源:」https://towardsdatascience.com/how-to-use-pytorch-hooks-5041d777f904

0 人点赞