Python ONNX 模型转换、加载、简化、推断

2024-02-05 08:31:47 浏览数 (1)

Microsoft 和合作伙伴社区创建了 ONNX 作为表示机器学习模型的开放标准。 本文记录 Python 下 pytorch 模型转换 ONNX 的相关内容。

简介

  • 官网:https://github.com/microsoft/onnxruntime

ONNX Runtime是一个跨平台的推理和训练机器学习加速器。

在 Pytorch 框架中训练好模型后,在部署时可以转成 onnx,再进行下一步部署。

模型转换

核心代码:

  • 生成 onnx 模型: torch.onnx.export
  • 简化 onnx 模型: onnxsim.simplify
代码语言:text复制
import torch
import onnxsim
import onnx

def export_to_onnx(model, output_path, input_shape, input_name, output_names):
    dummy_input = torch.rand(1, *input_shape)

    model.eval()

    temp_dict = dict()

    temp_onnx_path = output_path.replace('.onnx', '_temp.onnx')

    torch.onnx.export(model, 						# pytorch 模型
                    (dummy_input, 'ALL'),  			# 可以输入 tuple 
                    temp_onnx_path, 				# 输出 onnx 模型路径
                    verbose=False, 					# 聒噪
                    opset_version=11,				# onnx 版本
                    export_params=True, 			# 一个指示是否导出模型参数(权重)以及模型架构的标志。
                    do_constant_folding=True,   	# 一个指示是否在导出过程中折叠常量节点的标志
                    input_names=[input_name],		# 输入节点名称列表(可选)
                    output_names=output_names		# 输出节点名称列表(可选)
                    )

    input_data = {'image': dummy_input.cpu().numpy()}
    model_sim, flag = onnxsim.simplify(temp_onnx_path, input_data=input_data) # 简化 onnx 

    if flag:
        onnx.save(model_sim, output_path)
        print(f"simplify onnx model successfully !")
    else:
        print(f"simplify onnx model failed !!!")

  • 注意: torch.onnx.export 输入伪数据可以支持字符串,但是在 onnx 模型中仅会记录张量流转的路径,字符串、分支逻辑一般不会保存。
模型检查

onnx 加载模型后可以检测是否合法。

代码语言:text复制
# onnx check
onnx_model = onnx.load(onnx_model_path)
try:
    onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
    print('The model is invalid: %s' % e)
else:
    print('The model is valid!')

加载、运行 ONNX 模型

ONNXruntime 安装:

代码语言:text复制
pip install onnxruntime       # CPU build
pip install onnxruntime-gpu   # GPU build

推理代码:

代码语言:text复制
import onnxruntime

session = onnxruntime.InferenceSession("path to model")
session.get_modelmeta()
results = session.run(["output1", "output2"], {"input1": indata1, "input2": indata2})
results = session.run([], {"input1": indata1, "input2": indata2}) 

可以对比 onnx 模型结果与 pytorch 模型结果的差异来对转换结果进行验证。

参考资料

  • https://www.bilibili.com/read/cv10539136/
  • https://blog.csdn.net/hjxu2016/article/details/118419488

0 人点赞