PyTorch 1.0 中文官方教程:Torchvision 模型微调

2022-05-07 14:00:50 浏览数 (1)

译者:ZHHAYO

作者: Nathan Inkawhich

在本教程中,我们将深入探讨如何微调和特征提取torchvision 模型,所有这些模型都已经预先在1000类的magenet数据集上训练完成。本程将深入介绍如何使用几个现代的CNN架构,并将为微调任意的PyTorch模型建立一个直觉。 由于每个模型架构是有差异的,因此没有可以在所有场景中使用的样板微调代码。 然而,研究人员必须查看现有架构并对每个模型进行自定义调整。

在本文档中,我们将执行两种类型的迁移学习:微调和特征提取。 在微调中,我们从一个预训练模型开始,然后为我们的新任务更新所有的模型参数,实质上就是重新训练整个模型。 在特征提取中,我们从预训练模型开始,只更新产生预测的最后一层的权重。它被称为特征提取是因为我们使用预训练的CNN作为固定的特征提取器,并且仅改变输出层。 有关迁移学习的更多技术信息,请参阅here和here.

通常,这两种迁移学习方法都遵循以下几个步骤:

  • 初始化预训练模型
  • 重组最后一层,使其具有与新数据集类别数相同的输出数
  • 为优化算法定义我们想要在训练期间更新的参数
  • 运行训练步骤
代码语言:javascript复制
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

阅读全文/改进本文

0 人点赞