Tensorflow Object Detection API框架
基于tensorflow框架构建的快速对象检测模型构建、训练、部署框架,是针对计算机视觉领域对象检测任务的深度学习框架。之前tensorflow2.x一直不支持该框架,最近Tensorflow Object Detection API框架最近更新了,同时支持tensorflow1.x与tensorflow2.x。其中model zoo方面,tensorflow1.x基于COCO数据集预训练支持对象检测模型包括:
代码语言:javascript复制SSD,支持MobileNetv1/MobileNetv2/MobileNetv3/ResNet50 基础网络
Faster-RCNN,支持MobileNet/ResNet101/Inception基础网络
Mask-RCNN,支持ResNet101/ResNet50/Inception基础网络
Tensorflow2.x版本的模型库不仅支持tensorflow1.x这几种对象检测网络,还支持:
代码语言:javascript复制EfficientDet D0~EfficientDet D7
CenterNet HourGlass支持Box KeyPoint
CenterNet Resnet50 支持Box KeyPoint
此外还支持修改与注册自定义的对象检测模型训练。在框架安装方面也做了脚本优化,必以前更加的简介方便。官方github地址:
代码语言:javascript复制https://github.com/tensorflow/models/tree/master/research/object_detection
安装与配置
基于tensorflow1.x的Tensorflow Object Detection API框架,Windows环境霞依赖的软件与版本信息如下
代码语言:javascript复制-tensorflow1.15
-python3.6.5
-VS2015 VC
-CUDA10.0(可选)
-Git-2.19.0-64-bit
-protoc-3.4.0-win32
01
下载源码
安装好上述依赖的软件与对应版本之后,首先需要的获取Tensorflow Object Detection API框架源码,执行下面的git命令行即可:
回车获取代码,最终得到的源码目录为:D:tensorflowmodels
02
编译源文件
编译protos文件,使用protoc-3.4.2-win32作为编译工具,执行如下的命令行:
这样就完成了编译。
03
安装依赖与运行测试
安装依赖python库,运行测试完成测试,执行如下代码:
回车执行,完成依赖包安装!然后再执行:
运行结果如下:
说明tensorflow1.x版本的Tensorflow Object Detection API框架正确安装完成。可以进行模型训练与测试、部署导出等。把上面的命令行中的tf1改成tf2就会完成tensorflow2.x版本的对象检测框架安装与配置。
运行代码测试
使用SSD MobileNet模型基于tensorflow1.x版本的对象检测框架,完成实时对象检测,代码实现如下:
代码语言:javascript复制MODEL_NAME = 'ssd_mobilenet_v2_coco_2018_03_29'
MODEL_FILE = 'D:/tensorflow/' MODEL_NAME '.tar'
# Path to frozen detection graph
PATH_TO_CKPT = MODEL_NAME '/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('D:/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
while True:
ret, image_np = cap.read()
image_np = cv.flip(image_np, 1)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
cv.imshow('object detection', image_np)
c = cv.waitKey(1)
if c == 27: # ESC
cv.imwrite("D:/tensorflow/run_result.png", image_np)
cv.destroyAllWindows()
break
运行结果如下:
善始者实繁
克终者盖寡