torch.jit.trace与torch.jit.script的区别

2022-08-07 12:37:03 浏览数 (1)

文章目录

  • 术语
  • 什么时候用torch.jit.trace(结论:首选)
    • 优点
  • 什么时候用torch.jit.script(结论:必要时)
  • 错误举例
    • 动态控制
    • 输入和输出有丰富类型的模型需要格外注意
  • QA
  • 解决错误的方法

术语

  1. Tochscript:狭义概念导出图形的表示/格式;广义概念为导出模型的方法;
  2. (Torch)Scriptable:可以用torch.jit.script导出模型
  3. Traceable:可以用torch.jit.trace导出模型

什么时候用torch.jit.trace(结论:首选)

  1. torch.jit.trace一种导出方法;它运行具有某些张量输入的模型,并“跟踪/记录”所有执行到图形中的操作。
  2. 在模型内部的数据类型只有张量,且没有for if while等控制流,选择torch.jit.trace
  3. 支持python的预处理和动态行为;
  4. torch.jit.trace编译function并返回一个可执行文件,该可执行文件将使用即时编译进行优化。
  5. 大项目优先选择torch.jit.trace,特别是是图像检测和分割的算法;

优点

  1. 不会损害代码质量;
  2. 2.它的主要限制可以通过与torch.jit.script混合来解决

什么时候用torch.jit.script(结论:必要时)

  1. 定义:一种模型导出方法,其实编译python的模型源码,得到可执行的图;
  2. 在模型内部的数据类型只有张量,且没有for if while等控制流,也可以选择torch.jit.script
  3. 不支持python的预处理和动态行为;
  4. 必须做一下类型标注;
  5. torch.jit.script在编译function或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码。

错误举例

代码语言:javascript复制
import torch
from torch import nn


class MyModule(nn.Module):
    def __init__(self, return_b=False):
        super().__init__()
        self.return_b = return_b

    def forward(self, x):
        a = x   2
        if self.return_b:  #属于静态控制
            b = x   3
            return a, b
        return a


model = MyModule(return_b=True)

# Will work  成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))

# Will fail 失败
scripted = torch.jit.script(model)
  • 总结:控制流是静态的,torch.jit.trace将正常工作

动态控制

  1. if x[0] == 4: x = 1 is a dynamic control flow.
代码语言:javascript复制
model: nn.Sequential = ...
for m in model:  # 动态控制
  x = m(x) 

输入和输出有丰富类型的模型需要格外注意

  • detectron2模型的jit转化
代码语言:javascript复制
outputs = model(inputs)   # inputs/outputs are rich structure
# torch.jit.trace(model, inputs)  # FAIL! unsupported format
adapter = TracingAdapter(model, inputs)
traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # Can now trace the model

# Traced model can only produce flattened outputs (tuple of tensors):
flattened_outputs = traced(*adapter.flattened_inputs)
# Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
new_outputs = adapter.outputs_schema(flattened_outputs)

QA

    1. JIT要求python的代码要是低级的;详情 因为更多动态高级的python语法,jit不支持.具体哪些支持哪些没支持官方也没有详细的列表; JIT should not force users to write ugly code #48108
    1. 错误示例:动态控制流:对于动态控制流torch.jit.trace只会编译一个分支,在其他分支处理的时候会报错;
代码语言:javascript复制
def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.trace(f, torch.tensor(3))
print(m.code) # 可以打印出trace的情况
#--------------------------------------------
def f(x: Tensor) -> Tensor:
  return torch.sqrt(x)
    1. 错误示例:将变量视为常量
代码语言:javascript复制
import torch

a, b = torch.rand(1), torch.rand(2)
print(a,b)

def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))
result = torch.jit.trace(f1, a)(b)
print(result)

result =torch.jit.trace(f2, a)(b) # TracerWarning
print(result) #

print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
  • 错误示例:获取设备

解决错误的方法

    1. 严格消除警告信息,才C 运行的时候会报错
    1. 局部单元测试
    • 单元测试一样要做在导出模型后,这样避免在应用模型的时候(C 运行)出错;
代码语言:javascript复制
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
    1. 避免非必要的动态控制,例如:
代码语言:javascript复制
if x.numel() > 0:
  output = self.layers(x)
else:
  output = torch.zeros((0, C, H, W))  # Create empty outputs

0 人点赞