import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self, return_b=False):
super().__init__()
self.return_b = return_b
def forward(self, x):
a = x 2
if self.return_b: #属于静态控制
b = x 3
return a, b
return a
model = MyModule(return_b=True)
# Will work 成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))
# Will fail 失败
scripted = torch.jit.script(model)
总结:控制流是静态的,torch.jit.trace将正常工作
动态控制
if x[0] == 4: x = 1 is a dynamic control flow.
代码语言:javascript复制
model: nn.Sequential = ...
for m in model: # 动态控制
x = m(x)
输入和输出有丰富类型的模型需要格外注意
detectron2模型的jit转化
代码语言:javascript复制
outputs = model(inputs) # inputs/outputs are rich structure
# torch.jit.trace(model, inputs) # FAIL! unsupported format
adapter = TracingAdapter(model, inputs)
traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Can now trace the model
# Traced model can only produce flattened outputs (tuple of tensors):
flattened_outputs = traced(*adapter.flattened_inputs)
# Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
new_outputs = adapter.outputs_schema(flattened_outputs)
QA
JIT要求python的代码要是低级的;详情 因为更多动态高级的python语法,jit不支持.具体哪些支持哪些没支持官方也没有详细的列表; JIT should not force users to write ugly code #48108