YOLO算法,统计标注的xml
文件中包含的标记框信息
创建parse_dataset_annotation.py
import os
import sys
filedir = os.path.dirname(sys.argv[0]) #获取脚本所在目录
os.chdir(filedir) #将脚本所在的目录设置为工作目录
wdir = os.getcwd()
print('当前工作目录:{}n'.format(wdir)) #打印当前工作目录
from xml.dom.minidom import parse
def xml_parser( xml_file ):
'''
Parse an xml file and return the annotation info in the file
:param xml_file: the xml file name to be parsed
:return: file_name, width, height, objects.
file_name, filename of the xml file (without extension)
width, width of the annotated image
height, height of the annotated image
objects, annotated objects in the image
object, (object_name, xmin, ymin, xmax, ymax)
object_name, name of the annotated object
xmin, ymin, xmax, ymax, coordinate of the bounding box of the object
'''
DOMTree = parse( xml_file )
collection = DOMTree.documentElement #得到xml文件的根节点
file_name_xml = collection.getElementsByTagName( 'filename' )[0]
objects_xml = collection.getElementsByTagName( 'object' )
size_xml = collection.getElementsByTagName( 'size' )
file_name = file_name_xml.childNodes[0].data
for size in size_xml:
width = size.getElementsByTagName( 'width' )[0]
height = size.getElementsByTagName( 'height' )[0]
width = width.childNodes[0].data
height = height.childNodes[0].data
objects = []
for object_xml in objects_xml:
object_name = object_xml.getElementsByTagName( 'name' )[0]
bdbox = object_xml.getElementsByTagName( 'bndbox' )[0]
xmin = bdbox.getElementsByTagName( 'xmin' )[0]
ymin = bdbox.getElementsByTagName( 'ymin' )[0]
xmax = bdbox.getElementsByTagName( 'xmax' )[0]
ymax = bdbox.getElementsByTagName( 'ymax' )[0]
object = [ object_name.childNodes[0].data,
float(xmin.childNodes[0].data),
float(ymin.childNodes[0].data),
float(xmax.childNodes[0].data),
float(ymax.childNodes[0].data) ]
objects.append( object )
return file_name, int(width), int(height), objects
image_dir = 'images'
xml_dir = 'labels'
xml_files = os.listdir(xml_dir)
image_files = os.listdir(image_dir)
image_ext = image_files[0].split('.')[-1] #图片文件的扩展名
print(image_ext)
if len(image_files) == len(xml_files):
print('共有{:d}个xml文件。'.format(len(xml_files)))
else:
print('图片数量和xml文件数量不一致。')
obj_dict = {}
for xml_file in xml_files:
annotation = xml_parser(os.path.join(xml_dir, xml_file))
name_1 = xml_file.split('.')[0] '.' image_ext.lower()
name_2 = xml_file.split('.')[0] '.' image_ext.upper()
if name_1 not in image_files and name_2 not in image_files:
print('{:s}没有对应的图片。'.format(xml_file))
for obj in annotation[-1]:
key = obj[0]
x = (obj[1] obj[3])/2
y = (obj[2] obj[4])/2
width = obj[3] - obj[1]
height = obj[4] - obj[2]
box = [x,y,width,height]
if key in obj_dict:
obj_dict[key][0] = 1
n = obj_dict[key][0]
obj_dict[key][1:5] = [ (i*(n-1) j)/(n) for i,j in zip(obj_dict[key][1:5] , box)]
#obj_dict[key][5:9] = [ i if i>=j else j for i,j in zip(obj_dict[key][5:9] , box)]
#obj_dict[key][9:] = [ i if i>=j else j for i,j in zip(obj_dict[key][9:] , box)]
else:
obj_dict[key] = []
obj_dict[key].append(1) # 0,个数
obj_dict[key] = box # 1-4, 平均坐标
#obj_dict[key] = box # 5-8, 最大值
#obj_dict[key] = box # 9-12,最小值
for key,value in obj_dict.items():
print('一共有 {:4d} 个 {:20s},其边框平均位置为{:4.0f} *{:4.0f};平均尺寸为{:3.0f} *{:3.0f}。'.format(value[0],key,*value[1:]))
发布者:全栈程序员栈长,转转请注明出处:https://javaforall.cn/2159.html原文链接: