明月机器学习系列028:一个机器学习问题的解决过程

2021-10-28 11:43:29 浏览数 (1)

最近几天解决了一个问题,觉得可以写一写,问题大概是这样子的:

1. 问题背景


最近一直在做文档识别与文档比对,总体上是先用OCR模型识别出文本行,每个文本行使用一个box来表示(box就是一个矩形,使用左上角和右下角的坐标来表示),但是文字检测模型出来的效果并不是很理想,类似下面的情况并不少见:

说明:上面三个是出现问题的截图,红色框所在的是识别的box。

显然本来是同一行的可能会被识别成多个box(最好的情况当然是一行文本识别成一个box)。

2. 旧的解决方案:聚类


原来的实现方案使用聚类算法,将可以合并的box聚成一个类,然后进行合并。这里当时不用分类的原因,是因为不确定需要聚多少个类。确定用聚类之后,怎么计算两个box之间的距离就是算法的关键之处,说干就干,实现起来并不复杂,距离函数最后的实现大概如下:

代码语言:javascript复制
def distance(box_a, box_b):
    """计算两个box的距离"""
    ax1, ay1, ax2, ay2 = box_a
    bx1, by1, bx2, by2 = box_b
    if all(box_a == box_b):
        return 0.

    w1, w2 = ax2-ax1, bx2-bx1
    h1, h2 = ay2-ay1, by2-by1
    wh1, wh2 = w1/h1, w2/h2
    print(wh1, wh2)
    if wh1 > 1.5 and wh2 > 1.5:
        if abs(h1-h2)/min(h1, h2) > 0.1:
            # print('dist 1')
            return 1000.    # 高度差太大,不能合并
    if wh1 < 0.3 or wh2 < 0.3:
        # 很可能是两个竖排的文字
        # print('dist 2')
        return 1000.0
    if wh1 < 1.5 and wh2 < 1.5:
        # 两个都是高瘦型
        # 注意:两个竖排的长条形不应该合并在一起
        in_y = intersection_line((ay1, ay2), (by1, by2))
        # print(in_y/h1, in_y/h2)
        if in_y/h1 > 0.85 and in_y/h2 > 0.85:
            min_size = min(w1, w2, h1, h2)
            # print('dist 3')
            return abs(max(ax1, bx1)-min(ax2, bx2)) / min_size
        # print('dist 4')
        return 1000.0

    # 一个比较长,一个比较短
    hs = []
    if wh1 >= 1.5:
        hs.append(h1)
    if wh2 >= 1.5:
        hs.append(h2)

    dist = abs(max(ax1, bx1)-min(ax2, bx2)) / min(hs)
    # print(dist, hs)
    return dist

印象中当时调整这函数里面的各个参数也是调了不少时间的。当时感觉效果还可以,也就忙其他的事去了。

有了距离函数,其实聚类只是一个水到渠成的事情。

3. 问题的再次提起


最近在测试文本检测算法,这个问题又重新被捡起来了,其实效果真的就比较一般,这里不是说聚类算法效果不好,而是这个手动实现的距离算法,非常难以调试,很难判断调整一个参数之后,会造成多少影响。

既然这手动设计的距离算法效果不好,那就用分类算法吧,对于两个box我们完全可以进行分成两类:一类是可以合并的,另一类是不可以合并的,距离的作用其实也是于此。

4. 使用分类模型计算距离


于是距离的计算就变成了一个二分类问题。

开始有点犹豫要不要这么实现的,因为这里数据收集就是一个问题,效果还不太好判断(虽然直觉告诉我,使用分类算法替代距离效果会不错),不过相对于要去调那堆手动设置的距离算法,还是宁愿费点时间来收集与处理数据。

总体上分成三个步骤:

  1. 收集数据
  2. 处理数据
  3. 特征工程
  4. 训练模型

4.1 收集数据


在预先准备好的pdf文档上进行文字识别,将同一行上识别到多个box的记录下来,并且按可以是否可以合并记录到两个文件中,其中可以合并的数据(文件名merge.txt)大概格式如下:

代码语言:javascript复制
[[304, 1898, 868, 1934], [850, 1898, 1401, 1934]]
[[250, 316, 488, 345], [472, 311, 876, 348], [891, 309, 1221, 345], [1193, 309, 1405, 353]]
[[245, 238, 390, 277], [410, 238, 452, 275], [519, 241, 547, 275]]

每一行会有至少两个以上的box,这些box是可以进行合并的。

不可以进行合并的数据(文件名not-merge.txt)格式大概也类似:

代码语言:javascript复制
[[284, 1488, 423, 1522], [519, 1488, 609, 1520]]
[[281, 1651, 400, 1683], [516, 1649, 624, 1688]]
[[599, 911, 731, 940], [832, 911, 963, 940], [1059, 916, 1206, 938]]

这里的box是不可以进行合并的。

之所以这样设计这个原始数据格式,主要是为了方便后续收集数据,有新数据时,只要往这两个文件增加即可。

4.2 处理数据


现在只有原始数据,我们还不能进行模型训练,我们需要处理成指标的格式。

处理可以合并的数据是比较简单的:

代码语言:javascript复制
def parse_file(filename):
    with open(filename) as r:
        lines = r.readlines()

    lines = [text.strip() for text in lines]
    lines = [json.loads(t) for t in lines]
    lines = [sorted(boxes, key=lambda x: x[0]) for boxes in lines]
    return lines

data = []
# 处理可以合并的box数据
merge_lines = parse_file('./merge.txt')
for boxes in merge_lines:
    for idx in range(len(boxes)-1):
        box = [0]   boxes[idx]   boxes[idx 1]
        data.append(box)

就是将原始数据中的每一行的相邻的两个box组成一条记录,这两个box是可以合并的,这里使用label=0来表示,因为我们最终要的是距离。

而不可以进行合并的数据就比较复杂一点了,主要是不可以合并的原始数据收集更加困难,大多数识别到的box都是可以合并的。这里需要做数据扩增:

代码语言:javascript复制
not_merge_lines = parse_file('./not-merge.txt')
for boxes in not_merge_lines:
    for i in range(len(boxes)-1):
        for j in range(i 1, len(boxes)):
            box = [1]   boxes[i]   boxes[j]
            data.append(box)

            box = box.copy()
            box[1] -= random.randint(3, 7)
            data.append(box)

            box = box.copy()
            box[1]  = random.randint(1, 3)
            box[-2]  = random.randint(3, 6)
            data.append(box)

            box = box.copy()
            box[1]  = random.randint(1, 3)
            box[-2] -= random.randint(1, 4)
            hi = boxes[i][3] - boxes[i][1]
            hj = boxes[j][3] - boxes[j][1]
            if hi > hj:
                box[2] -= random.randint(1, 3)
                box[4]  = random.randint(1, 3)
            else:
                box[-3] -= random.randint(1, 3)
                box[-1]  = random.randint(1, 3)

            data.append(box)

不可以合并的原始数据里,同一行里的任意两个box都是不可以合并的,这也是这里有三个for循环的原因。

因为数据量不足,我们将一条记录扩增成了4条记录,不过扩增的时候要注意,不应该出现不合理的扩增,所以这里的扩增是三个方向:

  1. 左边的box可以往左
  2. 右边的box可以往右
  3. 高度比较大的box可以增大高度

其他的方向可能会导致原来不可以合并的,变成可以合并。

最后,将处理好的数据保存到csv文件:

代码语言:javascript复制
columns = ['label', 'ax1', 'ay1', 'ax2', 'ay2']
columns  = ['bx1', 'by1', 'bx2', 'by2']
data = pd.DataFrame(data, columns=columns)
data = data.drop_duplicates()
data.to_csv('data.csv')

4.3 特征工程


有了上面处理过的数据其实已经可以直接进行模型训练了,但是如果那样效果很可能是比较差的,当然如果我们有非常大量的样本数据,可能也可以直接进行训练,可是我们没有。于是,特征工程就变得很重要。

在特征工程时,有两个关键点是需要考虑的:

  1. 要使特征与位置无关,也就是同时平移两个box,结果应该是不变的;
  2. 要将重要特征识别出来。

最后实现大概如下:

代码语言:javascript复制
# 读取数据
data = pd.read_csv('./data.csv')

# 构造特征
data['ha'] = data.ay2 - data.ay1
data['hb'] = data.by2 - data.by1
data['h_min'] = data[['ha', 'hb']].min(axis=1)
data['ha'] = data.ha/data.h_min
data['hb'] = data.hb/data.h_min
data['wa'] = (data.ax2 - data.ax1)/data.h_min
data['wb'] = (data.bx2 - data.bx1)/data.h_min
data['h_rate'] = data.ha/data.hb
data['x_diff'] = (data.bx1 - data.ax2)/data.h_min
# 垂直方向的重叠程度
data['h_inter'] = (data[['ay2', 'by2']].min(axis=1) -
                   data[['ay1', 'by1']].max(axis=1))/data.h_min
total = len(data)
print("Total", total)
print('0: %d, 1: %d' % (total-data.label.sum(), data.label.sum()))

构造的特征并没有使用具体的坐标位置,所以也就具有了可平移性。指标做了标准化,可以适应放大缩小的情况。这里的标准化并没有使用0-1区间映射,或者正态分布映射,而是选择了一个比较的基准,这个基准就是两个box中的最小高度。

除了标准化之后的高度和宽度,还有三个额外的特征:

  1. h_rate:两个box的高度的比值,因为可以合并的box的高度应该是接近的;
  2. x_diff:右边box的左边的x坐标与左边box的右边的x坐标的差值的标准化值;
  3. h_inter:在垂直方向重叠的部分的标准化值。

这三个指标是分析得到的(事实上,原有手动设置的算法,主要考虑的点也是这些)。

4.4 训练模型


完成了上面的步骤之后,训练其实并没有太多可讲的了,就是一个调参与优化的过程:

代码语言:javascript复制
from joblib import dump

# 生成数据
data_cp = data.copy()
x_columns = ['ha', 'hb', 'wa', 'wb', 'h_rate', 'x_diff', 'h_inter']
X = data_cp[x_columns]
Y = data_cp[['label']]
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.25)

# 训练
lg = LogisticRegression(C=1.0)
lg.fit(x_train, y_train)
print(lg.coef_)

# 训练集评估
y_predict = lg.predict(x_train)
print("训练准确率:", lg.score(x_train, y_train))
print("评估指标:", classification_report(y_train, y_predict, labels=[0, 1],
                                     target_names=["合并", "不合并"]))

# 测试集评估
y_predict = lg.predict(x_test)
print("测试准确率:", lg.score(x_test, y_test))
print("评估指标:", classification_report(y_test, y_predict, labels=[0, 1],
                                     target_names=["合并", "不合并"]))

# 保存模型
dump(lg, 'merge_lines.joblib')

# 保存预测错误的数据
data['predict'] = lg.predict(X)
save_columns = ['ax1', 'ay1', 'ax2', 'ay2']
save_columns  = ['bx1', 'by1', 'bx2', 'by2']
save_columns  = ['label', 'predict']
data[save_columns].to_csv('output.csv', index=False)
data[save_columns][data.predict != data.label].to_csv('error.csv', index=False)

这里并没有使用太复杂的算法,而是直接使用常用的逻辑回归,毕竟简单的算法计算起来消耗的资源更少,也更快。关于模型训练,可以点击阅读原文。

训练好模型之后,应用就很简单了,不再细述。

5. 模型应用


应用反倒是最简单的了:

代码语言:javascript复制
from joblib import load

root_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(root_path, 'merge_lines.joblib')
merge_lines_model = load(model_path)

def distance(box_a, box_b):
    """计算两个box的距离"""
    if all(box_a == box_b):
        return 0.
    if box_a[0] > box_b[0]:
        box_a, box_b = box_b, box_a

    ax1, ay1, ax2, ay2 = box_a
    bx1, by1, bx2, by2 = box_b
    ha, hb = ay2-ay1, by2-by1
    h_min = min(ha, hb)
    ha, hb = ha/h_min, hb/h_min
    wa, wb = (ax2-ax1)/h_min, (bx2-bx1)/h_min
    h_rate = ha/hb
    x_diff = (bx1-ax2)/h_min
    h_inter = (min(ay2, by2)-max(ay1, by1))/h_min
    x = [ha, hb, wa, wb, h_rate, x_diff, h_inter]
    # 是否可以合并
    label = merge_lines_model.predict([x])[0]
    # print(label)
    return label

相比于原来的解决方案,这段代码显然是更加容易维护的。

这个模型用来计算距离,其实外面还是套了一个聚类算法的:

代码语言:javascript复制
cluster = DBSCAN(eps=0.5, min_samples=2, metric=distance).fit(boxes)

6. 可以改进的地方


在写这个文章的时候,突然想到特征工程的时候,有两个特征应该是可以优化的:

  1. h_rate:这个值如果大于1,应该对其取倒数,这样这个值的取值范围就会落在0到1之间,其大小的意义也更加明确;
  2. h_inter:这个值应该类似IOU的计算可能会更好。

另外,还有几个点想到:

  • 我们直接使用了分类模型,其实使用回归模型应该是一个更好的选择,还可以通过一个参数选择来达到更好的控制。
  • 数据扩增可以优化,现在的随机性还不是太好。

0 人点赞