ONNX整理

2022-10-31 10:35:55 浏览数 (1)

ONNX(Open Neural Network Exchange)——开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf二进制格式来序列化模型(protobuf序列化可以参考Netty整合Protobuffer ),可以提供更好的传输性能。官方github:GitHub - onnx/onnx at f2daca5e9b9315a2034da61c662d2a7ac28a9488

ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个onnx模型的其他信息结合在一起,生成一个model,也就是最终的onnx模型。实例如下

创建ONNX模型

创建onnx模型有两种方法,一种是其他框架转换过来,如Pytorch、PaddlePaddle等,从Pytorch转换onnx可以参考模型部署篇 的Pytorch 权重 pth 转换 onnx;PaddlePaddle转换onnx可以参考PaddleOCR使用指南 中的Paddle2ONNX

我们先来生成一个onnx文件

代码语言:javascript复制
import torch
import torch.nn as nn
from torch.autograd import Variable

class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

if __name__ == '__main__':

    net = Network()
    input = Variable(torch.randn([1, 1, 1, 1]))
    torch.onnx.export(net, input, 'net.onnx', opset_version=10)

然后来打印这个onnx文件的结构

代码语言:javascript复制
import torch
import torch.nn as nn
from torch.autograd import Variable
import onnx

class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

if __name__ == '__main__':

    # net = Network()
    # input = Variable(torch.randn([1, 1, 1, 1]))
    # torch.onnx.export(net, input, 'net.onnx', opset_version=10)

    print(onnx.load("./net.onnx"))

运行结果

代码语言:javascript复制
ir_version: 5
producer_name: "pytorch"
producer_version: "1.12.1"
graph {
  node {
    input: "input.1"
    input: "conv.weight"
    input: "conv.bias"
    output: "input"
    name: "Conv_0"
    op_type: "Conv"
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  node {
    input: "input"
    output: "4"
    name: "Relu_1"
    op_type: "Relu"
  }
  name: "torch_jit"
  initializer {
    dims: 1
    dims: 1
    dims: 1
    dims: 1
    data_type: 1
    name: "conv.weight"
    raw_data: "14317B?"
  }
  initializer {
    dims: 1
    data_type: 1
    name: "conv.bias"
    raw_data: "344n26277"
  }
  input {
    name: "input.1"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
  output {
    name: "4"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
}
opset_import {
  version: 10
}

首先是onnx版本,我们这里为ir_version: 5,然后是从什么框架转换过来的,这里是从Pytorch转换过来的producer_name: "pytorch",版本号是producer_version: "1.12.1"。

然后是graph->node,第一个node是2D卷积核,第二个node是Relu激活函数。node中的op_type是节点类型,所有类型可以参考https://github.com/onnx/onnx/blob/f2daca5e9b9315a2034da61c662d2a7ac28a9488/docs/Operators.md。name是节点名称,它跟op_type是不同的。attribute是节点属性,在Conv_0中就是2D卷积的各种属性,比如"group"是分组卷积,"kernel_shape"是卷积核尺寸等等。initializer是初始化,包含了权重初始化和偏置初始化。input是输入,包含输入的形状,output是输出,包含输出的形状。opset_import为当前的模型文件所依赖的算子domain和版本。

最后我们来检查该模型,运行是没有问题的。

代码语言:javascript复制
import torch
import torch.nn as nn
from torch.autograd import Variable
import onnx

class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

if __name__ == '__main__':

    # net = Network()
    # input = Variable(torch.randn([1, 1, 1, 1]))
    # torch.onnx.export(net, input, 'net.onnx', opset_version=10)

    # print(onnx.load("./net.onnx"))
    model = onnx.load("./net.onnx")
    onnx.checker.check_model(model)

另外一种就是用onnx自己的方法创建onnx模型。

代码语言:javascript复制
import onnx
import onnx.helper as helper
import numpy as np

if __name__ == '__main__':

    input = helper.make_tensor_value_info(name='input', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244])
    output = helper.make_tensor_value_info(name='output', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244])
    weight = helper.make_tensor(name='weight', data_type=onnx.TensorProto.FLOAT, dims=[3, 3, 1, 1], vals=np.random.randn(3, 3, 1, 1))
    bias = helper.make_tensor(name='bias', data_type=onnx.TensorProto.FLOAT, dims=[3], vals=np.random.randn(3))
    node = helper.make_node(op_type='Conv', inputs=['input', 'weight', 'bias'], outputs=['output'], kernel_shape=[1, 1], strides=[1, 1],
                            group=1, pads=[0, 0, 0, 0])
    graph = helper.make_graph(name='graph', nodes=[node], inputs=[input], outputs=[output], initializer=[weight, bias])

    model = helper.make_model(graph)
    onnx.checker.check_model(model)
    print(model)
    onnx.save_model(model, 'model.onnx')

运行结果

代码语言:javascript复制
ir_version: 8
graph {
  node {
    input: "input"
    input: "weight"
    input: "bias"
    output: "output"
    op_type: "Conv"
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  name: "graph"
  initializer {
    dims: 3
    dims: 3
    dims: 1
    dims: 1
    data_type: 1
    float_data: 0.45837152004241943
    float_data: 0.10209446400403976
    float_data: 1.0382566452026367
    float_data: -0.09292714297771454
    float_data: 1.58871591091156
    float_data: 0.3746287226676941
    float_data: -0.35588690638542175
    float_data: 0.7165427207946777
    float_data: 0.10244251787662506
    name: "weight"
  }
  initializer {
    dims: 3
    data_type: 1
    float_data: -0.36782845854759216
    float_data: 2.305680513381958
    float_data: -0.13051341474056244
    name: "bias"
  }
  input {
    name: "input"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
}
opset_import {
  version: 17
}

动态设置batch_size

在上面的结果中,我们可以看到input的维度都是固定值1,3,244,244,现在我们要来改变这个固定值为可以动态输入的值。我们先将模型给运行起来。

代码语言:javascript复制
import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    sess = onnxruntime.InferenceSession('./model.onnx')
    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
[array([[[[-7.4062514e-01,  2.5951520e-01, -3.5876265e-01, ...,
          -2.0852795e 00, -1.0078001e-01, -4.9386564e-01],
         [-6.0379845e-01,  9.2830718e-01, -4.2096943e-02, ...,
          -1.9139317e-01,  1.6547061e 00,  1.4468774e 00],
         [ 2.6494553e 00, -9.6209788e-01,  8.2099646e-02, ...,
          -1.5899204e 00, -1.3295431e 00,  1.1512205e-01],
         ...,
         [ 1.4135087e 00,  6.4077592e-01, -5.6514746e-01, ...,
           2.1367333e 00,  2.6012421e 00, -1.3565271e 00],
         [ 6.9879985e-01,  1.2454928e 00,  6.0045028e-01, ...,
          -6.1302024e-01, -4.3026954e-02, -7.2975445e-01],
         [-2.1020520e 00, -1.2499222e 00, -9.3896770e-01, ...,
          -4.6129468e-01,  5.4580927e-01, -7.4599540e-01]],

        [[ 5.6230574e 00,  2.6218858e 00,  7.1071947e-01, ...,
           3.6510468e-02,  2.5771899e 00,  2.0060635e 00],
         [ 4.2759910e 00,  2.5261867e 00,  1.0787441e 00, ...,
           3.3373690e 00,  4.5090003e 00,  3.5535808e 00],
         [ 1.6522924e 00,  1.5206050e 00,  3.6905313e 00, ...,
           1.5963824e 00,  5.1875353e-02,  3.4248161e 00],
         ...,
         [ 1.0295208e 00,  4.5397396e 00,  4.3366423e 00, ...,
           1.2408195e 00,  3.1239326e 00,  1.7476916e 00],
         [ 9.7080982e-01,  1.9692242e 00,  3.7690439e 00, ...,
          -1.6770840e-01,  1.1871569e 00,  4.2690439e 00],
         [ 4.4730301e 00,  1.5573008e 00,  7.2707558e 00, ...,
           4.7898588e 00,  2.9080591e 00,  7.2294927e-01]],

        [[ 1.3509388e 00, -1.9160898e-01, -1.3318433e 00, ...,
          -1.0562456e 00,  1.0652192e-01, -4.4993240e-01],
         [ 7.3106253e-01, -4.0714890e-03, -5.3625894e-01, ...,
          -6.2385768e-02,  3.3464909e-01,  2.7667671e-01],
         [-7.8517151e-01, -7.1918708e-01,  5.5366117e-01, ...,
          -4.7982591e-01, -1.0322813e 00,  8.0901492e-01],
         ...,
         [-1.0904443e 00,  4.7577775e-01,  9.5288980e-01, ...,
          -9.8435390e-01, -5.1632053e-01,  2.4581529e-01],
         [-6.4627886e-01, -9.8449951e-01,  1.6146483e-01, ...,
          -1.2009792e 00, -7.3006052e-01,  7.0891309e-01],
         [ 1.3855783e 00, -8.9338100e-01,  2.4704218e 00, ...,
           6.8950468e-01,  1.7709453e-01, -7.6678610e-01]]]],
      dtype=float32)]

现在我们来把输入的batch_size调整成2

代码语言:javascript复制
import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    inputs = model.graph.input
    outputs = model.graph.output
    for i in inputs:
        i.type.tensor_type.shape.dim[0].dim_value = 2
    for o in outputs:
        o.type.tensor_type.shape.dim[0].dim_value = 2
    onnx.checker.check_model(model)
    onnx.save_model(model, 'dynamic_model.onnx')
    sess = onnxruntime.InferenceSession('./dynamic_model.onnx')
    input = np.random.randn(2, 3, 244, 244).astype(np.float32)
    print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
[array([[[[-2.10871696e-02, -1.32871771e 00, -1.22335061e-01, ...,
           4.77721721e-01, -4.10815179e-01, -1.37511027e 00],
         [-1.09181249e 00, -2.02204657e 00,  1.54176390e 00, ...,
          -1.88722742e 00, -2.00726366e 00,  4.24929589e-01],
         [-7.14685619e-01,  3.82802397e-01, -2.30412316e 00, ...,
           7.06834435e-01, -2.36892438e 00, -2.11947155e 00],
         ...,
         [-9.51929450e-01, -1.22408187e 00, -1.35213524e-01, ...,
           5.55669367e-02, -5.95110297e-01, -2.15206313e 00],
         [ 8.90325904e-01, -1.89442956e 00,  8.34725618e-01, ...,
          -2.34860206e 00, -1.09965193e 00, -4.96994108e-01],
         [ 1.56639183e 00,  5.97145438e-01, -5.28750658e-01, ...,
           5.77995658e-01, -1.46205699e 00,  2.80693078e 00]],

        [[ 3.09728765e 00, -1.42589498e 00,  7.58970022e-01, ...,
           3.48910093e 00,  2.95971513e 00,  1.96736765e 00],
         [ 2.76622701e 00,  1.58350587e 00,  2.41761374e 00, ...,
           3.68322372e-01,  3.05963039e-01,  2.99718475e 00],
         [-1.75151324e 00,  2.79870439e 00, -3.03543806e-01, ...,
           2.86027908e 00,  1.78771615e 00,  4.79569674e 00],
         ...,
         [ 1.30739605e 00,  1.83714139e 00,  4.55001736e 00, ...,
           1.44066858e 00,  4.87037659e 00,  2.10291076e 00],
         [ 9.44083452e-01, -8.11131001e-02,  2.89160919e 00, ...,
           2.34788847e 00,  1.95467031e 00,  3.87145948e 00],
         [ 2.71238947e 00,  1.46723819e 00,  7.61192560e-01, ...,
           2.69581342e 00,  2.11386037e 00,  4.08577728e 00]],

        [[ 3.52043629e-01, -1.83945060e 00, -9.97831583e-01, ...,
          -2.60245442e-01,  3.69277894e-01,  1.17505208e-01],
         [ 2.62015522e-01, -6.50106370e-01, -7.36498535e-01, ...,
          -3.72626394e-01, -9.92001474e-01,  1.87904552e-01],
         [-2.00427341e 00, -2.67415404e-01, -1.00334084e 00, ...,
           8.22970718e-02,  1.41485706e-01,  1.49001801e 00],
         ...,
         [-9.85595703e-01, -9.74879414e-03,  1.27501774e 00, ...,
          -7.10564435e-01,  1.17551017e 00, -4.15902734e-01],
         [-9.80473995e-01, -1.07735765e 00,  2.39617974e-02, ...,
          -1.93872005e-02,  2.48361230e-02,  7.19040394e-01],
         [-6.61614537e-02, -4.85614896e-01, -7.31452227e-01, ...,
          -9.65917259e-02, -2.94267178e-01,  1.87805906e-01]]],


       [[[-1.05941308e 00,  8.10959578e-01, -9.29054856e-01, ...,
          -1.33419132e 00, -5.62950134e-01,  3.15277368e-01],
         [-2.45844007e 00, -5.31174302e-01,  8.06264520e-01, ...,
          -1.37343729e 00, -1.26287377e 00, -1.79255664e 00],
         [ 5.01155496e-01,  2.53203034e 00, -9.11398768e-01, ...,
          -2.61194611e 00, -6.27550602e-01, -1.04612875e 00],
         ...,
         [ 5.64767838e-01,  1.82380235e 00, -9.87865806e-01, ...,
          -1.48546624e 00,  5.00284791e-01, -1.14099467e 00],
         [-1.48488015e-01, -3.75306606e-03,  2.05217457e 00, ...,
          -4.82964367e-01,  6.37757182e-01,  5.87742925e-01],
         [-7.62285709e-01,  5.78535438e-01, -9.07517672e-01, ...,
          -1.40203249e 00,  3.13063234e-01,  9.46564317e-01]],

        [[ 2.21778965e 00,  1.17825162e 00,  1.17773283e 00, ...,
           4.21785736e 00,  1.93207061e 00,  6.90674305e 00],
         [ 5.16840172e 00,  4.03573513e-02,  3.72957373e 00, ...,
           2.57324958e 00,  3.23857665e-01,  8.98278236e-01],
         [ 1.18916261e 00,  4.03137350e 00,  1.54717636e 00, ...,
           5.73142242e 00,  2.54209590e 00,  3.02691102e 00],
         ...,
         [ 2.02949071e 00,  4.00444984e 00,  3.55739307e 00, ...,
           5.54533482e-01,  3.57894540e 00,  7.03547835e-01],
         [ 2.57975435e 00,  2.32062602e 00,  4.18669128e 00, ...,
           2.15663671e 00,  2.39567637e 00,  7.93485880e-01],
         [ 3.32399893e 00,  3.12817383e 00,  3.60134292e 00, ...,
           1.70791423e 00,  7.71586776e-01,  3.58140349e 00]],

        [[-6.90246701e-01, -8.55753422e-01, -1.35433823e-01, ...,
           9.99482393e-01, -2.96287388e-01,  2.49611807e 00],
         [ 1.56937921e 00, -9.95752215e-01,  5.38442284e-02, ...,
           1.63274094e-01, -9.27845955e-01, -6.64922059e-01],
         [-5.40241778e-01,  2.26666585e-01, -2.95405626e-01, ...,
           1.90356636e 00,  4.94795978e-01,  1.35599896e-01],
         ...,
         [-4.09579694e-01,  1.26961544e-01,  5.97525239e-01, ...,
          -9.00853217e-01,  8.11160445e-01, -8.88532698e-01],
         [-5.75763881e-01, -1.15364529e-01,  2.42510274e-01, ...,
           1.83168098e-01, -3.83193374e-01, -1.10992551e 00],
         [ 9.75027233e-02,  1.07848495e-02,  2.93477297e-01, ...,
          -2.67393768e-01, -8.09366763e-01,  6.60410523e-03]]]],
      dtype=float32)]

但是现在batch_size依然是一个固定值,如果我们修改input的第一个维度,是会报错的。则我们需要修改成以下的方式才能输入任意的batch_size。

代码语言:javascript复制
import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    inputs = model.graph.input
    outputs = model.graph.output
    for i in inputs:
        i.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    for o in outputs:
        o.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    onnx.checker.check_model(model)
    onnx.save_model(model, 'dynamic_model.onnx')
    sess = onnxruntime.InferenceSession('./dynamic_model.onnx')
    input = np.random.randn(3, 3, 244, 244).astype(np.float32)
    print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
[array([[[[-5.6308472e-01, -2.8269453e 00, -2.7103744e 00, ...,
           4.2550400e-01,  6.5147376e-01, -4.7779888e-02],
         [-2.5536952e 00,  1.1469245e-01,  3.4514198e-01, ...,
          -1.8919052e 00, -5.7445437e-01, -1.5864235e 00],
         [-1.7443299e-02, -8.9739335e-01, -2.9766396e-01, ...,
           2.7872375e-01, -8.8234627e-01, -2.3681331e 00],
         ...,
         [-1.3148707e 00, -5.4888296e-01,  4.1061863e-01, ...,
          -1.0763314e 00, -9.6379507e-01,  1.3077673e 00],
         [-6.3514382e-02, -5.1493609e-01, -1.5793841e 00, ...,
          -2.2589236e-02, -2.2170777e 00,  1.2437304e 00],
         [ 7.4394345e-01,  7.8581774e-01,  2.0062235e-01, ...,
          -1.4014708e 00,  5.5377036e-02,  3.6608991e-01]],

        [[ 5.4129419e 00,  1.7448205e 00,  3.4165416e 00, ...,
           1.0320716e 00,  1.6988618e 00,  5.1501741e 00],
         [-1.3918903e 00,  1.7199724e 00,  2.1343894e 00, ...,
           8.0553353e-01,  4.7985373e 00,  2.5783958e 00],
         [ 2.3555427e 00,  6.3222194e-01,  2.9314611e 00, ...,
           4.3459427e-01,  1.3417060e 00,  1.6852837e 00],
         ...,
         [ 5.7537341e-01,  3.0654173e 00, -5.7629395e-01, ...,
           1.0968879e 00,  3.7861698e 00,  1.4928346e 00],
         [ 3.1267416e 00,  2.0358701e 00,  2.2204084e 00, ...,
           5.2084265e 00,  3.9166064e 00,  6.4575119e 00],
         [-7.2486067e-01,  2.3311584e 00,  2.0912974e 00, ...,
           1.8693907e 00,  3.2796674e 00,  3.8991761e 00]],

        [[ 1.1898929e 00,  8.9648962e-03,  8.5148907e-01, ...,
          -4.7057205e-01, -7.9108685e-01,  1.0573645e 00],
         [-1.5732453e 00, -3.8554335e-01, -1.6086581e-01, ...,
          -8.1125468e-01,  1.2085729e 00,  5.6812420e-02],
         [-3.0767348e-01, -8.5083431e-01,  4.9003422e-02, ...,
          -7.8210533e-01, -5.2408022e-01, -2.3199841e-02],
         ...,
         [-7.3540843e-01, -2.9384446e-01, -1.6465921e 00, ...,
          -3.8980949e-01,  7.1137357e-01, -8.0783540e-01],
         [ 2.3953258e-01, -1.7050017e-01, -1.3933203e-01, ...,
           1.6591790e 00,  1.0759927e 00,  1.7683787e 00],
         [-1.6956003e 00, -4.6602386e-01, -3.4259117e-01, ...,
          -1.0014131e-01,  2.6990986e-01,  8.6363363e-01]]],


       [[[ 6.4478827e-01, -8.1067204e-01, -1.2237258e 00, ...,
          -1.2951733e 00, -6.2070227e-01, -1.2906476e 00],
         [-6.6038930e-01, -2.8674665e-01, -1.0612940e 00, ...,
           4.6769258e-01,  4.8500946e-01, -5.6188315e-01],
         [ 1.0600269e-02, -1.4934481e 00,  9.1430867e-01, ...,
          -6.1285675e-01, -3.0706315e 00, -9.9033105e-01],
         ...,
         [ 1.7771789e 00, -1.3830042e 00, -1.4351614e 00, ...,
          -2.6786397e 00,  3.7956804e-02,  6.7189908e-01],
         [-2.1517308e 00, -5.8123243e-01, -7.7163374e-01, ...,
           1.6774191e 00,  7.2239363e-01,  1.3373801e 00],
         [-8.6465418e-01, -1.3932706e 00, -2.2982714e 00, ...,
           1.9587449e 00, -6.2718022e-01, -1.1754386e 00]],

        [[ 5.2605295e 00,  6.8119764e-01,  1.6433215e 00, ...,
           1.4899890e 00,  7.7494907e-01,  1.0885936e 00],
         [ 1.7135508e 00,  1.7890544e 00,  1.5538380e 00, ...,
           4.2714515e 00,  3.4532502e 00,  4.0540075e 00],
         [ 3.2757509e-01,  2.8093519e 00,  4.4473543e 00, ...,
           1.6302650e 00,  2.0791094e 00, -2.7314346e 00],
         ...,
         [ 3.1872306e 00,  2.1063502e 00,  4.4839258e 00, ...,
           8.6034179e-01,  3.7707591e 00,  3.9809742e 00],
         [-2.0055294e-02, -4.3134212e-02,  2.1313593e 00, ...,
           3.0318618e 00,  2.2852294e 00,  3.9968524e 00],
         [ 2.1781492e 00,  3.6937137e 00,  1.5003638e 00, ...,
           3.5955300e 00,  1.7056749e 00,  1.9585730e 00]],

        [[ 1.1795213e 00, -7.1754062e-01, -1.3523299e-01, ...,
          -5.6350648e-01, -1.3417213e 00, -5.0127864e-02],
         [-5.1167816e-01, -7.7823803e-02, -3.1461412e-01, ...,
           7.8631788e-01,  5.9256524e-01,  6.9275266e-01],
         [-1.3142396e 00,  7.9331988e-01,  5.0062788e-01, ...,
          -6.4525604e-03, -3.3234254e-02, -2.1546085e 00],
         ...,
         [-2.0651843e-01,  1.1771068e-02,  1.3835690e 00, ...,
           2.8883666e-03,  4.5511311e-01,  2.9804629e-01],
         [-9.0822458e-01, -1.3634090e 00, -4.2348909e-01, ...,
           2.9903316e-01, -5.9180021e-01,  5.1938176e-01],
         [-3.7974668e-01,  6.5785772e-01, -4.8025602e-01, ...,
           1.5578230e-01, -8.5666311e-01, -8.2990326e-02]]],


       [[[ 2.7270940e-01,  1.6803369e-01,  6.4784336e-01, ...,
          -8.6817765e-01,  2.4317000e 00,  9.9560642e-01],
         [-1.0902294e 00, -1.5418210e 00, -6.4213789e-01, ...,
           3.8346985e-01, -2.2009264e-01, -1.4083362e 00],
         [-1.2999996e 00, -1.0029310e 00, -8.0927563e-01, ...,
          -9.6844232e-01,  4.7647089e-02, -1.7528368e 00],
         ...,
         [ 7.9181468e-01, -7.1245348e-01, -1.2355906e 00, ...,
          -4.4910422e-01,  7.0296872e-01, -1.8157486e 00],
         [ 8.5229218e-01, -3.9036795e-01,  3.7029549e-01, ...,
          -2.0579123e 00,  9.2259049e-03, -1.2485095e 00],
         [-1.0421257e 00,  9.6360290e-01, -1.9165359e 00, ...,
          -1.5525728e 00, -2.7757692e 00,  5.9844279e-01]],

        [[ 2.0120070e 00,  2.5763493e 00,  2.5311258e 00, ...,
           2.0375581e 00,  1.6430848e 00,  4.5296006e 00],
         [-6.4119029e-01,  3.2270002e-01,  2.7286339e 00, ...,
           3.4792902e 00,  4.8433290e 00,  1.8760866e 00],
         [ 5.2160606e 00,  5.8354855e-01,  1.9910555e 00, ...,
           3.8761294e-01,  3.4568546e 00,  2.2840927e 00],
         ...,
         [ 2.4697292e 00,  3.1099756e 00,  4.5984769e 00, ...,
           3.1638999e 00,  1.7895203e 00,  5.1426482e-01],
         [ 2.0174649e 00,  3.7343421e 00,  1.3838698e 00, ...,
           6.8948352e-01,  1.9830887e 00, -1.2911747e 00],
         [ 2.2970469e 00,  2.8243198e 00,  8.7906146e-01, ...,
           3.2837601e 00,  1.0420291e 00,  4.1244802e 00]],

        [[-4.5218289e-01, -1.1248827e-02, -3.9010030e-01, ...,
          -2.1441557e-01, -8.8925439e-01,  1.0432711e 00],
         [-1.5277631e 00, -6.0763943e-01,  8.2450414e-01, ...,
           5.1565582e-01,  9.1227055e-01, -4.1257131e-01],
         [ 1.1678007e 00, -8.4806198e-01, -4.1370481e-01, ...,
          -9.4888353e-01,  2.4556525e-01,  2.7058780e-02],
         ...,
         [-2.6444227e-01,  6.4803612e-01,  1.4935874e 00, ...,
           1.9097075e-02, -6.0670388e-01, -3.2186458e-01],
         [-5.2368152e-01,  6.9923353e-01, -6.0641676e-01, ...,
          -3.2536793e-01, -3.0933461e-01, -1.7596698e 00],
         [-5.7884902e-01, -9.0141267e-02, -4.4471401e-01, ...,
           3.5021925e-01, -1.7998603e-01,  6.4285696e-01]]]],
      dtype=float32)]

这里可以把input的第一个维度,也就是batch_size修改成任意数值,程序都可以运行。此时我们打印下model的信息。

代码语言:javascript复制
import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    inputs = model.graph.input
    outputs = model.graph.output
    for i in inputs:
        i.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    for o in outputs:
        o.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    print(model)
    # onnx.checker.check_model(model)
    # onnx.save_model(model, 'dynamic_model.onnx')
    # sess = onnxruntime.InferenceSession('./dynamic_model.onnx')
    # input = np.random.randn(3, 3, 244, 244).astype(np.float32)
    # print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
ir_version: 8
graph {
  node {
    input: "input"
    input: "weight"
    input: "bias"
    output: "output"
    op_type: "Conv"
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  name: "graph"
  initializer {
    dims: 3
    dims: 3
    dims: 1
    dims: 1
    data_type: 1
    float_data: 0.45837152004241943
    float_data: 0.10209446400403976
    float_data: 1.0382566452026367
    float_data: -0.09292714297771454
    float_data: 1.58871591091156
    float_data: 0.3746287226676941
    float_data: -0.35588690638542175
    float_data: 0.7165427207946777
    float_data: 0.10244251787662506
    name: "weight"
  }
  initializer {
    dims: 3
    data_type: 1
    float_data: -0.36782845854759216
    float_data: 2.305680513381958
    float_data: -0.13051341474056244
    name: "bias"
  }
  input {
    name: "input"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_param: "batchsize"
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_param: "batchsize"
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
}
opset_import {
  version: 17
}

这里我们可以看到在input中的第一个dim中变成了dim_param: "batchsize"

节点的增加和删除

  • 增加节点
代码语言:javascript复制
import onnx
import onnx.helper as helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./model.onnx')
    nodes = model.graph.node
    new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output'])
    nodes.append(new_node)
    nodes[0].output[0] = 'conv1'
    onnx.checker.check_model(model)
    onnx.save_model(model, 'add_model.onnx')

    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    sess = onnxruntime.InferenceSession('./add_model.onnx')
    print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
[array([[[[1.5453527 , 0.        , 0.        , ..., 0.04255658,
          0.        , 0.40214583],
         [0.        , 0.5019511 , 0.        , ..., 0.34235588,
          0.36859825, 0.        ],
         [0.        , 0.        , 0.34334645, ..., 0.        ,
          0.        , 0.        ],
         ...,
         [1.1857387 , 1.0710502 , 0.        , ..., 0.        ,
          1.8497316 , 0.        ],
         [0.37889728, 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.73697627, 0.        , 0.4978644 , ..., 0.        ,
          0.        , 0.32394186]],

        [[1.2723072 , 0.        , 0.66669345, ..., 5.6399436 ,
          1.4827138 , 2.7300682 ],
         [4.5705633 , 2.9856906 , 2.9005556 , ..., 3.505543  ,
          4.7502317 , 0.        ],
         [1.5251542 , 3.3182473 , 3.8036246 , ..., 0.        ,
          1.6024959 , 1.4051957 ],
         ...,
         [1.7204559 , 4.551407  , 4.172427  , ..., 0.9121852 ,
          3.3593512 , 4.6163626 ],
         [0.2845726 , 0.13289118, 3.3601975 , ..., 3.9331636 ,
          0.3700601 , 1.5711328 ],
         [3.3283763 , 2.128338  , 2.1621299 , ..., 1.7635765 ,
          0.        , 2.1479769 ]],

        [[0.        , 0.        , 0.        , ..., 1.4292918 ,
          0.        , 0.46683455],
         [1.0534286 , 0.        , 0.02258705, ..., 0.4342987 ,
          1.1339298 , 0.        ],
         [0.        , 0.50237906, 0.20627443, ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.        , 0.78040606, 1.003104  , ..., 0.        ,
          0.        , 1.0389903 ],
         [0.        , 0.        , 0.        , ..., 0.74816215,
          0.        , 0.02678718],
         [0.26068228, 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]]]], dtype=float32)]

这里我们再来打印下model的信息

代码语言:javascript复制
import onnx
import onnx.helper as helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./model.onnx')
    nodes = model.graph.node
    new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output'])
    nodes.append(new_node)
    nodes[0].output[0] = 'conv1'
    print(model)
    # onnx.checker.check_model(model)
    # onnx.save_model(model, 'add_model.onnx')
    #
    # input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    # sess = onnxruntime.InferenceSession('./add_model.onnx')
    # print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
ir_version: 8
graph {
  node {
    input: "input"
    input: "weight"
    input: "bias"
    output: "conv1"
    op_type: "Conv"
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  node {
    input: "conv1"
    output: "output"
    name: "relu1"
    op_type: "Relu"
  }
  name: "graph"
  initializer {
    dims: 3
    dims: 3
    dims: 1
    dims: 1
    data_type: 1
    float_data: 0.45837152004241943
    float_data: 0.10209446400403976
    float_data: 1.0382566452026367
    float_data: -0.09292714297771454
    float_data: 1.58871591091156
    float_data: 0.3746287226676941
    float_data: -0.35588690638542175
    float_data: 0.7165427207946777
    float_data: 0.10244251787662506
    name: "weight"
  }
  initializer {
    dims: 3
    data_type: 1
    float_data: -0.36782845854759216
    float_data: 2.305680513381958
    float_data: -0.13051341474056244
    name: "bias"
  }
  input {
    name: "input"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
}
opset_import {
  version: 17
}

这里我们可以看到增加了一个relu1的节点,并且第一个节点的output是conv1,第二个节点的input是conv1,output是output。

  • 删除节点
代码语言:javascript复制
import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./add_model.onnx')
    nodes = model.graph.node
    for node in nodes:
        if node.name == 'relu1':
            nodes.remove(node)
    nodes[0].output[0] = 'output'
    onnx.checker.check_model(model)
    onnx.save_model(model, 'del_model.onnx')

    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    sess = onnxruntime.InferenceSession('./del_model.onnx')
    print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
[array([[[[-8.5923064e-01, -4.2249173e-01,  3.8687822e-01, ...,
          -4.8348337e-02,  3.1652334e-01, -5.7166600e-01],
         [ 3.1469372e-01, -9.4796360e-01, -2.4245100e 00, ...,
           4.1007617e-01, -1.4098099e 00,  6.7472184e-01],
         [-1.2910874e 00,  1.6070822e-01, -1.0217074e 00, ...,
           7.1467435e-01,  1.5835044e-01, -6.4228356e-01],
         ...,
         [-2.5442154e 00, -8.8969648e-01,  1.1389736e 00, ...,
           1.7202379e 00, -1.1968368e 00, -3.3861694e-01],
         [-9.0216339e-01,  4.8469666e-01, -9.5050204e-01, ...,
           4.0511075e-01, -1.0113320e-01,  1.8743831e 00],
         [ 3.2901958e-01,  4.3780953e-02,  1.4250931e 00, ...,
          -1.4544667e 00,  9.0659869e-01,  1.7170597e 00]],

        [[ 1.3439684e 00,  3.0856354e 00,  2.7811766e 00, ...,
           4.1714394e-01, -3.3547878e-02,  1.1771207e 00],
         [ 2.1574910e 00,  2.1122241e 00, -5.8333945e-01, ...,
           1.9629711e 00,  3.4840956e 00,  6.1747317e 00],
         [ 5.2136226e 00,  4.8688288e 00,  1.4613919e 00, ...,
           4.1095753e 00,  1.4553337e 00,  3.5171165e 00],
         ...,
         [ 7.3736429e-02,  8.4109855e-01,  5.7113109e 00, ...,
           3.6336284e 00,  4.4551125e 00,  3.4602299e 00],
         [ 1.1054695e 00,  2.7417006e 00,  4.9065466e 00, ...,
           2.1775680e 00,  4.4132576e 00,  2.3781679e 00],
         [-1.2788355e 00,  2.5300267e 00,  3.2560487e 00, ...,
           2.2025514e 00,  4.2551570e 00,  3.5148311e 00]],

        [[-8.5124874e-01,  3.1858414e-01,  3.3686757e-03, ...,
          -1.1497847e 00, -1.1996644e 00, -9.6176589e-01],
         [-4.2057925e-01, -1.8098265e-01, -7.4302059e-01, ...,
          -3.5920531e-01,  7.0454830e-01,  1.8304255e 00],
         [ 1.4177717e 00,  8.4456313e-01, -1.6396353e-01, ...,
           4.2133337e-01, -4.6482396e-01,  6.6906375e-01],
         ...,
         [-1.0060047e 00, -1.2088763e 00,  1.2608007e 00, ...,
           5.1739502e-01,  8.9526463e-01,  7.2866821e-01],
         [-3.5698372e-01, -5.9943002e-01,  1.0040566e 00, ...,
           3.1322885e-01,  3.4513384e-01, -6.2404698e-01],
         [-2.0622578e 00,  3.9633280e-01,  2.1701033e-01, ...,
           2.6992482e-01,  4.4787437e-01,  2.1187775e-01]]]],
      dtype=float32)]

替换节点

现在我们将add_model.onnx中的Conv节点替换成Squeeze节点(压缩维度)

代码语言:javascript复制
import onnx
import onnx.helper as helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./add_model.onnx')
    new_node = helper.make_node(op_type='Squeeze', inputs=['input'], outputs=['conv1'], name='squeeze1')
    nodes = model.graph.node
    nodes.append(new_node)
    for node in nodes:
        if node.op_type == 'Conv':
            nodes.remove(node)
    # onnx.checker.check_model(model)
    onnx.save_model(model, 'replace.onnx')
    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    sess = onnxruntime.InferenceSession('./replace.onnx')
    print(sess.run(['output'], {'input': input}))

运行结果

代码语言:javascript复制
[array([[[0.        , 1.1258854 , 0.        , ..., 0.54984987,
         0.19069785, 0.        ],
        [1.1481465 , 0.        , 1.9025986 , ..., 0.        ,
         0.11273875, 0.        ],
        [1.57912   , 0.        , 0.        , ..., 0.        ,
         1.7471381 , 0.        ],
        ...,
        [0.42386332, 0.30908984, 0.        , ..., 0.        ,
         0.        , 1.8173866 ],
        [0.07642962, 0.31224537, 0.        , ..., 1.6805407 ,
         2.0282576 , 0.        ],
        [0.        , 0.2521538 , 0.        , ..., 0.        ,
         0.6431213 , 0.5844705 ]],

       [[0.        , 0.        , 0.        , ..., 0.23725364,
         0.22994171, 0.316093  ],
        [0.85044146, 1.2757416 , 0.28854838, ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 1.1362596 , ..., 1.8543358 ,
         1.1296074 , 0.5114057 ],
        ...,
        [0.        , 0.00810617, 0.        , ..., 1.0819261 ,
         1.707781  , 0.        ],
        [0.        , 0.6385371 , 0.        , ..., 0.6565783 ,
         1.457183  , 0.        ],
        [0.        , 0.8315589 , 1.4111192 , ..., 1.0682058 ,
         0.17328343, 2.3547616 ]],

       [[0.2426068 , 0.        , 0.        , ..., 0.89054537,
         0.98760164, 0.        ],
        [1.1344411 , 0.8732987 , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.3664789 , 1.4099371 , ..., 0.        ,
         0.0588427 , 0.5932818 ],
        ...,
        [0.        , 0.68438137, 0.8869638 , ..., 0.        ,
         0.        , 1.4681839 ],
        [0.        , 0.        , 0.        , ..., 0.16630006,
         1.9389246 , 0.        ],
        [0.        , 0.        , 0.03726726, ..., 0.86296386,
         0.        , 0.        ]]], dtype=float32)]

这里需要注意的是,如果我们将# onnx.checker.check_model(model)的注释打开,运行是会报错的,因为我们添加的新节点squeeze1是在relu1之后的,虽然无法通过检查,但是是可以使用运行时来运行的。那如何才能即能运行又让检查也可以通过。

代码语言:javascript复制
import onnx
import onnx.helper as helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./add_model.onnx')
    new_node = helper.make_node(op_type='Squeeze', inputs=['input'], outputs=['conv1'], name='squeeze1')
    nodes = model.graph.node
    # nodes.append(new_node)
    for idx, node in enumerate(nodes):
        if node.op_type == 'Conv':
            nodes.remove(node)
            nodes.insert(idx, new_node)
    onnx.checker.check_model(model)
    onnx.save_model(model, 'replace.onnx')
    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    sess = onnxruntime.InferenceSession('./replace.onnx')
    print(sess.run(['output'], {'input': input}))

这里主要就是调换一下新节点的位置就好了。

ONNXRuntime介绍

ONNXRuntime是微软推出的一个推理框架,可以非常方便的运行ONNX模型,官方GitHub:https://github.com/microsoft/onnxruntime

0 人点赞