Microsoft 和合作伙伴社区创建了 ONNX 作为表示机器学习模型的开放标准。 本文记录 Python 下 pytorch 模型转换 ONNX 的相关内容。
简介
- 官网:https://github.com/microsoft/onnxruntime
ONNX Runtime是一个跨平台的推理和训练机器学习加速器。
在 Pytorch 框架中训练好模型后,在部署时可以转成 onnx,再进行下一步部署。
模型转换
核心代码:
- 生成 onnx 模型:
torch.onnx.export
- 简化 onnx 模型:
onnxsim.simplify
:
import torch
import onnxsim
import onnx
def export_to_onnx(model, output_path, input_shape, input_name, output_names):
dummy_input = torch.rand(1, *input_shape)
model.eval()
temp_dict = dict()
temp_onnx_path = output_path.replace('.onnx', '_temp.onnx')
torch.onnx.export(model, # pytorch 模型
(dummy_input, 'ALL'), # 可以输入 tuple
temp_onnx_path, # 输出 onnx 模型路径
verbose=False, # 聒噪
opset_version=11, # onnx 版本
export_params=True, # 一个指示是否导出模型参数(权重)以及模型架构的标志。
do_constant_folding=True, # 一个指示是否在导出过程中折叠常量节点的标志
input_names=[input_name], # 输入节点名称列表(可选)
output_names=output_names # 输出节点名称列表(可选)
)
input_data = {'image': dummy_input.cpu().numpy()}
model_sim, flag = onnxsim.simplify(temp_onnx_path, input_data=input_data) # 简化 onnx
if flag:
onnx.save(model_sim, output_path)
print(f"simplify onnx model successfully !")
else:
print(f"simplify onnx model failed !!!")
- 注意:
torch.onnx.export
输入伪数据可以支持字符串,但是在 onnx 模型中仅会记录张量流转的路径,字符串、分支逻辑一般不会保存。
模型检查
onnx 加载模型后可以检测是否合法。
代码语言:text复制# onnx check
onnx_model = onnx.load(onnx_model_path)
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print('The model is invalid: %s' % e)
else:
print('The model is valid!')
加载、运行 ONNX 模型
ONNXruntime
安装:
pip install onnxruntime # CPU build
pip install onnxruntime-gpu # GPU build
推理代码:
代码语言:text复制import onnxruntime
session = onnxruntime.InferenceSession("path to model")
session.get_modelmeta()
results = session.run(["output1", "output2"], {"input1": indata1, "input2": indata2})
results = session.run([], {"input1": indata1, "input2": indata2})
可以对比 onnx 模型结果与 pytorch 模型结果的差异来对转换结果进行验证。
参考资料
- https://www.bilibili.com/read/cv10539136/
- https://blog.csdn.net/hjxu2016/article/details/118419488