将Pytorch模型移植到C++详细教程(附代码演练)

2023-08-29 08:19:50 浏览数 (1)

说明

在本文中,我们将看到如何将Pytorch模型移植到C 中。Pytorch通常用于研究和制作新模型以及系统的原型。该框架很灵活,因此易于使用。主要的问题是我们如何将Pytorch模型移植到更适合的格式C 中,以便在生产中使用。

我们将研究不同的管道,如何将PyTrac模型移植到C 中,并使用更合适的格式应用到生产中。

1) TorchScript脚本

2) 开放式神经网络交换

3) TFLite(Tensorflow Lite)

TorchScript脚本

TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以在高性能环境(例如C )中运行。它有助于创建可序列化和可优化的模型。在Python中训练这些模型之后,它们可以在Python或C 中独立运行。因此,可以使用Python轻松地在PyTorch中训练模型,然后通过torchscript将模型导出到无法使用Python的生产环境中。它基本上提供了一个工具来捕获模型的定义。

跟踪模块:

代码语言:javascript复制
class DummyCell(torch.nn.Module):    def __init__(self):        super(DummyCell, self).__init__()        self.linear = torch.nn.Linear(4, 4)    def forward(self, x):        out = self.linear(x)        return out
dummy_cell = DummyCell()x =  torch.rand(2, 4)traced_cell = torch.jit.trace(dummy_cell, (x))
# Print Traced Graphprint(traced_cell.graph)
# Print Traced Codeprint(traced_cell.code)

在这里,torchscript调用了模块,将执行的操作记录到称为图的中间表示中。traced_cell.graph提供了一个非常低级的表示,并且图形中的大部分信息最终对用户没有用处。traced_cell.code 提供了更多的python语法解释代码。

上述代码的输出(traced_cell.graph和traced_cell.code) :

代码语言:javascript复制
graph(%self.1 : __torch__.DummyCell,      %input : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)):   : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)   : Tensor = prim::CallMethod[name="forward"](, %input)  return ()def forward(self,    input: Tensor) -> Tensor:  return (self.linear).forward(input, )
TorchScript的优点

1) TorchScript代码可以在自己的解释器中调用。所保存的图形也可以在C 中加载用于生产。

2) TorchScript为我们提供了一种表示,在这种表示中,我们可以对代码进行编译器优化,以提供更高效的执行。

ONNX(开放式神经网络交换)

ONNX是一种开放格式,用于表示机器学习模型。ONNX定义了一组通用的操作符、机器学习和深度学习模型的构建块以及一种通用的文件格式,使AI开发人员能够将模型与各种框架、工具、运行时和编译器一起使用。它定义了一个可扩展的计算图模型,以及内置操作符和标准数据类型的定义。

可以使用以下代码将上述DummyCell模型导出到onnx:

代码语言:javascript复制
torch.onnx.export(dummy_cell, x, "dummy_model.onnx", export_params=True, verbose=True)

输出:

代码语言:javascript复制
graph(%input : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu),      %linear.weight : Float(4, 4, strides=[4, 1], requires_grad=1, device=cpu),      %linear.bias : Float(4, strides=[1], requires_grad=1, device=cpu)):  %3 : Float(2, 4, strides=[4, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%input, %linear.weight, %linear.bias)  return (%3)

它将模型保存到文件名“dummy_model.onnx“中,可以使用python模块onnx加载该模型。为了在python中进行推理,可以使用ONNX运行时。ONNX运行时是一个针对ONNX模型的以性能为中心的引擎,它可以跨多个平台和硬件高效地进行推断。查看此处了解有关性能的更多详细信息。

https://cloudblogs.microsoft.com/opensource/2019/05/22/onnx-runtime-machine-learning-inferencing-0-4-release/

C 中的推理

为了从C 中执行ONNX模型,首先,我们必须使用tract库在Rust中编写推理代码。现在,我们有了用于推断ONNX模型的rust库。我们现在可以使用cbindgen将rust库导出为公共C头文件。

tract:https://github.com/sonos/tract

cbindgen:https://github.com/eqrion/cbindgen

现在,此头文件以及从Rust生成的共享库或静态库可以包含在C 中以推断ONNX模型。在从rust生成共享库的同时,我们还可以根据不同的硬件提供许多优化标志。Rust也可以轻松实现针对不同硬件类型的交叉编译。

Tensorflow Lite

Tensorflow Lite是一个用于设备上推理的开源深度学习框架。它是一套帮助开发人员在移动、嵌入式和物联网设备上运行Tensorflow模型的工具。它使在设备上的机器学习推理具有低延迟和小二进制大小。它有两个主要组成部分:

1) Tensorflow Lite解释器:它在许多不同的硬件类型上运行特别优化的模型,包括移动电话、嵌入式Linux设备和微控制器。

2) Tensorflow Lite转换器:它将Tensorflow模型转换为一种有效的形式,供解释器使用。

将PyTorch模型转换为TensorFlow lite的主管道如下:

1) 构建PyTorch模型

2) 以ONNX格式导模型

3) 将ONNX模型转换为Tensorflow(使用ONNX tf)

在这里,我们可以使用以下命令将ONNX模型转换为TensorFlow protobuf模型:

代码语言:javascript复制
!onnx-tf convert -i "dummy_model.onnx" -o  'dummy_model_tensorflow'

4) 将Tensorflow模型转换为Tensorflow Lite(tflite)

TFLITE模型(Tensorflow Lite模型)现在可以在C 中使用。这里请参考如何在C 中对TFLITE模型进行推理。

https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_c

尾注

我希望你觉得这篇文章有用。我们试图简单地解释一下,我们可以用不同的方式将PyTorch训练过的模型部署到生产中。

参考文献

1)TorchScript简介:https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

2)在C 中加载TorchScript模型:https://pytorch.org/tutorials/advanced/cpp_export.html

3)将Pytorch模型导出到ONNX:https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

4)Rust中的Tract神经网络推理工具包:https://github.com/sonos/tract

5)在C 中的TfLite模型上运行推理:https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_c

6)Colab - 在Android设备上进行Pytorch训练的模型:https://colab.research.google.com/drive/1MwFVErmqU9Z6cTDWLoTvLgrAEBRZUEsA

0 人点赞