前言
最好、最高效、最简洁的,是 “方案一” 。
方案一
步骤一、固定基本网络
代码模板:
代码语言:javascript复制# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)
# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下:
def freeze_model(model, to_freeze_dict, keep_step=None):
for (name, param) in model.named_parameters():
if name in to_freeze_dict:
param.requires_grad = False
else:
pass
# # 打印当前的固定情况(可忽略):
# freezed_num, pass_num = 0, 0
# for (name, param) in model.named_parameters():
# if param.requires_grad == False:
# freezed_num = 1
# else:
# pass_num = 1
# print('n Total {} params, miss {} n'.format(freezed_num pass_num, pass_num))
return model
Note:
- 如果预加载模型是在 model = nn.DataParallel(model) 模式下训练出来的分布式模型,那么每个参数名称会默认加上 .module 前缀。
- 相应地,会导致无法对号导入单机模型。此时需要将如下语句:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
改为:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}
步骤二、让optimizer回避要freeze的参数
代码模板:
代码语言:javascript复制optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)
步骤三、train时通过.eval()来freeze
因为:即使对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。(详见【pytorch】bn)
所以:train每个epoch之前都要统一重新定义一下这块,否则容易出问题。
代码语言:javascript复制model.eval()
model.stage4_xx.train()
model.pred_xx.train()
方案二
pytorch下进行freeze操作,一般需要经过以下四步。
步骤一、固定基本网络
代码模板:
代码语言:javascript复制# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)
# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下:
def freeze_model(model, to_freeze_dict, keep_step=None):
for (name, param) in model.named_parameters():
if name in to_freeze_dict:
param.requires_grad = False
else:
pass
# # 打印当前的固定情况(可忽略):
# freezed_num, pass_num = 0, 0
# for (name, param) in model.named_parameters():
# if param.requires_grad == False:
# freezed_num = 1
# else:
# pass_num = 1
# print('n Total {} params, miss {} n'.format(freezed_num pass_num, pass_num))
return model
步骤二、让optimizer回避要freeze的参数
代码模板:
代码语言:javascript复制optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)
步骤三、固定bn
(参考《bn》)即使通过步骤一对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。
所以还需要额外地深入固定bn:
- 固定 momentum :momentum=0.0
- 掐灭 track_running_stats :track_running_stats=False
举例:
代码语言:javascript复制self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
修改为:
代码语言:javascript复制self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)
但是 track_running_stats=False 会带来副作用:受波及的每个bn都会在state_dict中丢失三个对应的键值对(每组对应的key都为xx.xx.bn.running_mean、xx.xx.bn.running_var 和 xx.xx.bn.num_batches_tracked)
步骤四、正常训练
训练过程中,记得定时check一下被固定部分是否恒定不变:
- 比如每次eval的时候,顺便check一下被固定部分的预测精度。
步骤五、后处理
4.1 重启track_running_stats
举例:
代码语言:javascript复制self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)
修改为:
代码语言:javascript复制self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)
此时,之前受波及的每个bn,都会在state_dict中恢复所丢失三个对应的键值对(但是value为空,待填充)。
Note:
- 线上训练虽然用freeze过的网络,但线下测试时,还是要老老实实换回未被freeze的网络。否则结果不仅会对不齐,被freeze和未被freeze的task都会表现更差!
4.2 复原缺失的value
为了克服 track_running_stats=False 带来的副作用,最终模型需要依赖 “原始state_dict” 和 “训好的state_dict” 合并。前者为后者补充缺失的value。
代码语言:javascript复制# 原始state_dict:
origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu'))
# 训好的state_dict:
new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu'))
# 后者从前者中补充缺失的键值对:
final_dict = new_state_dict.copy()
for (key, val) in origin_state_dict.items():
if key not in final_dict:
final_dict[key] = val
# 载入合并好的 state_dict,这时候一定是可以通过 strict=True 的:
model.load_state_dict(final_dict, strict=True)
这时重新再save一遍model,就是可最终直接用的model文件了。