本文接着介绍了Mask Rcnn目标分割算法如何训练自己数据集,对训练所需的文件以及训练代码进行详细的说明。
本文详细介绍在只有样本图片数据时,如果建立Mask Rcnn目标分割训练数据集的步骤。过程中用到的所有代码均已提供。
一、制作自己的数据集
1、labelme安装
自己的数据和上面数据的区别就在于没有.json标签文件,所以训练自己的数据关键步骤就是获取标签文件,制作标签需要用到labelme软件。我们在当前虚拟环境下直接安装:
activate py37_torch(这是我的虚拟环境)
pip install pyqt5 -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install labelme
安装好后,直接在命令行输入labelme就可以打开标注软件了
具体安装和使用方法可以参考以下链接
https://zhuanlan.zhihu.com/p/436614108
2、labelme制作标签
在labelme软件中打开图片所在文件夹,一次对图片边缘进行画点制作标签,并保存,每张图片的标签以.json文件保存在图片所在目录
标签保存到与图片同一路径下,对所有图片标注后,得到下面所示的数据集(每张图片下面为对应的标签.json文件)
3、将标签转换为coco数据集格式(一)(可直接进行第4步,这一步仅作为探索中间过程的记录)
(1)单个json文件转换为coco格式
在利用mask rcnn进行自己的数据集训练时,数据集的格式要采用coco格式,所以利用labelme自带的json_to_dataset将自己的.json文件转换。该文件所在路径如下图所示:
打开json_to_dataset.py文件,对保存路径进行修改,修改为自己转换后的路径即可。
生成的文件夹下一共包含5个文件(这里缺少了一个yaml,后面会介绍如何获取yaml)
(2)批量转换
但是这样一个一个文件转太麻烦,因此我们需要写一个程序自动将所有.json文件一次性转换。
这里我写了一个批量转换的程序
My_json_to_dataset.py
代码语言:javascript复制import argparse
import base64
import json
import os
import os.path as osp
import imgviz
import PIL.Image
from labelme.logger import logger
from labelme import utils
import glob
# 最前面加入导包
import yaml
def main():
logger.warning(
"This script is aimed to demonstrate how to convert the "
"JSON file to a single image dataset."
)
logger.warning(
"It won't handle multiple JSON files to generate a "
"real-use dataset."
)
parser = argparse.ArgumentParser()
###############################################增加的语句##############################
# parser.add_argument("json_file")
parser.add_argument("--json_dir",default="D:/2021file/Biye/Mask_RCNN-master/samples/Mydata")
###############################################end###################################
parser.add_argument("-o", "--out", default=None)
args = parser.parse_args()
###############################################增加的语句##############################
assert args.json_dir is not None and len(args.json_dir) > 0
# json_file = args.json_file
json_dir = args.json_dir
if osp.isfile(json_dir):
json_list = [json_dir] if json_dir.endswith('.json') else []
else:
json_list = glob.glob(os.path.join(json_dir, '*.json'))
###############################################end###################################
for json_file in json_list:
json_name = osp.basename(json_file).split('.')[0]
out_dir = args.out if (args.out is not None) else osp.join(osp.dirname(json_file), json_name)
###############################################end###################################
if not osp.exists(out_dir):
os.makedirs(out_dir)
data = json.load(open(json_file))
imageData = data.get("imageData")
if not imageData:
imagePath = os.path.join(os.path.dirname(json_file), data["imagePath"])
with open(imagePath, "rb") as f:
imageData = f.read()
imageData = base64.b64encode(imageData).decode("utf-8")
img = utils.img_b64_to_arr(imageData)
label_name_to_value = {"_background_": 0}
for shape in sorted(data["shapes"], key=lambda x: x["label"]):
label_name = shape["label"]
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
lbl, _ = utils.shapes_to_label(
img.shape, data["shapes"], label_name_to_value
)
label_names = [None] * (max(label_name_to_value.values()) 1)
for name, value in label_name_to_value.items():
label_names[value] = name
lbl_viz = imgviz.label2rgb(
lbl, imgviz.asgray(img), label_names=label_names, loc="rb"
)
PIL.Image.fromarray(img).save(osp.join(out_dir, "img.png"))
utils.lblsave(osp.join(out_dir, "label.png"), lbl)
PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, "label_viz.png"))
with open(osp.join(out_dir, "label_names.txt"), "w") as f:
for lbl_name in label_names:
f.write(lbl_name "n")
logger.info("Saved to: {}".format(out_dir))
if __name__ == "__main__":
main()
程序运行结果:
最后每个json文件都生成了一个新的文件夹
Labelme版本不同生成的mask图片位数也不同,要求是8位的,检查一下生成的label.png图片位数,如果你的不是8位的,则需要进一步转化,我的生成的是8位的,所以这里不需要转换了。
4、将标签转换为coco数据集格式(二)
修改My_json_to_dataset.py,增加了生成yaml文件代码
上面说到,生成的文件夹里没有.yaml文件,于是对上面的My_json_to_dataset.py进行修改。
修改后的文件如下:
My_json_to_dataset.py
代码语言:javascript复制import argparse
import base64
import json
import os
import os.path as osp
import imgviz
import PIL.Image
from labelme.logger import logger
from labelme import utils
import glob
# 最前面加入导包
import yaml
def main():
logger.warning(
"This script is aimed to demonstrate how to convert the "
"JSON file to a single image dataset."
)
logger.warning(
"It won't handle multiple JSON files to generate a "
"real-use dataset."
)
parser = argparse.ArgumentParser()
###############################################增加的语句##############################
# parser.add_argument("json_file")
parser.add_argument("--json_dir",default="D:/2021file/Biye/Mask_RCNN-master/samples/Mydata")
###############################################end###################################
parser.add_argument("-o", "--out", default=None)
args = parser.parse_args()
###############################################增加的语句##############################
assert args.json_dir is not None and len(args.json_dir) > 0
# json_file = args.json_file
json_dir = args.json_dir
if osp.isfile(json_dir):
json_list = [json_dir] if json_dir.endswith('.json') else []
else:
json_list = glob.glob(os.path.join(json_dir, '*.json'))
###############################################end###################################
for json_file in json_list:
json_name = osp.basename(json_file).split('.')[0]
out_dir = args.out if (args.out is not None) else osp.join(osp.dirname(json_file), json_name)
###############################################end###################################
if not osp.exists(out_dir):
os.makedirs(out_dir)
data = json.load(open(json_file))
imageData = data.get("imageData")
if not imageData:
imagePath = os.path.join(os.path.dirname(json_file), data["imagePath"])
with open(imagePath, "rb") as f:
imageData = f.read()
imageData = base64.b64encode(imageData).decode("utf-8")
img = utils.img_b64_to_arr(imageData)
label_name_to_value = {"_background_": 0}
for shape in sorted(data["shapes"], key=lambda x: x["label"]):
label_name = shape["label"]
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
lbl, _ = utils.shapes_to_label(
img.shape, data["shapes"], label_name_to_value
)
label_names = [None] * (max(label_name_to_value.values()) 1)
for name, value in label_name_to_value.items():
label_names[value] = name
lbl_viz = imgviz.label2rgb(
lbl, imgviz.asgray(img), label_names=label_names, loc="rb"
)
PIL.Image.fromarray(img).save(osp.join(out_dir, "img.png"))
utils.lblsave(osp.join(out_dir, "label.png"), lbl)
PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, "label_viz.png"))
with open(osp.join(out_dir, "label_names.txt"), "w") as f:
for lbl_name in label_names:
f.write(lbl_name "n")
logger.info("Saved to: {}".format(out_dir))
#######
#增加了yaml生成部分
logger.warning('info.yaml is being replaced by label_names.txt')
info = dict(label_names=label_names)
with open(osp.join(out_dir, 'info.yaml'), 'w') as f:
yaml.safe_dump(info, f, default_flow_style=False)
logger.info('Saved to: {}'.format(out_dir))
if __name__ == "__main__":
main()
这时就有了yaml文件
5、将数据整理成模型认可的形式
首先我们建立train_data文件夹,里面新建四个文件夹,
cv2_mask、json、labelme_json、pic
json:labelme生成的json文件
labelme_json:My_json_to_dataset生成的文件夹
pic:原图
cv2_mask:My_json_to_datase生成文件夹中的png格式label文件
除了cv2_mask文件夹,其它三个文件夹的内容在前面已经都生成了,放入到对应文件夹下即可
6、生成cv2_mask文件内容
先别着急把生成的label.png图存到cv2_mask文件夹里面,在此之前,为了适应模型内部默认的路径格式,需要对label.png进行简单的重命名(否则就要去代码里边改,比较麻烦)。
比如我的原图叫KS001,那这个png的图就应该改成KS001.png,这里写了个脚本进行批量转换。
rename.py
代码语言:javascript复制# 把label.png改名为原图名.png
import os
for root, dirs, names in os.walk("D:/2021file/Biye/Mask_RCNN-master/samples/Mydata/train_data/labelme_json"): # 改成你自己的json文件夹所在的目录
for dr in dirs:
file_dir = os.path.join(root, dr)
# print(dr)
file = os.path.join(file_dir, 'label.png')
# print(file)
new_name = dr.split('_')[0] '.png'
new_file_name = os.path.join(file_dir, new_name)
os.rename(file, new_file_name)
运行完代码后发现文件夹下的图片名称已经得到了修改
将名称修改后的mask图放入cv2_mask文件夹中,这里同样写了个脚本进行批量复制:
creat_mask.py
代码语言:javascript复制import os
from shutil import copyfile
for root, dirs, names in os.walk("D:/2021file/Biye/Mask_RCNN-master/samples/Mydata/train_data/labelme_json"): # 改成你自己的json文件夹所在的目录
for dr in dirs:
file_dir = os.path.join(root, dr)
print(dr)
file = os.path.join(file_dir, dr '.png')
print(file)
new_name = dr.split('_')[0] '.png'
new_file_name = os.path.join(file_dir, new_name)
print(new_file_name)
tar_root = 'D:/2021file/Biye/Mask_RCNN-master/samples/Mydata/train_data/cv2_mask' # 目标路径
tar_file = os.path.join(tar_root, new_name)
copyfile(new_file_name, tar_file)
运行:
运行完发现cv2_mask文件夹下已经有了全部mask图
二、训练
准备好以上数据集,即可以开始进行训练了
Mytrain.py
代码语言:javascript复制# -*- coding: utf-8 -*-
import os
import sys
import random
import math
import re
import time
import numpy as np
import cv2
# import matplotlib
# import matplotlib.pyplot as plt
import tensorflow as tf
from mrcnn.config import Config
# import utils
from mrcnn import model as modellib, utils
from mrcnn import visualize
import yaml
from mrcnn.model import log
from PIL import Image
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Root directory of the project
ROOT_DIR = os.getcwd()
# ROOT_DIR = os.path.abspath("../")
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
iter_num = 0
# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
# if not os.path.exists(COCO_MODEL_PATH):
# utils.download_trained_weights(COCO_MODEL_PATH)
class ShapesConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
# Give the configuration a recognizable name
NAME = "shapes"
# Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# Number of classes (including background)
NUM_CLASSES = 2 1 # background 1 shapes 注意这里我是2类,所以是2 1
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 256
IMAGE_MAX_DIM = 1024
# Use smaller anchors because our image and objects are small
# RPN_ANCHOR_SCALES = (8 * 6, 16 * 6, 32 * 6, 64 * 6, 128 * 6) # anchor side in pixels
RPN_ANCHOR_SCALES = (16 * 6, 32 * 6, 64 * 6, 128 * 6, 256 * 6) # 我的图片中目标比较大,所以我把anchor的尺寸也设置的大了一点
# Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 100
# Use a small epoch since the data is simple
STEPS_PER_EPOCH = 50 # 每个epoch中迭代的step,最好不要改动
# use small validation steps since the epoch is small
VALIDATION_STEPS = 50
config = ShapesConfig()
config.display()
class DrugDataset(utils.Dataset):
# 得到该图中有多少个实例(物体)
def get_obj_index(self, image):
n = np.max(image)
return n
# 解析labelme中得到的yaml文件,从而得到mask每一层对应的实例标签
def from_yaml_get_class(self, image_id):
info = self.image_info[image_id]
with open(info['yaml_path']) as f:
#temp = yaml.load(f.read())
temp = yaml.load(f.read(),Loader=yaml.FullLoader)
labels = temp['label_names']
del labels[0]
return labels
# 重新写draw_mask
def draw_mask(self, num_obj, mask, image, image_id):
# print("draw_mask-->",image_id)
# print("self.image_info",self.image_info)
info = self.image_info[image_id]
# print("info-->",info)
# print("info[width]----->",info['width'],"-info[height]--->",info['height'])
for index in range(num_obj):
for i in range(info['width']):
for j in range(info['height']):
# print("image_id-->",image_id,"-i--->",i,"-j--->",j)
# print("info[width]----->",info['width'],"-info[height]--->",info['height'])
at_pixel = image.getpixel((i, j))
if at_pixel == index 1:
mask[j, i, index] = 1
return mask
# 重新写load_shapes,里面包含自己的自己的类别
# 并在self.image_info信息中添加了path、mask_path 、yaml_path
# yaml_pathdataset_root_path = "/dateset/"
# img_floder = dataset_root_path "rgb"
# mask_floder = dataset_root_path "mask"
# dataset_root_path = "/tongue_dateset/"
def load_shapes(self, count, img_floder, mask_floder, imglist, dataset_root_path):
"""Generate the requested number of synthetic images.
count: number of images to generate.
height, width: the size of the generated images.
"""
# Add classes
self.add_class("shapes", 1, "KS")
self.add_class("shapes", 2, "MS")
# self.add_class("shapes", 3, "leibie3")
# self.add_class("shapes", 4, "leibie4")
for i in range(count):
# 获取图片宽和高
print(i)
filestr = imglist[i].split(".")[0]
# print(imglist[i],"-->",cv_img.shape[1],"--->",cv_img.shape[0])
# print("id-->", i, " imglist[", i, "]-->", imglist[i],"filestr-->",filestr)
# filestr = filestr.split("_")[1]
mask_path = mask_floder "/" filestr ".png"
yaml_path = dataset_root_path "labelme_json/" filestr "/info.yaml"
print(dataset_root_path "labelme_json/" filestr "/img.png")
cv_img = cv2.imread(dataset_root_path "labelme_json/" filestr "/img.png")
print(type(cv_img))
self.add_image("shapes", image_id=i, path=img_floder "/" imglist[i],
width=cv_img.shape[1], height=cv_img.shape[0], mask_path=mask_path, yaml_path=yaml_path)
# 重写load_mask
def load_mask(self, image_id):
"""Generate instance masks for shapes of the given image ID.
"""
global iter_num
print("image_id", image_id)
info = self.image_info[image_id]
count = 1 # number of object
img = Image.open(info['mask_path'])
num_obj = self.get_obj_index(img)
mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
mask = self.draw_mask(num_obj, mask, img, image_id)
occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
for i in range(count - 2, -1, -1):
mask[:, :, i] = mask[:, :, i] * occlusion
occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
labels = []
labels = self.from_yaml_get_class(image_id)
labels_form = []
for i in range(len(labels)):
if labels[i].find("KS") != -1:
labels_form.append("KS")
elif labels[i].find("MS") != -1:
labels_form.append("MS")
# elif labels[i].find("leibie3") != -1:
# labels_form.append("leibie3")
# elif labels[i].find("leibie4") != -1:
# labels_form.append("leibie4")
class_ids = np.array([self.class_names.index(s) for s in labels_form])
return mask, class_ids.astype(np.int32)
'''
def get_ax(rows=1, cols=1, size=8):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes.
Change the default size attribute to control the size
of rendered images
"""
_, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
return ax
'''
# 基础设置
dataset_root_path = "D:/2021file/Biye/Mask_RCNN-master/samples/Mydata/train_data/" # 你的数据的路径
img_floder = dataset_root_path "pic"
mask_floder = dataset_root_path "cv2_mask"
# yaml_floder = dataset_root_path
imglist = os.listdir(img_floder)
count = len(imglist)
# train与val数据集准备
dataset_train = DrugDataset()
dataset_train.load_shapes(count, img_floder, mask_floder, imglist, dataset_root_path)
dataset_train.prepare()
# print("dataset_train-->",dataset_train._image_ids)
dataset_val = DrugDataset()
dataset_val.load_shapes(count, img_floder, mask_floder, imglist, dataset_root_path)
dataset_val.prepare()
# print("dataset_val-->",dataset_val._image_ids)
# Load and display random samples
# image_ids = np.random.choice(dataset_train.image_ids, 4)
# for image_id in image_ids:
# image = dataset_train.load_image(image_id)
# mask, class_ids = dataset_train.load_mask(image_id)
# visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)
# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)
# Which weights to start with?
init_with = "coco" # imagenet, coco, or last
if init_with == "imagenet":
model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
# Load weights trained on MS COCO, but skip layers that
# are different due to the different number of classes
# See README for instructions to download the COCO weights
# print(COCO_MODEL_PATH)
model.load_weights(COCO_MODEL_PATH, by_name=True,
exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
elif init_with == "last":
# Load the last model you trained and continue training
model.load_weights(model.find_last()[1], by_name=True)
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=10,
layers='heads') # 固定其他层,只训练head,epoch为10
# Fine tune all layers
# Passing layers="all" trains all layers. You can also
# pass a regular expression to select which layers to
# train by name pattern.
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=10,
layers="all") # 微调所有层的参数,epoch为10
代码中部分数据相关描述需要修改成你自己的数据描述
(1)首先修改数据集路径:
修改类别名称,定位到def load_shapes 120行,加入数据集中的类别
(2)定位到NUM_CLASSES,55行
修改为对应的类别数,注意要在类别数上 1,这里认为背景为一类
# Number of classes (including background)
NUM_CLASSES = 2 1 # background 1 shapes 注意这里我是2类,所以是2 1
(3)定位到def load_mask,其中的类别也要做相应修改
(4)正常运行,开始训练
训练过程日志以及权重保存在logs中的最新一个文件夹中,同样可以使用tensorboard对训练过程进行查看
tensorboard –logdir=”D:/2021file/Biye/Mask_RCNN-master/logs/shapes20220417T2356″
运行后在浏览器打开下面地址:(有时候运行框给我的地址是http://ComputerDing:6006,并不可用,所以我直接手动进入下面地址 )
http://localhost:6006/
四、训练时可能出现的报错及解决方法
1、报错:
AttributeError: ‘NoneType’ object has no attribute ‘shape’
报错原因是没有找到文件,仔细一看它输出的文件路径和我的实际的有一点不一样,我的是
D:/2021file/Biye/Mask_RCNN-master/samples/Mydata/train_data/labelme_json/KS001/img.png
它的是
D:/2021file/Biye/Mask_RCNN-master/samples/Mydata/train_data/labelme_json/KS001_json/img.png
解决方法:
注意这里三处都要修改过来
2、报错:
OSError: Unable to open file (unable to open file: name = ‘D:2021fileBiyeMask_RCNN-mastermask_rcnn_coco.h5’, errno = 2, error message = ‘No such file or directory’, flags = 0, o_flags = 0)
将下载的mask_rcnn_coco.h5权重文件放到’D:2021fileBiyeMask_RCNN-master目录下即可
3、报错:TypeError: load() missing 1 required positional argument: ‘Loader’
原因分析:
由于Yaml 5.1版本后弃用了 yaml.load(file) 这个用法。Yaml 5.1版本之后就修改了需要指定Loader,通过默认加载器(FullLoader)禁止执行任意函数,使得此load函数的安全得到加强。
解决方法:
temp = yaml.load(f.read())
替换为:
temp = yaml.load(f.read(),Loader=yaml.FullLoader)
4、报错:IndexError: boolean index did not match indexed array along dimension 0; dimension is 0 but corresponding boolean dimension is 1
原因分析:自己数据集的类别没有在程序中加入
定位到def load_shapes 120行,加入数据集中的类别
注意def load_mask中的类别也要做相应修改
5、报错:ValueError: Error when checking input: expected input_image_meta to have shape (17,) but got array with shape (15,)
原因分析:报这个错误,应该是你的类别数忘记修改过来了
# Number of classes (including background)
NUM_CLASSES = 2 1 # background 1 shapes 注意这里我是2类,所以是2 1
五、测试
Mytest.py
代码语言:javascript复制# -*- coding: utf-8 -*-
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt
import cv2
import time
from mrcnn.config import Config
from datetime import datetime
# Root directory of the project
ROOT_DIR = os.getcwd()
# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
# Import COCO config
# sys.path.append(os.path.join(ROOT_DIR, "samples/coco/")) # To find local version
# from samples.coco import coco
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
# Local path to trained weights file
COCO_MODEL_PATH = "D:/2021file/BiyeMask_RCNN-master/logs/shapes20220417T2356/mask_rcnn_shapes_0010.h5" # 模型保存目录
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH)
print("***********************")
# Directory of images to run detection on
#随机检测时用
#IMAGE_DIR = os.path.join(ROOT_DIR, "images")
class ShapesConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
# Give the configuration a recognizable name
NAME = "shapes"
# Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# Number of classes (including background)
NUM_CLASSES = 2 1 # background 2 shapes
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 320
IMAGE_MAX_DIM = 384
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8 * 6, 16 * 6, 32 * 6, 64 * 6, 128 * 6) # anchor side in pixels
# Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 100
# Use a small epoch since the data is simple
STEPS_PER_EPOCH = 50
# use small validation steps since the epoch is small
VALIDATION_STEPS = 50
# import train_tongue
# class InferenceConfig(coco.CocoConfig):
class InferenceConfig(ShapesConfig):
# Set batch size to 1 since we'll be running inference on
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
# Create model object in inference mode.
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
# Load weights trained on MS-COCO
model.load_weights(COCO_MODEL_PATH, by_name=True)
# COCO Class names
# Index of the class in the list is its ID. For example, to get ID of
# the teddy bear class, use: class_names.index('teddy bear')
# 注意修改类别名称
class_names = ['BG','KS', 'MS']
# Load a random image from the images folder
#有两种方式,一种是从文件夹中随机读取图片,另一种是指定图片路径
#file_names = next(os.walk(IMAGE_DIR))[2]
image = skimage.io.imread("F:/jk/KS/22_Color.png") # 你想要测试的图片
a = datetime.now()
# Run detection
results = model.detect([image], verbose=1)
b = datetime.now()
# Visualize results
print("Dec_time:", (b - a).seconds)
r = results[0]
#print(r)
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
class_names, r['scores'])
需修改以下地方
(1)训练得到的模型文件***.h5路径
(2)类别数
(3)类别名称(注意加上背景BG)
检测图片路径
若报错:IndexError: list index out of range
原因是类别名称里面忘记加了背景
(4)运行测试结果
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/186031.html原文链接:https://javaforall.cn