一、实验介绍
本实验实现了一个简化版VGG网络,并基于此完成图像分类任务。
VGG网络是深度卷积神经网络中的经典模型之一,由牛津大学计算机视觉组(Visual Geometry Group)提出。它在2014年的ImageNet图像分类挑战中取得了优异的成绩(分类任务第二,定位任务第一),被广泛应用于图像分类、目标检测和图像生成等任务。 VGG网络的主要特点是使用了非常小的卷积核尺寸(通常为3x3)和更深的网络结构。该网络通过多个卷积层和池化层堆叠在一起,逐渐增加网络的深度,从而提取图像的多层次特征表示。VGG网络的基本构建块是由连续的卷积层组成,每个卷积层后面跟着一个ReLU激活函数。在每个卷积块的末尾,都会添加一个最大池化层来减小特征图的尺寸。VGG网络的这种简单而有效的结构使得它易于理解和实现,并且在不同的任务上具有很好的泛化性能。 VGG网络有几个不同的变体,如VGG11、VGG13、VGG16和VGG19,它们的数字代表网络的层数。这些变体在网络深度和参数数量上有所区别,较深的网络通常具有更强大的表示能力,但也更加复杂。
二、实验环境
本系列实验使用了PyTorch深度学习框架,相关操作如下:
1. 配置虚拟环境
代码语言:javascript复制conda create -n DL python=3.7
代码语言:javascript复制conda activate DL
代码语言:javascript复制pip install torch==1.8.1 cu102 torchvision==0.9.1 cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
代码语言:javascript复制conda install matplotlib
代码语言:javascript复制 conda install scikit-learn
2. 库版本介绍
软件包 | 本实验版本 | 目前最新版 |
---|---|---|
matplotlib | 3.5.3 | 3.8.0 |
numpy | 1.21.6 | 1.26.0 |
python | 3.7.16 | |
scikit-learn | 0.22.1 | 1.3.0 |
torch | 1.8.1 cu102 | 2.0.1 |
torchaudio | 0.8.1 | 2.0.2 |
torchvision | 0.9.1 cu102 | 0.15.2 |
三、实验内容
ChatGPT:
卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,广泛应用于图像识别、计算机视觉和模式识别等领域。它的设计灵感来自于生物学中视觉皮层的工作原理。 卷积神经网络通过多个卷积层、池化层和全连接层组成。
- 卷积层主要用于提取图像的局部特征,通过卷积操作和激活函数的处理,可以学习到图像的特征表示。
- 池化层则用于降低特征图的维度,减少参数数量,同时保留主要的特征信息。
- 全连接层则用于将提取到的特征映射到不同类别的概率上,进行分类或回归任务。
卷积神经网络在图像处理方面具有很强的优势,它能够自动学习到具有层次结构的特征表示,并且对平移、缩放和旋转等图像变换具有一定的不变性。这些特点使得卷积神经网络成为图像分类、目标检测、语义分割等任务的首选模型。除了图像处理,卷积神经网络也可以应用于其他领域,如自然语言处理和时间序列分析。通过将文本或时间序列数据转换成二维形式,可以利用卷积神经网络进行相关任务的处理。
0. 导入必要的工具包
代码语言:javascript复制import torch
from torch import nn
import torch.nn.functional as F
1. conv_layer(创建卷积块)
- 每个卷积块由三个层组成
nn.Conv2d
卷积层nn.BatchNorm2d
批量标准化层- ReLU激活层
def conv_layer(chann_in, chann_out, k_size, p_size):
layer = nn.Sequential(
nn.Conv2d(chann_in, chann_out, kernel_size=k_size, padding=p_size),
nn.BatchNorm2d(chann_out),
nn.ReLU()
)
return layer
-
nn.Conv2d(chann_in, chann_out, kernel_size=k_size, padding=p_size)
:二维卷积层,它将输入特征图进行卷积操作。chann_in
表示输入通道数,chann_out
表示输出通道数,kernel_size
表示卷积核尺寸,padding
表示填充大小。 -
nn.BatchNorm2d(chann_out)
:批量标准化层,用于对卷积层的输出进行标准化处理,加速网络训练过程,并增强网络的鲁棒性。 -
nn.ReLU()
:ReLU激活层,对卷积层输出进行非线性映射,引入非线性特征,增加网络的表达能力。
2. vgg_conv_block(卷积模块:卷积层、池化层)
由多个相同的卷积块和一个最大池化层组成。
代码语言:javascript复制def vgg_conv_block(in_list, out_list, k_list, p_list, pooling_k, pooling_s):
layers = [conv_layer(in_list[i], out_list[i], k_list[i], p_list[i]) for i in range(len(in_list)) ]
layers = [nn.MaxPool2d(kernel_size = pooling_k, stride = pooling_s)]
return nn.Sequential(*layers)
- 函数的输入参数包括:
in_list
、out_list
、k_list
、p_list
、pooling_k
和pooling_s
,分别表示每个卷积块的输入通道数、输出通道数、卷积核尺寸、填充大小,以及最大池化层的核大小和步长。
- 通过列表推导式和
conv_layer
函数创建了多个卷积块的层,并将它们按顺序存储在layers
列表中。然后,将最大池化层(nn.MaxPool2d
)的实例添加到layers
列表的末尾。 - 通过
nn.Sequential
将layers
列表中的层按顺序连接起来,并返回一个包含所有层的卷积模块。
3. vgg_fc_layer(全连接层)
全连接层由三个层组成:nn.Linear
线性层、nn.BatchNorm1d
批量标准化层和ReLU激活层。
def vgg_fc_layer(size_in, size_out):
layer = nn.Sequential(
nn.Linear(size_in, size_out),
nn.BatchNorm1d(size_out),
nn.ReLU()
)
return layer
- 函数的输入参数包括
size_in
和size_out
,它们分别表示输入特征的大小和输出特征的大小。 - 通过
nn.Sequential
将线性层、批量标准化层和ReLU激活层三个层按顺序连接起来,并返回一个全连接层的模块。
4. VGG_S(VGG模型简化版)
为了简化,我们少使用了几层卷积层。
代码语言:javascript复制class VGG_S(nn.Module):
def __init__ (self, num_classes):
super().__init__()
self.layer1 = vgg_conv_block([3,64], [64,64], [3,3], [1,1], 2, 2)
self.layer2 = vgg_conv_block([64,128], [128,128], [3,3], [1,1], 2, 2)
self.layer3 = vgg_conv_block([128,256,256], [256,256,256], [3,3,3], [1,1,1], 2, 2)
# 全连接层
self.layer4 = vgg_fc_layer(4096, 1024)
# Final layer
self.layer5 = nn.Linear(1024, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
vgg16_features = self.layer3(out)
out = vgg16_features.view(out.size(0), -1)
out = self.layer4(out)
out = self.layer5(out)
return out
a. __init__
- 通过调用
vgg_conv_block
函数创建了三个卷积模块(layer1
、layer2
和layer3
),并指定了它们的输入通道数、输出通道数、卷积核尺寸、填充大小以及最大池化层的核大小和步长。 - 创建一个全连接层(
layer4
),其中输入特征的大小为4096,输出特征的大小为1024。 - 通过
nn.Linear
创建了最后一层(layer5
),将1024维的特征映射到预测类别的数量。
b. forward
输入数据经过卷积部分的三个卷积模块,然后通过view
函数将特征展平成一维向量。接着,特征向量通过全连接层和最后一层进行预测,最终输出预测结果。