前言
网上关于VGG模型的文章有很多,有介绍算法本身的,也有代码实现,但是很多代码只给出了模型的结构实现,并不包含数据准备的部分,这让人很难愉快的将代码迁移自己的任务中。为此,这篇博客接下来围绕着如何使用VGG实现自己的图像分类任务,从数据准备到实验验证。代码基于Python与TensorFlow实现,模型结构采用VGG-16,并且将很少的出现算法和理论相关的东西。
数据准备
下载数据和转换代码
大多数人自己的训练数据,一般都是传统的图片形式,如.jpg,.png等等,而图像分类任务的话,这些图片的天然组织形式就是一个类别放在一个文件夹里,那么有啥大众化的数据集是这样的组织形式呢?TensorFlow的FlowersData,它下载下来是这个样子:
一共有五类,每一类中都有几百张图,我们把这些数据组织成TFrecord形式,对应的博客在这里,源码的github在这里,FlowersData数据集在这里。 有上面这三个东西之后,就可以生成TFrecord文件了。
组织图片数据
首先将FlowersData文件夹下的数据分成两个部分,训练数据和测试数据,我把原文件五个类别中都拿出大概100张图左右,数据的构成和路径如下:
生成训练TFrecord
代码语言:javascript复制#图片路径
cwd = 'F:\flowersdata\trainimages\'
代码语言:javascript复制#文件路径
filepath = 'F:\flowersdata\tfrecord\train\'
代码语言:javascript复制classes=['daisy',
'dandelion',
'roses',
'sunflowers',
'tulips']
代码语言:javascript复制#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
代码语言:javascript复制#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
生成效果:
生成预测TFrecord
代码语言:javascript复制#图片路径
cwd = 'F:\flowersdata\testimages\'
代码语言:javascript复制#文件路径
filepath = 'F:\flowersdata\tfrecord\test\'
代码语言:javascript复制classes=['daisy',
'dandelion',
'roses',
'sunflowers',
'tulips']
代码语言:javascript复制#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
代码语言:javascript复制#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
生成效果:
训练模型
初始权重与源码下载
VGG-16的初始权重我上传到了百度云,在这里下载; VGG-16源码我上传到了github,在这里下载;
在源码中: train_and_val.py文件是最终要执行的文件,它定了训练和预测的过程; input_data.py是将上一步中生成的TFRecord文件组织成batch的过程; VGG.py定义了VGG-16的网络结构; tool.py是最底层,定义了一些卷积池化等操作。
训练模型
train_and_val.py文件修改:
代码语言:javascript复制if __name__=="__main__":
train()
#evaluate()
根据自己的路径修改:
代码语言:javascript复制#初始权重路径
pre_trained_weights = 'vgg16_pretrain/vgg16.npy'
#训练数据路径
train_data_dir = 'F:\flowersdata\tfrecord\train\traindata.tfrecords*'
test_data_dir =
#预测数据路径
'F:\flowersdata\tfrecord\test\testdata.tfrecords*'
#训练生成文件路径
train_log_dir = 'logs/train/'
#预测生成文件路径
val_log_dir = 'logs/val/'
根据自己的显存容量修改:
代码语言:javascript复制IMG_W = 224
IMG_H = 224
BATCH_SIZE = 8
训练过程每50个step打印loss; 每200个step计算一个batch中的准确率; 每1000个step保存一次权重。
预测
train_and_val.py文件修改:
代码语言:javascript复制if __name__=="__main__":
#train()
evaluate()
代码语言:javascript复制#训练过程中生成的权重
log_dir = 'logs/train/'
#预测数据集路径
test_data_dir = 'F:\flowersdata\tfrecord\test\testdata.tfrecords*'
#用于生成tf文件的图片数量
n_test = 502
打印测试样本总数; 打印正确预测的样本总数; 打印top_1。