Tensorflow Object Detection API 终于支持tensorflow1.x与tensorflow2.x了

2020-09-08 16:54:48 浏览数 (1)

Tensorflow Object Detection API框架

基于tensorflow框架构建的快速对象检测模型构建、训练、部署框架,是针对计算机视觉领域对象检测任务的深度学习框架。之前tensorflow2.x一直不支持该框架,最近Tensorflow Object Detection API框架最近更新了,同时支持tensorflow1.x与tensorflow2.x。其中model zoo方面,tensorflow1.x基于COCO数据集预训练支持对象检测模型包括:

代码语言:javascript复制
SSD,支持MobileNetv1/MobileNetv2/MobileNetv3/ResNet50 基础网络
Faster-RCNN,支持MobileNet/ResNet101/Inception基础网络
Mask-RCNN,支持ResNet101/ResNet50/Inception基础网络

Tensorflow2.x版本的模型库不仅支持tensorflow1.x这几种对象检测网络,还支持:

代码语言:javascript复制
EfficientDet D0~EfficientDet D7
CenterNet HourGlass支持Box KeyPoint
CenterNet Resnet50 支持Box KeyPoint

此外还支持修改与注册自定义的对象检测模型训练。在框架安装方面也做了脚本优化,必以前更加的简介方便。官方github地址:

代码语言:javascript复制
https://github.com/tensorflow/models/tree/master/research/object_detection

安装与配置

基于tensorflow1.x的Tensorflow Object Detection API框架,Windows环境霞依赖的软件与版本信息如下

代码语言:javascript复制
-tensorflow1.15
-python3.6.5
-VS2015 VC  
-CUDA10.0(可选)
-Git-2.19.0-64-bit
-protoc-3.4.0-win32

01

下载源码

安装好上述依赖的软件与对应版本之后,首先需要的获取Tensorflow Object Detection API框架源码,执行下面的git命令行即可:

回车获取代码,最终得到的源码目录为:D:tensorflowmodels

02

编译源文件

编译protos文件,使用protoc-3.4.2-win32作为编译工具,执行如下的命令行:

这样就完成了编译。

03

安装依赖与运行测试

安装依赖python库,运行测试完成测试,执行如下代码:

回车执行,完成依赖包安装!然后再执行:

运行结果如下:

说明tensorflow1.x版本的Tensorflow Object Detection API框架正确安装完成。可以进行模型训练与测试、部署导出等。把上面的命令行中的tf1改成tf2就会完成tensorflow2.x版本的对象检测框架安装与配置。

运行代码测试

使用SSD MobileNet模型基于tensorflow1.x版本的对象检测框架,完成实时对象检测,代码实现如下:

代码语言:javascript复制
MODEL_NAME = 'ssd_mobilenet_v2_coco_2018_03_29'
MODEL_FILE = 'D:/tensorflow/'   MODEL_NAME   '.tar'

# Path to frozen detection graph
PATH_TO_CKPT = MODEL_NAME   '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('D:/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)


def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)


with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        while True:
            ret, image_np = cap.read()
            image_np = cv.flip(image_np, 1)
            image_np_expanded = np.expand_dims(image_np, axis=0)
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            scores = detection_graph.get_tensor_by_name('detection_scores:0')
            classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')
            # Actual detection.
            (boxes, scores, classes, num_detections) = sess.run(
              [boxes, scores, classes, num_detections],
              feed_dict={image_tensor: image_np_expanded})
            # Visualization of the results of a detection.
            vis_util.visualize_boxes_and_labels_on_image_array(
              image_np,
              np.squeeze(boxes),
              np.squeeze(classes).astype(np.int32),
              np.squeeze(scores),
              category_index,
              use_normalized_coordinates=True,
              line_thickness=8)

            cv.imshow('object detection', image_np)
            c = cv.waitKey(1)
            if c == 27: # ESC
                cv.imwrite("D:/tensorflow/run_result.png", image_np)
                cv.destroyAllWindows()
                break

运行结果如下:

善始者实繁

克终者盖寡

0 人点赞