【pytorch】固定(freeze)住部分网络

2021-12-06 21:17:42 浏览数 (1)

前言

最好、最高效、最简洁的,是 “方案一” 。

方案一

步骤一、固定基本网络

代码模板:

代码语言:javascript复制
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # 打印当前的固定情况(可忽略):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num  = 1
    #     else:
    #         pass_num  = 1
    # print('n Total {} params, miss {} n'.format(freezed_num   pass_num, pass_num))

    return model

Note:

  • 如果预加载模型是在 model = nn.DataParallel(model) 模式下训练出来的分布式模型,那么每个参数名称会默认加上 .module 前缀。
  • 相应地,会导致无法对号导入单机模型。此时需要将如下语句:
代码语言:javascript复制
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
改为: 
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}

步骤二、让optimizer回避要freeze的参数

代码模板:

代码语言:javascript复制
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

步骤三、train时通过.eval()来freeze

因为:即使对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。(详见【pytorch】bn)

所以:train每个epoch之前都要统一重新定义一下这块,否则容易出问题。

代码语言:javascript复制
model.eval()
model.stage4_xx.train()
model.pred_xx.train()

方案二

pytorch下进行freeze操作,一般需要经过以下四步。

步骤一、固定基本网络

代码模板:

代码语言:javascript复制
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # 打印当前的固定情况(可忽略):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num  = 1
    #     else:
    #         pass_num  = 1
    # print('n Total {} params, miss {} n'.format(freezed_num   pass_num, pass_num))

    return model

步骤二、让optimizer回避要freeze的参数

代码模板:

代码语言:javascript复制
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

步骤三、固定bn

(参考《bn》)即使通过步骤一对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。

所以还需要额外地深入固定bn:

  • 固定 momentum :momentum=0.0
  • 掐灭 track_running_stats :track_running_stats=False

举例:

代码语言:javascript复制
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)

修改为:

代码语言:javascript复制
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

但是 track_running_stats=False 会带来副作用:受波及的每个bn都会在state_dict中丢失三个对应的键值对(每组对应的key都为xx.xx.bn.running_mean、xx.xx.bn.running_var 和 xx.xx.bn.num_batches_tracked)

步骤四、正常训练

训练过程中,记得定时check一下被固定部分是否恒定不变:

  • 比如每次eval的时候,顺便check一下被固定部分的预测精度。

步骤五、后处理

4.1 重启track_running_stats

举例:

代码语言:javascript复制
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

修改为:

代码语言:javascript复制
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)

此时,之前受波及的每个bn,都会在state_dict中恢复所丢失三个对应的键值对(但是value为空,待填充)。

Note:

  • 线上训练虽然用freeze过的网络,但线下测试时,还是要老老实实换回未被freeze的网络。否则结果不仅会对不齐,被freeze和未被freeze的task都会表现更差!
4.2 复原缺失的value

为了克服 track_running_stats=False 带来的副作用,最终模型需要依赖 “原始state_dict” 和 “训好的state_dict” 合并。前者为后者补充缺失的value。

代码语言:javascript复制
# 原始state_dict:
origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu'))
# 训好的state_dict:
new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu'))

# 后者从前者中补充缺失的键值对:
final_dict = new_state_dict.copy()
for (key, val) in origin_state_dict.items():
    if key not in final_dict:
        final_dict[key] = val

# 载入合并好的 state_dict,这时候一定是可以通过 strict=True 的:
model.load_state_dict(final_dict, strict=True)
这时重新再save一遍model,就是可最终直接用的model文件了。 

0 人点赞