预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。
核心转换函数如下所示:
代码语言:javascript复制def convert_from_mxnet(model, checkpoint_prefix, debug=False):
_, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
remapped_state = {}
for state_key in model.state_dict().keys():
k = state_key.split('.')
aux = False
mxnet_key = ''
if k[0] == 'features':
if k[1] == 'conv1_1':
# input block
mxnet_key = 'conv1_x_1__'
if k[2] == 'bn':
mxnet_key = 'relu-sp__bn_'
aux, key_add = _convert_bn(k[3])
mxnet_key = key_add
else:
assert k[3] == 'weight'
mxnet_key = 'conv_' k[3]
elif k[1] == 'conv5_bn_ac':
# bn ac at end of features block
mxnet_key = 'conv5_x_x__relu-sp__bn_'
assert k[2] == 'bn'
aux, key_add = _convert_bn(k[3])
mxnet_key = key_add
else:
# middle blocks
if model.b and 'c1x1_c' in k[2]:
bc_block = True # b-variant split c-block special treatment
else:
bc_block = False
ck = k[1].split('_')
mxnet_key = ck[0] '_x__' ck[1] '_'
ck = k[2].split('_')
mxnet_key = ck[0] '-' ck[1]
if ck[1] == 'w' and len(ck) > 2:
mxnet_key = '(s/2)' if ck[2] == 's2' else '(s/1)'
mxnet_key = '__'
if k[3] == 'bn':
mxnet_key = 'bn_' if bc_block else 'bn__bn_'
aux, key_add = _convert_bn(k[4])
mxnet_key = key_add
else:
ki = 3 if bc_block else 4
assert k[ki] == 'weight'
mxnet_key = 'conv_' k[ki]
elif k[0] == 'classifier':
if 'fc6-1k_weight' in mxnet_weights:
mxnet_key = 'fc6-1k_'
else:
mxnet_key = 'fc6_'
mxnet_key = k[1]
else:
assert False, 'Unexpected token'
if debug:
print(mxnet_key, '=> ', state_key, end=' ')
mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
if k[0] == 'classifier' and k[1] == 'weight':
torch_tensor = torch_tensor.view(torch_tensor.size() (1, 1))
remapped_state[state_key] = torch_tensor
if debug:
print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
model.load_state_dict(remapped_state)
return model
从中可以看出,其转换步骤如下:
(1)创建pytorch的网络结构模型,设为model
(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;
(3)遍历加载后模型mxnet_weights的state_dict().keys
(4)对一些指定的key值,需要进行相应的处理和转换
(5)对修改键名之后的key利用numpy之间的转换来实现加载。
为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。
第二步,运行转换程序,实现预训练模型的转换。
可以看到在相当的文件夹下已经出现了转换后的模型。