tf2-yolov3训练自己的数据集

2021-06-21 17:43:15 浏览数 (2)

tf2相比于tf1来说更加的友好,支持了Eager模式,代码和keras基本相同,所以代码也很简单,下面就如何用tf2-yolov3训练自己的数据集。 项目的代码包:链接: tf2-yolov3.需要自行下载 至于tf2-yolov3的原理可以参考这个链接,我觉得是讲的最好一个:链接: yolov3算法的一点理解.

tf2-yolov3训练自己的数据集

  • 1、配置相关的环境
  • 2、使用官方权重进行预测
  • 3、训练自己的模型文件,并且识别
    • 1)建立数据集文件夹
    • 2)添加图片并且标注(labelimg软件)
    • 3)建立.txt文件
    • 4)建立标签.names文件
    • 5)生成tfrecord文件(train和val)
    • 6)进行迁移训练
    • 7)进行模型测试

1、配置相关的环境

我是在linux上跑的,linux上配环境比较简单,相关windows配环境可以看这个博客: 链接: tensorflow-gpu环境搭建超级详细博客.

2、使用官方权重进行预测

1、进入到目标文件夹内

代码语言:javascript复制
cd yolo_tf2.1/

2、输入 python convert.py 生成tf可用的模型

输出的yolov3.tf 保存在checkpoint里面。 3、开始检测 1)检测照片:

代码语言:javascript复制
python detect.py --image ./data/people.jpg

这样便是成功的

2)打开摄像头进行预测:

代码语言:javascript复制
python detect_video.py --video 0

3) 对视频流进行预测

代码语言:javascript复制
python detect_video.py --video test.mp4 --output ./test_output.avi

经过以上测试,表示这个代码包可以正常的使用了,就可以利用TensorFlow2-yolov3来进行检测了,下一步我们来介绍一下如何训练自己的数据集。

3、训练自己的模型文件,并且识别

1)建立数据集文件夹

其中Annootation:存放标注好的**.xml**文件 JPEGImages : 自己搜集好的一些图片

2)添加图片并且标注(labelimg软件)

软件的下载地址:目标检测标注工具labelImg使用方法 记得要将图片保存到Annootation文件夹里面 …直到标注完所有的图片

3)建立.txt文件

//VOC2012//ImageSets//Main路径下

把你要训练的还有验证的数据文件都给写到.txt文件里面,方便程序对数据进行读取。

下面这段程序可以获取图片名称,因为每个人的图片的名称不一样,所以需要做相应的调整:

代码语言:javascript复制
import os,glob
path = r"C:UsersTSKDesktopyolo_tf2.1VOCdevkitVOC2012JPEGImages"
path_list=os.listdir(path)
path_list.sort()  #对列表进行格式化
for i in path_list[0:320] :  #训练的样本
       print(i[:-4] " -1")
for i in path_list[320:-1] :  #验证的样本
       print(i[:-4] " -1")

4)建立标签.names文件

在yolo_tf2.1/data文件夹下

里面写入的就是自己要训练的类别,有哪些类,就写入那些名称。

5)生成tfrecord文件(train和val)

这个文件的作用大概就是:这么多的图片,你让TensorFlow挨个去读取的话,很占内存,很费时间,原来很占内存,现在只用占一点点,终究一个还是节省内存,读取速度加快。 通过 .txt文件来读取

看自己的 .txt 文件是什么名字,这个地方得相应的改一下 训练集:

代码语言:javascript复制
python tools/voc2012.py --data_dir ./VOCdevkit_fire/VOC2012 --split train --output_file ./data/voc2012_train_dlsb.tfrecord --classes ./data/dlsb.names

先解释一下部分含义,感觉没啥好解释的,都是字面意思 (捂脸笑) 一开始可能会出现这种情况,转tfrecord文件的时候可能会出点问题

然后我百度了一下,发现是这样一个原因: 错误的意思是:Unicode的解码(Decode)出现错误了,以gbk编码的方式去解码(该字符串变成Unicode),但是此处通过gbk的方式,却无法解码(can’tdecode).’'illegalmultibyte sequence"的意思是非法的多字节序列,也就是说无法解码了。 我在源代码中添加了这个就可以正确的执行了,encoding = 'utf-8'如下:

我觉得还是那个.txt文件的格式不对,所以他读取不了,给它特定的格式就能够正确的读取了。

测试集:

代码语言:javascript复制
python tools/voc2012.py  --data_dir ./VOCdevkit_fire/VOC2012  --split val  --output_file  ./data/voc2012_val_dlsb.tfrecord --classes ./data/dlsb.names

出现这样表示已经转tfrecord成功。

6)进行迁移训练

代码语言:javascript复制
python train.py --dataset ./data/voc2012_train_dlsb.tfrecord --val_dataset ./data/voc2012_val_dlsb.tfrecord --classes ./data/dlsb.names --num_classes 3 --mode fit --transfer darknet --batch_size 4 --epochs 150 --weights ./checkpoints/yolov3.tf --weights_num_classes 80

先简单的进行150个epochs的训练:

静待结果。。。。

损失函数下降到16,还不是特别好

7)进行模型测试

代码语言:javascript复制
python detect.py --classes ./data/three.names --num_classes 3 --weights ./checkpoints/yolov3_train_150.tf --image ./000002.jpg --yolo_score_threshold 0.3

准确度还不是很高…正在改进中…

0 人点赞