小伙伴们好呀,不久前我们推出了模型部署入门系列教程,受到了大家的一致好评,也收到了很多小伙伴的催更,后续教程正在准备中,将在不久后跟大家见面,敬请期待哦~
今天,我们又将开启新的 TorchScript 解读系列教程,带领大家玩转 PyTorch 模型部署。感兴趣的小伙伴一起往下看吧~
什么是 TorchScript✦
PyTorch 无疑是现在最成功的深度学习训练框架之一,是各种顶会顶刊论文实验的大热门。比起其他的框架,PyTorch 最大的卖点是它对动态网络的支持,比其他需要构建静态网络的框架拥有更低的学习成本。PyTorch 源码 Readme 中还专门为此做了一张动态图:
对研究员而言,PyTorch 能极大地提高想 idea、做实验、发论文的效率,是训练框架中的豪杰,但是它不适合部署。动态建图带来的优势对于性能要求更高的应用场景而言更像是缺点,非固定的网络结构给网络结构分析并进行优化带来了困难,多数参数都能以 Tensor 形式传输也让资源分配变成一件闹心的事。另外由于图是由 python 代码来构建的,一方面部署要依赖 python 环境,另一方面模型也毫无保密性可言。
而 TorchScript 就是为了解决这个问题而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。今天我们先简要地介绍一些 TorchScript 的功能,让大家有一个初步的认识,进阶的解读会陆续推出~
模型转换
作为模型部署的一个范式,通常我们都需要生成一个模型的中间表示(IR),这个 IR 拥有相对固定的图结构,所以更容易优化,让我们看一个例子:
代码语言:javascript复制import torch
from torchvision.models import resnet18
# 使用PyTorch model zoo中的resnet18作为例子
model = resnet18()
model.eval()
# 通过trace的方法生成IR需要一个输入样例
dummy_input = torch.rand(1, 3, 224, 224)
# IR生成
with torch.no_grad():
jit_model = torch.jit.trace(model, dummy_input)
到这里就将 PyTorch 的模型转换成了 TorchScript 的 IR。这里我们使用了 trace 模式来生成 IR,所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图。关于 trace 的过程我们会在未来的分享中进行解读。
那么这个 IR 中到底都有些什么呢?我们可以可视化一下其中的 layer1 看看:
代码语言:javascript复制jit_layer1 = jit_model.layer1
print(jit_layer1.graph)
# graph(%self.6 : __torch__.torch.nn.modules.container.Sequential,
# %4 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=0, device=cpu)):
# %1 : __torch__.torchvision.models.resnet.___torch_mangle_10.BasicBlock = prim::GetAttr[name="1"](%self.6)
# %2 : __torch__.torchvision.models.resnet.BasicBlock = prim::GetAttr[name="0"](%self.6)
# %6 : Tensor = prim::CallMethod[name="forward"](%2, %4)
# %7 : Tensor = prim::CallMethod[name="forward"](%1, %6)
# return (%7)
是不是有点摸不着头脑?TorchScript 有它自己对于 Graph 以及其中元素的定义,对于第一次接触的人来说可能比较陌生,但是没关系,我们还有另一种可视化方式:
代码语言:javascript复制print(jit_layer1.code)
# def forward(self,
# argument_1: Tensor) -> Tensor:
# _0 = getattr(self, "1")
# _1 = (getattr(self, "0")).forward(argument_1, )
# return (_0).forward(_1, )
没错,就是代码!TorchScript 的 IR 是可以还原成 python 代码的,如果你生成了一个 TorchScript 模型并且想知道它的内容对不对,那么可以通过这样的方式来做一些简单的检查。
刚才的例子中我们使用 trace 的方法生成 IR。除了 trace 之外,PyTorch 还提供了另一种生成 TorchScript 模型的方法:script。这种方式会直接解析网络定义的 python 代码,生成抽象语法树 AST,因此这种方法可以解决一些 trace 无法解决的问题,比如对 branch/loop 等数据流控制语句的建图。script 方式的建图有很多有趣的特性,会在未来的分享中做专题分析,敬请期待。
模型优化
聪明的同学可能发现了,上面的可视化中只有 resnet18 里 forward 的部分,其中的子模块信息是不是丢失了呢?如果没有丢失,那么怎么样才能确定子模块的内容是否正确呢?别担心,还记得我们说过 TorchScript 支持对网络的优化吗,这里我们就可以用一个 pass 解决这个问题:
代码语言:javascript复制# 调用inline pass,对graph做变换
torch._C._jit_pass_inline(jit_layer1.graph)
print(jit_layer1.code)
# def forward(self,
# argument_1: Tensor) -> Tensor:
# _0 = getattr(self, "1")
# _1 = getattr(self, "0")
# _2 = _1.bn2
# _3 = _1.conv2
# _4 = _1.bn1
# input = torch._convolution(argument_1, _1.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
# _5 = _4.running_var
# _6 = _4.running_mean
# _7 = _4.bias
# input0 = torch.batch_norm(input, _4.weight, _7, _6, _5, False, 0.10000000000000001, 1.0000000000000001e-05, True)
# input1 = torch.relu_(input0)
# input2 = torch._convolution(input1, _3.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
# _8 = _2.running_var
# _9 = _2.running_mean
# _10 = _2.bias
# out = torch.batch_norm(input2, _2.weight, _10, _9, _8, False, 0.10000000000000001, 1.0000000000000001e-05, True)
# input3 = torch.add_(out, argument_1, alpha=1)
# input4 = torch.relu_(input3)
# _11 = _0.bn2
# _12 = _0.conv2
# _13 = _0.bn1
# input5 = torch._convolution(input4, _0.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
# _14 = _13.running_var
# _15 = _13.running_mean
# _16 = _13.bias
# input6 = torch.batch_norm(input5, _13.weight, _16, _15, _14, False, 0.10000000000000001, 1.0000000000000001e-05, True)
# input7 = torch.relu_(input6)
# input8 = torch._convolution(input7, _12.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
# _17 = _11.running_var
# _18 = _11.running_mean
# _19 = _11.bias
# out0 = torch.batch_norm(input8, _11.weight, _19, _18, _17, False, 0.10000000000000001, 1.0000000000000001e-05, True)
# input9 = torch.add_(out0, input4, alpha=1)
# return torch.relu_(input9)
这里我们就能看到卷积、batch_norm、relu 等熟悉的算子了。
上面代码中我们使用了一个名为 inline 的 pass,将所有子模块进行内联,这样我们就能看见更完整的推理代码。pass 是一个来源于编译原理的概念,一个 TorchScript 的 pass 会接收一个图,遍历图中所有元素进行某种变换,生成一个新的图。我们这里用到的 inline 起到的作用就是将模块调用展开,尽管这样做并不能直接影响执行效率,但是它其实是很多其他 pass 的基础。PyTorch 中定义了非常多的 pass 来解决各种优化任务,未来我们会做一些更详细的介绍。
序列化
不管是哪种方法创建的 TorchScript 都可以进行序列化,比如:
代码语言:javascript复制# 将模型序列化
jit_model.save('jit_model.pth')
# 加载序列化后的模型
jit_model = torch.jit.load('jit_model.pth')
序列化后的模型不再与 python 相关,可以被部署到各种平台上。
PyTorch 提供了可以用于 TorchScript 模型推理的 c API,序列化后的模型终于可以不依赖 python 进行推理了:
代码语言:javascript复制// 加载生成的torchscript模型
auto module = torch::jit::load('jit_model.pth');
// 根据任务需求读取数据
std::vector<torch::jit::IValue> inputs = ...;
// 计算推理结果
auto output = module.forward(inputs).toTensor();
与其他组件的关系✦
与 torch.onnx 的关系
ONNX 是业界广泛使用的一种神经网络中间表示,PyTorch 自然也对 ONNX 提供了支持。torch.onnx.export 函数可以帮助我们把 PyTorch 模型转换成 ONNX 模型,这个函数会使用 trace 的方式记录 PyTorch 的推理过程。聪明的同学可能已经想到了,没错,ONNX 的导出,使用的正是 TorchScript 的 trace 工具。具体步骤如下:
1. 使用 trace 的方式先生成一个 TorchScipt 模型,如果你转换的本身就是 TorchScript 模型,则可以跳过这一步。
2. 使用许多 pass 对 1 中生成的模型进行变换,其中对 ONNX 导出最重要的一个 pass 就是ToONNX,这个 pass 会进行一个映射,将 TorchScript 中 prim、aten 空间下的算子映射到onnx空间下的算子。
3. 使用 ONNX 的 proto 格式对模型进行序列化,完成 ONNX 的导出。
关于 ONNX 导出的实现以及算子映射的方式将会在未来的分享中详细展开。
与 torch.fx 的关系
PyTorch1.9 开始添加了 torch.fx 工具,根据官方的介绍,它由符号追踪器 (symbolic tracer),中间表示(IR), Python 代码生成 (Python code generation) 等组件组成,实现了 python->python 的翻译。是不是和 TorchScript 看起来有点像?
其实他们之间联系不大,可以算是互相垂直的两个工具,为解决两个不同的任务而诞生。
TorchScript 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。
FX 的主要用途是进行 python->python 的翻译,它的 IR 中节点类型更简单,比如函数调用、属性提取等等,这样的 IR 学习成本更低更容易编辑。使用 FX 来编辑图通常是为了实现某种特定功能,比如给模型插入量化节点等,避免手动编辑网络造成的重复劳动。
这两个工具可以同时使用,比如使用 FX 工具编辑模型来让训练更便利、功能更强大;然后用 TorchScript 将模型加速部署到特定平台。
希望通过以上的分享,大家对 TorchScript 有了一个初步的认识,未来我们将会为大家带来更进阶的解读,欢迎大家持续关注。另外值得分享的是,MMDeploy 已开始对 TorchScript 提供支持,