最近几天解决了一个问题,觉得可以写一写,问题大概是这样子的:
1. 问题背景
最近一直在做文档识别与文档比对,总体上是先用OCR模型识别出文本行,每个文本行使用一个box来表示(box就是一个矩形,使用左上角和右下角的坐标来表示),但是文字检测模型出来的效果并不是很理想,类似下面的情况并不少见:
说明:上面三个是出现问题的截图,红色框所在的是识别的box。
显然本来是同一行的可能会被识别成多个box(最好的情况当然是一行文本识别成一个box)。
2. 旧的解决方案:聚类
原来的实现方案使用聚类算法,将可以合并的box聚成一个类,然后进行合并。这里当时不用分类的原因,是因为不确定需要聚多少个类。确定用聚类之后,怎么计算两个box之间的距离就是算法的关键之处,说干就干,实现起来并不复杂,距离函数最后的实现大概如下:
代码语言:javascript复制def distance(box_a, box_b):
"""计算两个box的距离"""
ax1, ay1, ax2, ay2 = box_a
bx1, by1, bx2, by2 = box_b
if all(box_a == box_b):
return 0.
w1, w2 = ax2-ax1, bx2-bx1
h1, h2 = ay2-ay1, by2-by1
wh1, wh2 = w1/h1, w2/h2
print(wh1, wh2)
if wh1 > 1.5 and wh2 > 1.5:
if abs(h1-h2)/min(h1, h2) > 0.1:
# print('dist 1')
return 1000. # 高度差太大,不能合并
if wh1 < 0.3 or wh2 < 0.3:
# 很可能是两个竖排的文字
# print('dist 2')
return 1000.0
if wh1 < 1.5 and wh2 < 1.5:
# 两个都是高瘦型
# 注意:两个竖排的长条形不应该合并在一起
in_y = intersection_line((ay1, ay2), (by1, by2))
# print(in_y/h1, in_y/h2)
if in_y/h1 > 0.85 and in_y/h2 > 0.85:
min_size = min(w1, w2, h1, h2)
# print('dist 3')
return abs(max(ax1, bx1)-min(ax2, bx2)) / min_size
# print('dist 4')
return 1000.0
# 一个比较长,一个比较短
hs = []
if wh1 >= 1.5:
hs.append(h1)
if wh2 >= 1.5:
hs.append(h2)
dist = abs(max(ax1, bx1)-min(ax2, bx2)) / min(hs)
# print(dist, hs)
return dist
印象中当时调整这函数里面的各个参数也是调了不少时间的。当时感觉效果还可以,也就忙其他的事去了。
有了距离函数,其实聚类只是一个水到渠成的事情。
3. 问题的再次提起
最近在测试文本检测算法,这个问题又重新被捡起来了,其实效果真的就比较一般,这里不是说聚类算法效果不好,而是这个手动实现的距离算法,非常难以调试,很难判断调整一个参数之后,会造成多少影响。
既然这手动设计的距离算法效果不好,那就用分类算法吧,对于两个box我们完全可以进行分成两类:一类是可以合并的,另一类是不可以合并的,距离的作用其实也是于此。
4. 使用分类模型计算距离
于是距离的计算就变成了一个二分类问题。
开始有点犹豫要不要这么实现的,因为这里数据收集就是一个问题,效果还不太好判断(虽然直觉告诉我,使用分类算法替代距离效果会不错),不过相对于要去调那堆手动设置的距离算法,还是宁愿费点时间来收集与处理数据。
总体上分成三个步骤:
- 收集数据
- 处理数据
- 特征工程
- 训练模型
4.1 收集数据
在预先准备好的pdf文档上进行文字识别,将同一行上识别到多个box的记录下来,并且按可以是否可以合并记录到两个文件中,其中可以合并的数据(文件名merge.txt)大概格式如下:
代码语言:javascript复制[[304, 1898, 868, 1934], [850, 1898, 1401, 1934]]
[[250, 316, 488, 345], [472, 311, 876, 348], [891, 309, 1221, 345], [1193, 309, 1405, 353]]
[[245, 238, 390, 277], [410, 238, 452, 275], [519, 241, 547, 275]]
每一行会有至少两个以上的box,这些box是可以进行合并的。
不可以进行合并的数据(文件名not-merge.txt)格式大概也类似:
代码语言:javascript复制[[284, 1488, 423, 1522], [519, 1488, 609, 1520]]
[[281, 1651, 400, 1683], [516, 1649, 624, 1688]]
[[599, 911, 731, 940], [832, 911, 963, 940], [1059, 916, 1206, 938]]
这里的box是不可以进行合并的。
之所以这样设计这个原始数据格式,主要是为了方便后续收集数据,有新数据时,只要往这两个文件增加即可。
4.2 处理数据
现在只有原始数据,我们还不能进行模型训练,我们需要处理成指标的格式。
处理可以合并的数据是比较简单的:
代码语言:javascript复制def parse_file(filename):
with open(filename) as r:
lines = r.readlines()
lines = [text.strip() for text in lines]
lines = [json.loads(t) for t in lines]
lines = [sorted(boxes, key=lambda x: x[0]) for boxes in lines]
return lines
data = []
# 处理可以合并的box数据
merge_lines = parse_file('./merge.txt')
for boxes in merge_lines:
for idx in range(len(boxes)-1):
box = [0] boxes[idx] boxes[idx 1]
data.append(box)
就是将原始数据中的每一行的相邻的两个box组成一条记录,这两个box是可以合并的,这里使用label=0来表示,因为我们最终要的是距离。
而不可以进行合并的数据就比较复杂一点了,主要是不可以合并的原始数据收集更加困难,大多数识别到的box都是可以合并的。这里需要做数据扩增:
代码语言:javascript复制not_merge_lines = parse_file('./not-merge.txt')
for boxes in not_merge_lines:
for i in range(len(boxes)-1):
for j in range(i 1, len(boxes)):
box = [1] boxes[i] boxes[j]
data.append(box)
box = box.copy()
box[1] -= random.randint(3, 7)
data.append(box)
box = box.copy()
box[1] = random.randint(1, 3)
box[-2] = random.randint(3, 6)
data.append(box)
box = box.copy()
box[1] = random.randint(1, 3)
box[-2] -= random.randint(1, 4)
hi = boxes[i][3] - boxes[i][1]
hj = boxes[j][3] - boxes[j][1]
if hi > hj:
box[2] -= random.randint(1, 3)
box[4] = random.randint(1, 3)
else:
box[-3] -= random.randint(1, 3)
box[-1] = random.randint(1, 3)
data.append(box)
不可以合并的原始数据里,同一行里的任意两个box都是不可以合并的,这也是这里有三个for循环的原因。
因为数据量不足,我们将一条记录扩增成了4条记录,不过扩增的时候要注意,不应该出现不合理的扩增,所以这里的扩增是三个方向:
- 左边的box可以往左
- 右边的box可以往右
- 高度比较大的box可以增大高度
其他的方向可能会导致原来不可以合并的,变成可以合并。
最后,将处理好的数据保存到csv文件:
代码语言:javascript复制columns = ['label', 'ax1', 'ay1', 'ax2', 'ay2']
columns = ['bx1', 'by1', 'bx2', 'by2']
data = pd.DataFrame(data, columns=columns)
data = data.drop_duplicates()
data.to_csv('data.csv')
4.3 特征工程
有了上面处理过的数据其实已经可以直接进行模型训练了,但是如果那样效果很可能是比较差的,当然如果我们有非常大量的样本数据,可能也可以直接进行训练,可是我们没有。于是,特征工程就变得很重要。
在特征工程时,有两个关键点是需要考虑的:
- 要使特征与位置无关,也就是同时平移两个box,结果应该是不变的;
- 要将重要特征识别出来。
最后实现大概如下:
代码语言:javascript复制# 读取数据
data = pd.read_csv('./data.csv')
# 构造特征
data['ha'] = data.ay2 - data.ay1
data['hb'] = data.by2 - data.by1
data['h_min'] = data[['ha', 'hb']].min(axis=1)
data['ha'] = data.ha/data.h_min
data['hb'] = data.hb/data.h_min
data['wa'] = (data.ax2 - data.ax1)/data.h_min
data['wb'] = (data.bx2 - data.bx1)/data.h_min
data['h_rate'] = data.ha/data.hb
data['x_diff'] = (data.bx1 - data.ax2)/data.h_min
# 垂直方向的重叠程度
data['h_inter'] = (data[['ay2', 'by2']].min(axis=1) -
data[['ay1', 'by1']].max(axis=1))/data.h_min
total = len(data)
print("Total", total)
print('0: %d, 1: %d' % (total-data.label.sum(), data.label.sum()))
构造的特征并没有使用具体的坐标位置,所以也就具有了可平移性。指标做了标准化,可以适应放大缩小的情况。这里的标准化并没有使用0-1区间映射,或者正态分布映射,而是选择了一个比较的基准,这个基准就是两个box中的最小高度。
除了标准化之后的高度和宽度,还有三个额外的特征:
- h_rate:两个box的高度的比值,因为可以合并的box的高度应该是接近的;
- x_diff:右边box的左边的x坐标与左边box的右边的x坐标的差值的标准化值;
- h_inter:在垂直方向重叠的部分的标准化值。
这三个指标是分析得到的(事实上,原有手动设置的算法,主要考虑的点也是这些)。
4.4 训练模型
完成了上面的步骤之后,训练其实并没有太多可讲的了,就是一个调参与优化的过程:
代码语言:javascript复制from joblib import dump
# 生成数据
data_cp = data.copy()
x_columns = ['ha', 'hb', 'wa', 'wb', 'h_rate', 'x_diff', 'h_inter']
X = data_cp[x_columns]
Y = data_cp[['label']]
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.25)
# 训练
lg = LogisticRegression(C=1.0)
lg.fit(x_train, y_train)
print(lg.coef_)
# 训练集评估
y_predict = lg.predict(x_train)
print("训练准确率:", lg.score(x_train, y_train))
print("评估指标:", classification_report(y_train, y_predict, labels=[0, 1],
target_names=["合并", "不合并"]))
# 测试集评估
y_predict = lg.predict(x_test)
print("测试准确率:", lg.score(x_test, y_test))
print("评估指标:", classification_report(y_test, y_predict, labels=[0, 1],
target_names=["合并", "不合并"]))
# 保存模型
dump(lg, 'merge_lines.joblib')
# 保存预测错误的数据
data['predict'] = lg.predict(X)
save_columns = ['ax1', 'ay1', 'ax2', 'ay2']
save_columns = ['bx1', 'by1', 'bx2', 'by2']
save_columns = ['label', 'predict']
data[save_columns].to_csv('output.csv', index=False)
data[save_columns][data.predict != data.label].to_csv('error.csv', index=False)
这里并没有使用太复杂的算法,而是直接使用常用的逻辑回归,毕竟简单的算法计算起来消耗的资源更少,也更快。关于模型训练,可以点击阅读原文。
训练好模型之后,应用就很简单了,不再细述。
5. 模型应用
应用反倒是最简单的了:
代码语言:javascript复制from joblib import load
root_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(root_path, 'merge_lines.joblib')
merge_lines_model = load(model_path)
def distance(box_a, box_b):
"""计算两个box的距离"""
if all(box_a == box_b):
return 0.
if box_a[0] > box_b[0]:
box_a, box_b = box_b, box_a
ax1, ay1, ax2, ay2 = box_a
bx1, by1, bx2, by2 = box_b
ha, hb = ay2-ay1, by2-by1
h_min = min(ha, hb)
ha, hb = ha/h_min, hb/h_min
wa, wb = (ax2-ax1)/h_min, (bx2-bx1)/h_min
h_rate = ha/hb
x_diff = (bx1-ax2)/h_min
h_inter = (min(ay2, by2)-max(ay1, by1))/h_min
x = [ha, hb, wa, wb, h_rate, x_diff, h_inter]
# 是否可以合并
label = merge_lines_model.predict([x])[0]
# print(label)
return label
相比于原来的解决方案,这段代码显然是更加容易维护的。
这个模型用来计算距离,其实外面还是套了一个聚类算法的:
代码语言:javascript复制cluster = DBSCAN(eps=0.5, min_samples=2, metric=distance).fit(boxes)
6. 可以改进的地方
在写这个文章的时候,突然想到特征工程的时候,有两个特征应该是可以优化的:
- h_rate:这个值如果大于1,应该对其取倒数,这样这个值的取值范围就会落在0到1之间,其大小的意义也更加明确;
- h_inter:这个值应该类似IOU的计算可能会更好。
另外,还有几个点想到:
- 我们直接使用了分类模型,其实使用回归模型应该是一个更好的选择,还可以通过一个参数选择来达到更好的控制。
- 数据扩增可以优化,现在的随机性还不是太好。