OpenVINO中的图象修复模型与示例演示

2021-12-29 07:45:40 浏览数 (1)

模型介绍与转换

在OpenVINO的公开模型库中有一个图象修复的模型的,它支持使用mask作为参考,实现对输入的修复。模型来自:

代码语言:javascript复制
https://github.com/shepnerd/inpainting_gmcnn.git

模型结果如下:

下载模型之后,通过OpenVINO的脚本支持一键转换为IR格式。下载命令行:

代码语言:javascript复制
python downloader.py --name gmcnn-places2-tf

转换命令行:

代码语言:javascript复制
python converter.py --name gmcnn-places2-tf

转换之后的模型输入格式如下:

代码语言:javascript复制
Placeholder - [BCHW] = 1x3x512x680 BGR, 图象
Placeholder_1 - [BCHW] = 1x1x512x680 mask 单通道二值图象

输出格式如下:

代码语言:javascript复制
1x3x512x680 BGR, 图象

图象修复代码演示

使用转换之后的模型,实现图象修复的代码演示。运行结果如下:

模型推理与输出解析的各步如下:

加载模型

代码语言:javascript复制
ie = IECore()
net = ie.read_network(model=gmcnn_xml, weights=gmcnn_bin)
its = iter(net.input_info)
input_blob1 = next(its)
input_blob2 = next(its)
print(input_blob1, input_blob2)
out_blob = next(iter(net.outputs))

_, c1, h1, w1 = net.input_info[input_blob1].input_data.shape # 三通道
_, c2, mh1, mw1 = net.input_info[input_blob2].input_data.shape # 单通道
exec_net = ie.load_network(network=net, device_name="CPU")
print(c1, c2)

预处理输入图象

代码语言:javascript复制
# 处理输入图象
# src = cv.imread("D:/images/grad.png")
# mask = cv.imread("D:/mask.png")
src = cv.imread("D:/images/1024.png")
mask = cv.imread("D:/images/1024_mask.png")
h, w, c = src.shape

# 生成待修复图象
dst = cv.add(src, mask)
cv.imshow("input", dst)
cv.imshow("mask", mask)

# 输入预处理,BGR三通道输入
image = cv.resize(dst, (w1, h1))
image = image.transpose(2, 0, 1)

# 单通道输入
gray_m = cv.cvtColor(mask, cv.COLOR_BGR2GRAY)
m = cv.resize(gray_m, (w1, h1))
ret, bin = cv.threshold(m, 1, 1, cv.THRESH_BINARY)
m = np.expand_dims(bin, axis=2)
m = m.transpose(2, 0, 1)

推理与解析输出并显示

代码语言:javascript复制
t0 = cv.getTickCount()
out_prob = exec_net.infer(inputs={input_blob1: [image], input_blob2: [m]})
infer_time = (cv.getTickCount() - t0) / cv.getTickFrequency()
output = out_prob[out_blob]
result = np.transpose(output, (0, 2, 3, 1)).astype(np.uint8)
result = np.squeeze(result, axis=0)
result_img = cv.resize(result, (w, h))
cv.putText(result_img, 'infer time: {:.2f} FPS'.format(float(1 / infer_time)), (5, 35), cv.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 200))
cv.imshow("image inpaint demo", result_img)
cv.waitKey(0)
cv.destroyAllWindows()

0 人点赞