明月机器学习系列023:表格识别(二)

2021-10-28 15:29:27 浏览数 (1)

上图是表格识别的流程图,淡红色的在上一篇已经介绍过了,这次重点介绍淡绿色的部分:

  • 交点图像与聚合
  • 表格聚合
  • 聚合表格线

补充一点上次的曲线方程识别,对于我们要识别的是横线和竖线,而对于竖线可能会导致斜率无穷大,所以我们在实现是,对于横线我们使用曲线方程y=ax b,而对于竖线我们则反过来使用:x=ay b,这样不至于出现斜率的问题。

交点图像与聚合


我们已经有了二值化的横线和竖线图像,要求交点图像已经很简单,把两个图像叠加在一起即可:

代码语言:javascript复制
point = cv2.bitwise_and(col_img, row_img)

其实就是做and运算,两个图中都是白色的叠加之后才是白色,这样就出来交点图像了,如下:

图中的白点就是交点,不过和曲线一样,这些交点并不只是一个点,而是若干个点聚合在一起,具体跟交点线段的粗细有关。要想应用方便,我们需要先将其聚合成一个单一的点。这里我们还是使用DBSCAN算法:

使用一个中心点来替代相关的点即可。这里选择Manhattan距离,只是为了减少些计算量,使用默认的欧式距离也是一样的。

有了交点,还需要多做一步,就是要判断这些交点分别是在哪些线段上,因为我们已经有了每个曲线的方程和端点,

表格聚合


一个页面上的表格数量可能不会只有一个表格,所以在真正开始识别表格前,我们需要先清楚哪些表格线线段和交点是属于同一个表格的。

这里我们可能我们可以使用代码判断去讲有交点的线条都合并在同一个表格中,那样也不好维护。我们还是聚类算法,这里我们使用另一个Optics算法(DBSCAN算法的升级版本),这里选用Optics而不用DBSCAN的原因主要是我们之前已经实现过一次,支持自定义距离,而scikit-learn中的DBSCAN和Optics算法都不支持自定义距离。

聚类的关键点就是怎么计算不同线段之间的距离,显然这里使用欧式距离还是Manhattan等都是不行的了,我们需要定义自己的距离。两条线段之间的距离计算,如下:

  • 如果两个线段有交点,则距离为0;
  • 否则计算两个线段的两个端点之间的距离的最小值的和。如假设A线段有a和b两个端点,B线段有c和d两个端点,他们的交点是e,那么这两个线段的距离:
代码语言:javascript复制
# 其中dist是计算两个点距离的函数
min(dist(e, a), dist(e, b))   min(dist(e, c), dist(e, d))

不过对于我们自己的场景,是用于表格的横线和竖线上,计算交点和计算距离都可以进行简化。

代码语言:javascript复制
def distance(line1, line2):
    """计算两个线段的距离
    line_type: 线段类型,布尔值,True表示横线,False表示竖线
    a, b: 线段直线参数,y=ax b或者x=ay b,具体看line_type的值
    x1, y1, x2, y2: 线段的两个端点
    :param line1,line2: [line_type, a, b, x1, y1, x2, y2]
    """
    l_type1, a1, b1, x11, y11, x12, y12 = line1.data
    l_type2, a2, b2, x21, y21, x22, y22 = line2.data
    if l_type1 == l_type2:     # 平行
        return Optics.inf

    if l_type1 and not l_type2:
        # y=a1*x b1 and x=a2*y b2
        x0 = (x21 x22)/2   # line2
        y0 = (y11 y12)/2   # line1
    elif not l_type1 and l_type2:
        # x=a1*y b1 and y=a2*x b2
        x0 = (x11 x12)/2   # line1
        y0 = (y21 y22)/2   # line2

    def point_line_dist(l_type, x1, y1, x2, y2):
        """计算点到线的距离"""
        if l_type:
            if min(x1, x2)-5 <= x0 <= max(x1, x2) 5:
                return 0
            # 到两端点的最小距离
            return min(abs(x0-x1), abs(x0-x2))

        # 竖线
        if min(y1, y2)-5 <= y0 <= max(y1, y2) 5:
            return 0
        return min(abs(y0-y1), abs(y0-y2))

    # 计算到线1的距离
    dist1 = point_line_dist(l_type1, x11, y11, x12, y12)
    # 计算到线2的距离
    dist2 = point_line_dist(l_type2, x21, y21, x22, y22)
    return dist1   dist2

把这个距离函数传入Optics,作为一个参数即可:

代码语言:javascript复制
def do_table_cluster(lines, line_types, endpoints,
                     max_radius=3, min_samples=2, cluster_thr=2):
    """线段聚类
    线段方程类型为:
        True: y=ax b
        False: x=ay b
    :param lines list 线段方程参数[(a, b)]
    :param line_types list 线段方程的类型,跟lines参数对应,取值True or False
    :param endpoints list 线段的端点,注意每个线段有两个端点: [[(y1, x1), (y2, x2)]]
    :param min_samples, max_radius: 聚类参数
    :return labels
    """
    assert len(lines) == len(line_types) == len(endpoints)
    data = [(l_type, a, b, x1, y1, x2, y2)
            for ((a, b), l_type, ((y1, x1), (y2, x2))) in 
            zip(lines, line_types, endpoints)]
    optics = Optics(max_radius, min_samples, distance=distance)
    optics.fit(data)
    return optics.cluster(cluster_thr)

表格线聚合


我们已经知道了哪些线段和交点是属于同一个表格的,但是这些线段可能有些是属于同一行或者同一列的,如下的情形:

如果上图红色圈住的两条线段,是属于同一个列线,应该合并在一起,表格的行线可能也会出现这种情况。合并也比较不难,我们还是使用DBSCAN聚类算法,聚类时,只需要对横线或者竖线方程的b参数进行聚类即可,因为对于横线或者竖线他们的斜率正常来说是相差很小的,这里不再详述。

至此,我们已经知道这个表格n*m的表格线,已经其顶点及交点坐标,接下来就是进行识别了。

未完待续。。。

0 人点赞