译者:yuange250
最佳方案
保存模型的推荐方法
Pytorch主要有两种方法可用于序列化和保存一个模型。
第一种只存取模型的参数(更为推荐): 保存参数:
代码语言:javascript复制torch.save(the_model.state_dict(), PATH)
读取参数:
代码语言:javascript复制the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二种方法则将整个模型都保存下来:
代码语言:javascript复制torch.save(the_model, PATH)
读取的时候也是读取整个模型:
代码语言:javascript复制the_model = torch.load(PATH)
阅读全文/改进本文