迁移学习是指在类似的数据集上使用训练好的算法,而无须从头开始训练。人类并不是通过从头分析成千上万个相似的图像来识别新的图形。
VGG16模型是在ImageNet竞赛中获得成功的最早的算法之一,它比较简单,本篇用它来介绍迁移学习和用来训练我们的蔬菜水果图像分类器。
VGG16模型包含5个VGG块(features部分)。每个VGG块是一组卷积层,一个非线性激活函数和一个最大化池化函数。它所有的算法参数都是调整好的,可以识别1000个类别。VGG16详细的模型结构如下:
代码语言:javascript复制VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
在迁移学习中,我们尝试通过冻结模型的大部分层的学习参数来获得算法的学习内容,仅仅微调网络的最后几层,如本例中我们保留5个VGG模块(features部分)中的学习参数(权重)仅改变线性层(classifier部分)的学习参数。当然,最后一层的out_features 得由1000改为36,因为我们的蔬菜水果数据集仅含有36个类别。此数据集可以从kaggle网站下载,包含训练集,测试集和验证集,共2G左右,网址如下:https://www.kaggle.com/datasets/kritikseth/fruit-and-vegetable-image-recognition
代码语言:javascript复制import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss, correct = 0, 0
model.train()
for batch, (X, y) in enumerate(dataloader):
#X = X.to("cuda")
#y = y.to("cuda")
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
correct = (pred.argmax(1) == y).type(torch.float).sum().item() ##
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = loss.item() # only for monitoring
if batch % 10 == 0: # 训练监控
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss /= num_batches
correct /= size
train_Avg_loss.append(train_loss)
train_Accuracy.append(100*correct)
print(f"training Accuracy: {(100*correct):>0.1f}")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
#X = X.to("cuda")
#y = y.to("cuda")
pred = model(X)
test_loss = loss_fn(pred, y).item()
correct = (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
test_Avg_loss.append(test_loss)
test_Accuracy.append(100*correct)
print(f"Test Accuracy: {(100*correct):>0.2f}%, Test Avg loss: {test_loss:>8f}n")
vgg = models.vgg16(pretrained=True) # 如果是电脑第一次运行,它会从网络下载VGG16模型
#if torch.cuda.is_available():
# vgg = vgg.cuda()
print(vgg)
model = vgg
model.classifier[-1].out_features = 36 # 本来VGG16是1000个分类,但当前问题只有36个分类
print(model)
transform1 = transforms.Compose([transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(p=0.5),
#transforms.RandomRotation(degrees=(-5, 5)),
transforms.ToTensor(), # 把数据处理成[0,1]
transforms.Normalize(0, 1)])
transform2 = transforms.Compose([transforms.Resize((128, 128)),
transforms.ToTensor(), # 把数据处理成[0,1]
transforms.Normalize(0, 1)])
if __name__ == "__main__":
training_data = datasets.ImageFolder(r"F:datasetsFruits and Vegetables Image Recognitiontrain",
transform=transform1)
test_data = datasets.ImageFolder(r"F:datasetsFruits and Vegetables Image Recognitiontest", transform=transform2)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
#device = "cuda" if torch.cuda.is_available() else "cpu"
#train_feat_loader = train_dataloader
#test_feat_loader = test_dataloader
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vgg.classifier.parameters(), lr=1e-3, weight_decay=0.001)
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.1, weight_decay=0.01)
# 设置动态学习率,每step_size 个 epochs后, lr *= gamma 。
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
# 设置动态学习率,每step_size 个 epochs后, lr *= gamma 。
from matplotlib import pyplot as plt
from matplotlib import ticker
train_Accuracy = []
train_Avg_loss = []
test_Accuracy = []
test_Avg_loss = []
epochs = 60
for epoch in range(epochs):
print(f"Epoch {epoch 1}:")
#momentum = 0 if t < 10 else 0.9
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=momentum, weight_decay=1e-5)
## 设置动态学习率,每step_size 个 epochs后, lr *= gamma 。
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
print(f"current learning rate is {scheduler.get_last_lr()[0]}")
train(train_dataloader, model, loss_fn, optimizer)
scheduler.step()
test(test_dataloader, model, loss_fn)
print("Done.")
torch.save(model.state_dict(), "Fruits_and_Vegetables.pth")
print("Saved PyTorch Model State to Fruits_and_Vegetables.pth")
# 绘图显示正确率和平均损失
plt.subplot(2, 1, 1)
plt.plot(range(1, epochs 1), train_Accuracy, "r-", label="train_Accuracy")
plt.plot(range(1, epochs 1), test_Accuracy, "b-", label="test_Accuracy")
plt.xlabel("Epoch")
xticker_formatter = ticker.FuncFormatter(lambda x, pos: "%d" % x)
plt.gca().xaxis.set_major_formatter(xticker_formatter)
plt.ylabel("Accuracy[%]")
plt.legend()
plt.grid()
plt.subplot(2, 1, 2)
plt.plot(range(1, epochs 1), train_Avg_loss, "r-", label="train_Avg_loss")
plt.plot(range(1, epochs 1), test_Avg_loss, "b-", label="test_Avg_loss")
plt.xlabel("Epoch")
plt.gca().xaxis.set_major_formatter(xticker_formatter)
plt.ylabel("Avg_loss")
plt.legend()
plt.grid()
plt.savefig("Accuracy and loss plot.png")
plt.show()
本例中没有使用GPU计算,应为我的GPU太水,带不动VGG16模型。
准确率和平均损失见下图。有点奇怪的是,测试集的损失竟然一开始小于训练集。
最后是用验证集来验证分类效果
代码语言:javascript复制import os
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from Fruits_and_vegetables_basedOn_Vgg16 import model, transform2
model.load_state_dict(torch.load("Fruits_and_Vegetables.pth"))
dataset_train_path = r"F:datasetsFruits and Vegetables Image Recognitiontrain"
validation_data = datasets.ImageFolder(r"F:datasetsFruits and Vegetables Image Recognitionvalidation",
transform=transform2)
validation_dataloader = DataLoader(validation_data, batch_size=64, shuffle=False)
classes = os.listdir(dataset_train_path)
print(classes)
model.eval()
with torch.no_grad():
n = 0
wrong = 0
for i in range(len(validation_data)):
x, y = validation_data[i][0], validation_data[i][1]
x = Variable(torch.unsqueeze(x, dim=0), requires_grad=False)
# 模型只有全连接层时不用这句。用卷积时不加的话会报类似以下错误: Expected 4-dimensional input for 4-dimensional weight [16, 1, 2, 2],
# but got 3-dimensional input of size [1, 28, 28] instead
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
n = 1
print(f'Predicted: "{predicted}", Actual: "{actual}"')
if predicted != actual:
wrong = 1
print(f"validation Accuracy[%] is {(n-wrong)*100.0/n:>0.2f}")
验证集每个分类有10张图片,总共360张图片,验证集的预测正确率和训练集的差不多。
代码语言:javascript复制Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "paprika", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "sweetpotato", Actual: "sweetpotato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "tomato", Actual: "tomato"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "turnip", Actual: "turnip"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
Predicted: "watermelon", Actual: "watermelon"
validation Accuracy[%] is 96.30
其中有一张判错是把地瓜看成了红辣椒。