ONNX(Open Neural Network Exchange)——开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf二进制格式来序列化模型(protobuf序列化可以参考Netty整合Protobuffer ),可以提供更好的传输性能。官方github:GitHub - onnx/onnx at f2daca5e9b9315a2034da61c662d2a7ac28a9488
ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个onnx模型的其他信息结合在一起,生成一个model,也就是最终的onnx模型。实例如下
创建ONNX模型
创建onnx模型有两种方法,一种是其他框架转换过来,如Pytorch、PaddlePaddle等,从Pytorch转换onnx可以参考模型部署篇 的Pytorch 权重 pth 转换 onnx;PaddlePaddle转换onnx可以参考PaddleOCR使用指南 中的Paddle2ONNX。
我们先来生成一个onnx文件
代码语言:javascript复制import torch
import torch.nn as nn
from torch.autograd import Variable
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.conv(x))
if __name__ == '__main__':
net = Network()
input = Variable(torch.randn([1, 1, 1, 1]))
torch.onnx.export(net, input, 'net.onnx', opset_version=10)
然后来打印这个onnx文件的结构
代码语言:javascript复制import torch
import torch.nn as nn
from torch.autograd import Variable
import onnx
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.conv(x))
if __name__ == '__main__':
# net = Network()
# input = Variable(torch.randn([1, 1, 1, 1]))
# torch.onnx.export(net, input, 'net.onnx', opset_version=10)
print(onnx.load("./net.onnx"))
运行结果
代码语言:javascript复制ir_version: 5
producer_name: "pytorch"
producer_version: "1.12.1"
graph {
node {
input: "input.1"
input: "conv.weight"
input: "conv.bias"
output: "input"
name: "Conv_0"
op_type: "Conv"
attribute {
name: "dilations"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
type: INTS
}
}
node {
input: "input"
output: "4"
name: "Relu_1"
op_type: "Relu"
}
name: "torch_jit"
initializer {
dims: 1
dims: 1
dims: 1
dims: 1
data_type: 1
name: "conv.weight"
raw_data: "