Marker 源码解析(一)

2024-03-09 08:45:58 浏览数 (3)

.markerbenchmark.py

代码语言:javascript复制
import argparse
import tempfile
import time
from collections import defaultdict

from tqdm import tqdm

from marker.convert import convert_single_pdf
from marker.logger import configure_logging
from marker.models import load_all_models
from marker.benchmark.scoring import score_text
from marker.extract_text import naive_get_text
import json
import os
import subprocess
import shutil
import fitz as pymupdf
from tabulate import tabulate

# 配置日志记录
configure_logging()

# 定义函数,使用 Nougat 进行预测
def nougat_prediction(pdf_filename, batch_size=1):
    # 创建临时目录
    out_dir = tempfile.mkdtemp()
    # 运行 Nougat 命令行工具进行预测
    subprocess.run(["nougat", pdf_filename, "-o", out_dir, "--no-skipping", "--recompute", "--batchsize", str(batch_size)], check=True)
    # 获取生成的 Markdown 文件
    md_file = os.listdir(out_dir)[0]
    with open(os.path.join(out_dir, md_file), "r") as f:
        data = f.read()
    # 删除临时目录
    shutil.rmtree(out_dir)
    return data

# 主函数
def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser(description="Benchmark PDF to MD conversion.  Needs source pdfs, and a refernece folder with the correct markdown.")
    # 添加参数:输入 PDF 文件夹
    parser.add_argument("in_folder", help="Input PDF files")
    # 添加参数:参考 Markdown 文件夹
    parser.add_argument("reference_folder", help="Reference folder with reference markdown files")
    # 添加参数:输出文件名
    parser.add_argument("out_file", help="Output filename")
    # 添加参数:是否运行 Nougat 并比较
    parser.add_argument("--nougat", action="store_true", help="Run nougat and compare", default=False)
    # 添加参数:Nougat 批处理大小,默认为 1
    parser.add_argument("--nougat_batch_size", type=int, default=1, help="Batch size to use for nougat when making predictions.")
    # 添加参数:Marker 并行因子,默认为 1
    parser.add_argument("--marker_parallel_factor", type=int, default=1, help="How much to multiply default parallel OCR workers and model batch sizes by.")
    # 添加参数:生成的 Markdown 文件输出路径
    parser.add_argument("--md_out_path", type=str, default=None, help="Output path for generated markdown files")
    # 解析参数
    args = parser.parse_args()

    # 定义方法列表
    methods = ["naive", "marker"]
    if args.nougat:
        methods.append("nougat")

    # 加载所有模型
    model_lst = load_all_models()

    # 初始化得分字典
    scores = defaultdict(dict)
    # 获取指定文件夹中的所有文件列表
    benchmark_files = os.listdir(args.in_folder)
    # 筛选出以".pdf"结尾的文件列表
    benchmark_files = [b for b in benchmark_files if b.endswith(".pdf")]
    # 初始化存储时间信息的字典
    times = defaultdict(dict)
    # 初始化存储页数信息的字典
    pages = defaultdict(int)

    # 遍历每个 PDF 文件
    for fname in tqdm(benchmark_files):
        # 生成对应的 markdown 文件名
        md_filename = fname.rsplit(".", 1)[0]   ".md"

        # 获取参考文件的路径并读取内容
        reference_filename = os.path.join(args.reference_folder, md_filename)
        with open(reference_filename, "r") as f:
            reference = f.read()

        # 获取 PDF 文件的路径并打开
        pdf_filename = os.path.join(args.in_folder, fname)
        doc = pymupdf.open(pdf_filename)
        # 记录该 PDF 文件的页数
        pages[fname] = len(doc)

        # 遍历不同的方法
        for method in methods:
            start = time.time()
            # 根据不同方法进行处理
            if method == "marker":
                full_text, out_meta = convert_single_pdf(pdf_filename, model_lst, parallel_factor=args.marker_parallel_factor)
            elif method == "nougat":
                full_text = nougat_prediction(pdf_filename, batch_size=args.nougat_batch_size)
            elif method == "naive":
                full_text = naive_get_text(doc)
            else:
                raise ValueError(f"Unknown method {method}")

            # 计算处理时间并记录
            times[method][fname] = time.time() - start

            # 计算得分并记录
            score = score_text(full_text, reference)
            scores[method][fname] = score

            # 如果指定了 markdown 输出路径,则将处理结果写入文件
            if args.md_out_path:
                md_out_filename = f"{method}_{md_filename}"
                with open(os.path.join(args.md_out_path, md_out_filename), "w ") as f:
                    f.write(full_text)

    # 计算总页数
    total_pages = sum(pages.values())
    # 打开输出文件,以写入模式打开,如果文件不存在则创建
    with open(args.out_file, "w ") as f:
        # 创建一个默认字典,用于存储数据
        write_data = defaultdict(dict)
        # 遍历每个方法
        for method in methods:
            # 计算每个方法的总时间
            total_time = sum(times[method].values())
            # 为每个文件创建统计信息字典
            file_stats = {
                fname:
                {
                    "time": times[method][fname],
                    "score": scores[method][fname],
                    "pages": pages[fname]
                }
                for fname in benchmark_files
            }
            # 将文件统计信息和方法的平均分数、每页时间、每个文档时间存储到 write_data 中
            write_data[method] = {
                "files": file_stats,
                "avg_score": sum(scores[method].values()) / len(scores[method]),
                "time_per_page": total_time / total_pages,
                "time_per_doc": total_time / len(scores[method])
            }

        # 将 write_data 写入到输出文件中,格式化为 JSON 格式,缩进为 4
        json.dump(write_data, f, indent=4)

    # 创建两个空列表用于存储汇总表和分数表
    summary_table = []
    score_table = []
    # 分数表的表头为 benchmark_files
    score_headers = benchmark_files
    # 遍历每个方法
    for method in methods:
        # 将方法、平均分数、每页时间、每个文档时间添加到汇总表中
        summary_table.append([method, write_data[method]["avg_score"], write_data[method]["time_per_page"], write_data[method]["time_per_doc"]])
        # 将方法和每个文件的分数添加到分数表中
        score_table.append([method, *[write_data[method]["files"][h]["score"] for h in score_headers]])

    # 打印汇总表,包括方法、平均分数、每页时间、每个文档时间
    print(tabulate(summary_table, headers=["Method", "Average Score", "Time per page", "Time per document"]))
    print("")
    print("Scores by file")
    # 打印分数表,包括方法和每个文件的分数
    print(tabulate(score_table, headers=["Method", *score_headers]))
# 如果当前脚本被直接执行,则调用主函数
if __name__ == "__main__":
    main()

.markerchunk_convert.py

代码语言:javascript复制
# 导入 argparse 模块,用于解析命令行参数
import argparse
# 导入 subprocess 模块,用于执行外部命令
import subprocess

# 定义主函数
def main():
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(description="Convert a folder of PDFs to a folder of markdown files in chunks.")
    # 添加命令行参数,指定输入文件夹路径
    parser.add_argument("in_folder", help="Input folder with pdfs.")
    # 添加命令行参数,指定输出文件夹路径
    parser.add_argument("out_folder", help="Output folder")
    # 解析命令行参数
    args = parser.parse_args()

    # 构造要执行的 shell 命令
    cmd = f"./chunk_convert.sh {args.in_folder} {args.out_folder}"

    # 执行 shell 脚本
    subprocess.run(cmd, shell=True, check=True)

# 如果当前脚本作为主程序运行,则调用主函数
if __name__ == "__main__":
    main()

.markerconvert.py

代码语言:javascript复制
# 导入必要的库
import argparse
import os
from typing import Dict, Optional

import ray
from tqdm import tqdm
import math

# 导入自定义模块
from marker.convert import convert_single_pdf, get_length_of_text
from marker.models import load_all_models
from marker.settings import settings
from marker.logger import configure_logging
import traceback
import json

# 配置日志记录
configure_logging()

# 定义一个远程函数,用于处理单个 PDF 文件
@ray.remote(num_cpus=settings.RAY_CORES_PER_WORKER, num_gpus=.05 if settings.CUDA else 0)
def process_single_pdf(fname: str, out_folder: str, model_refs, metadata: Optional[Dict] = None, min_length: Optional[int] = None):
    # 构建输出文件名和元数据文件名
    out_filename = fname.rsplit(".", 1)[0]   ".md"
    out_filename = os.path.join(out_folder, os.path.basename(out_filename))
    out_meta_filename = out_filename.rsplit(".", 1)[0]   "_meta.json"
    
    # 如果输出文件已存在,则直接返回
    if os.path.exists(out_filename):
        return
    
    try:
        # 如果指定了最小文本长度,检查文件文本长度是否符合要求
        if min_length:
            length = get_length_of_text(fname)
            if length < min_length:
                return
        
        # 转换 PDF 文件为 Markdown 格式,并获取转换后的文本和元数据
        full_text, out_metadata = convert_single_pdf(fname, model_refs, metadata=metadata)
        
        # 如果转换后的文本不为空,则写入到文件中
        if len(full_text.strip()) > 0:
            with open(out_filename, "w ", encoding='utf-8') as f:
                f.write(full_text)
            with open(out_meta_filename, "w ") as f:
                f.write(json.dumps(out_metadata, indent=4))
        else:
            print(f"Empty file: {fname}.  Could not convert.")
    except Exception as e:
        # 捕获异常并打印错误信息
        print(f"Error converting {fname}: {e}")
        print(traceback.format_exc())

# 主函数
def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description="Convert multiple pdfs to markdown.")
    
    # 添加输入文件夹和输出文件夹参数
    parser.add_argument("in_folder", help="Input folder with pdfs.")
    parser.add_argument("out_folder", help="Output folder")
    # 添加命令行参数,指定要转换的块索引
    parser.add_argument("--chunk_idx", type=int, default=0, help="Chunk index to convert")
    # 添加命令行参数,指定并行处理的块数
    parser.add_argument("--num_chunks", type=int, default=1, help="Number of chunks being processed in parallel")
    # 添加命令行参数,指定要转换的最大 pdf 数量
    parser.add_argument("--max", type=int, default=None, help="Maximum number of pdfs to convert")
    # 添加命令行参数,指定要使用的工作进程数
    parser.add_argument("--workers", type=int, default=5, help="Number of worker processes to use")
    # 添加命令行参数,指定要使用的元数据 json 文件进行过滤
    parser.add_argument("--metadata_file", type=str, default=None, help="Metadata json file to use for filtering")
    # 添加命令行参数,指定要转换的 pdf 的最小长度
    parser.add_argument("--min_length", type=int, default=None, help="Minimum length of pdf to convert")

    # 解析命令行参数
    args = parser.parse_args()

    # 获取输入文件夹的绝对路径
    in_folder = os.path.abspath(args.in_folder)
    # 获取输出文件夹的绝对路径
    out_folder = os.path.abspath(args.out_folder)
    # 获取输入文件夹中所有文件的路径列表
    files = [os.path.join(in_folder, f) for f in os.listdir(in_folder)]
    # 如果输出文件夹不存在,则创建输出文件夹
    os.makedirs(out_folder, exist_ok=True)

    # 处理并行处理时的块
    # 确保将所有文件放入一个块中
    chunk_size = math.ceil(len(files) / args.num_chunks)
    start_idx = args.chunk_idx * chunk_size
    end_idx = start_idx   chunk_size
    files_to_convert = files[start_idx:end_idx]

    # 如果需要,限制要转换的文件数量
    if args.max:
        files_to_convert = files_to_convert[:args.max]

    metadata = {}
    # 如果指定了元数据文件,则加载元数据
    if args.metadata_file:
        metadata_file = os.path.abspath(args.metadata_file)
        with open(metadata_file, "r") as f:
            metadata = json.load(f)

    # 确定要使用的进程数
    total_processes = min(len(files_to_convert), args.workers)

    # 初始化 Ray,设置 CPU 和 GPU 数量,存储路径等参数
    ray.init(
        num_cpus=total_processes,
        num_gpus=1 if settings.CUDA else 0,
        storage=settings.RAY_CACHE_PATH,
        _temp_dir=settings.RAY_CACHE_PATH,
        log_to_driver=settings.DEBUG
    )

    # 加载所有模型
    model_lst = load_all_models()
    # 将模型列表放入 Ray 中
    model_refs = ray.put(model_lst)

    # 根据 GPU 内存动态设置每个任务的 GPU 分配比例
    gpu_frac = settings.VRAM_PER_TASK / settings.INFERENCE_RAM if settings.CUDA else 0
    # 打印正在转换的 PDF 文件数量、当前处理的块索引、总块数、使用的进程数以及输出文件夹路径
    print(f"Converting {len(files_to_convert)} pdfs in chunk {args.chunk_idx   1}/{args.num_chunks} with {total_processes} processes, and storing in {out_folder}")
    
    # 为每个需要转换的 PDF 文件创建一个 Ray 任务,并指定使用的 GPU 分数
    futures = [
        process_single_pdf.options(num_gpus=gpu_frac).remote(
            filename,
            out_folder,
            model_refs,
            metadata=metadata.get(os.path.basename(filename)),
            min_length=args.min_length
        ) for filename in files_to_convert
    ]

    # 运行所有的 Ray 转换任务
    progress_bar = tqdm(total=len(futures))
    while len(futures) > 0:
        # 等待所有任务完成,超时时间为 7 秒
        finished, futures = ray.wait(
            futures, timeout=7.0
        )
        finished_lst = ray.get(finished)
        # 更新进度条
        if isinstance(finished_lst, list):
            progress_bar.update(len(finished_lst))
        else:
            progress_bar.update(1)

    # 关闭 Ray 以释放资源
    ray.shutdown()
# 如果当前脚本被直接执行,则调用主函数
if __name__ == "__main__":
    main()

.markerconvert_single.py

代码语言:javascript复制
# 导入必要的模块
import argparse  # 用于解析命令行参数
from marker.convert import convert_single_pdf  # 导入 convert_single_pdf 函数
from marker.logger import configure_logging  # 导入 configure_logging 函数
from marker.models import load_all_models  # 导入 load_all_models 函数
import json  # 导入 json 模块

# 配置日志记录
configure_logging()

# 主函数
def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数
    parser.add_argument("filename", help="PDF file to parse")  # PDF 文件名
    parser.add_argument("output", help="Output file name")  # 输出文件名
    parser.add_argument("--max_pages", type=int, default=None, help="Maximum number of pages to parse")  # 最大解析页数
    parser.add_argument("--parallel_factor", type=int, default=1, help="How much to multiply default parallel OCR workers and model batch sizes by.")  # 并行因子
    # 解析命令行参数
    args = parser.parse_args()

    # 获取文件名
    fname = args.filename
    # 加载所有模型
    model_lst = load_all_models()
    # 调用 convert_single_pdf 函数,解析 PDF 文件并返回全文和元数据
    full_text, out_meta = convert_single_pdf(fname, model_lst, max_pages=args.max_pages, parallel_factor=args.parallel_factor)

    # 将全文写入输出文件
    with open(args.output, "w ", encoding='utf-8') as f:
        f.write(full_text)

    # 生成元数据文件名
    out_meta_filename = args.output.rsplit(".", 1)[0]   "_meta.json"
    # 将元数据写入元数据文件
    with open(out_meta_filename, "w ") as f:
        f.write(json.dumps(out_meta, indent=4))

# 如果当前脚本被直接执行,则调用主函数
if __name__ == "__main__":
    main()

.markermarkerbbox.py

代码语言:javascript复制
import fitz as pymupdf

# 判断两个矩形框是否应该合并
def should_merge_blocks(box1, box2, tol=5):
    # 在 tol y 像素内,并且在右侧在 tol 像素内
    merge = [
        box2[0] > box1[0], # 在 x 坐标上在后面
        abs(box2[1] - box1[1]) < tol, # 在 y 坐标上在 tol 像素内
        abs(box2[3] - box1[3]) < tol, # 在 y 坐标上在 tol 像素内
        abs(box2[0] - box1[2]) < tol, # 在 x 坐标上在 tol 像素内
    ]
    return all(merge)

# 合并两个矩形框
def merge_boxes(box1, box2):
    return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box2[2], box1[2]), max(box1[3], box2[3]))

# 判断两个矩形框是否相交
def boxes_intersect(box1, box2):
    # 矩形框1与矩形框2相交
    return box1[0] < box2[2] and box1[2] > box2[0] and box1[1] < box2[3] and box1[3] > box2[1]

# 判断两个矩形框的相交面积占比是否大于给定百分比
def boxes_intersect_pct(box1, box2, pct=.9):
    # 确定相交矩形的坐标
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    # 两个轴对齐边界框的交集始终是一个轴对齐边界框
    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # 计算两个边界框的面积
    bb1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    bb2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = intersection_area / float(bb1_area   bb2_area - intersection_area)
    return iou > pct

# 判断一个矩形框是否与多个矩形框相交
def multiple_boxes_intersect(box1, boxes):
    for box2 in boxes:
        if boxes_intersect(box1, box2):
            return True
    return False

# 判断一个矩形框是否包含在另一个矩形框内
def box_contained(box1, box2):
    # 矩形框1在矩形框2内部
    return box1[0] > box2[0] and box1[1] > box2[1] and box1[2] < box2[2] and box1[3] < box2[3]

# 将归一化的矩形框坐标还原为原始坐标
def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]

# 修正矩形框的旋转
def correct_rotation(bbox, page):
    #bbox base is (x0, y0, x1, y1)
    # 获取页面的旋转角度
    rotation = page.rotation
    # 如果旋转角度为0,则直接返回原始边界框
    if rotation == 0:
        return bbox

    # 计算旋转后的左上角和右下角坐标
    tl = pymupdf.Point(bbox[0], bbox[1]) * page.rotation_matrix
    br = pymupdf.Point(bbox[2], bbox[3]) * page.rotation_matrix

    # 根据不同的旋转角度进行边界框的调整
    if rotation == 90:
        bbox = [br[0], tl[1], tl[0], br[1]]
    elif rotation == 180:
        bbox = [br[0], br[1], tl[0], tl[1]]
    elif rotation == 270:
        bbox = [tl[0], br[1], br[0], tl[1]]

    # 返回调整后的边界框
    return bbox

.markermarkerbenchmarkscoring.py

代码语言:javascript复制
# 导入 math 模块
import math

# 从 rapidfuzz 模块中导入 fuzz 和 distance 函数
from rapidfuzz import fuzz, distance
# 导入 re 模块
import re

# 定义最小分块字符数
CHUNK_MIN_CHARS = 25


def tokenize(text):
    # 定义正则表达式模式
    pattern = r'([^wsd'])|([w'] )|(d )|(n )|(  )'
    # 使用正则表达式模式匹配文本
    result = re.findall(pattern, text)
    # 将匹配结果扁平化并过滤掉空字符串
    flattened_result = [item for sublist in result for item in sublist if item]
    return flattened_result


def chunk_text(text):
    # 将文本按换行符分割成块
    chunks = text.split("n")
    # 过滤掉空白块和长度小于最小分块字符数的块
    chunks = [c for c in chunks if c.strip() and len(c) > CHUNK_MIN_CHARS]
    return chunks


def overlap_score(hypothesis_chunks, reference_chunks):
    # 计算长度修正因子
    length_modifier = len(hypothesis_chunks) / len(reference_chunks)
    # 计算搜索距离
    search_distance = max(len(reference_chunks) // 5, 10)
    chunk_scores = []
    chunk_weights = []
    for i, hyp_chunk in enumerate(hypothesis_chunks):
        max_score = 0
        chunk_weight = 1
        i_offset = int(i * length_modifier)
        chunk_range = range(max(0, i_offset-search_distance), min(len(reference_chunks), i_offset search_distance))
        for j in chunk_range:
            ref_chunk = reference_chunks[j]
            # 计算相似度得分
            score = fuzz.ratio(hyp_chunk, ref_chunk, score_cutoff=30) / 100
            if score > max_score:
                max_score = score
                chunk_weight = math.sqrt(len(ref_chunk))
        chunk_scores.append(max_score)
        chunk_weights.append(chunk_weight)
    chunk_scores = [chunk_scores[i] * chunk_weights[i] for i in range(len(chunk_scores))]
    return chunk_scores, chunk_weights


def score_text(hypothesis, reference):
    # 返回一个0-1的对齐分数
    hypothesis_chunks = chunk_text(hypothesis)
    reference_chunks = chunk_text(reference)
    chunk_scores, chunk_weights = overlap_score(hypothesis_chunks, reference_chunks)
    return sum(chunk_scores) / sum(chunk_weights)

.markermarkercleanersbullets.py

代码语言:javascript复制
# 导入正则表达式模块
import re

# 定义函数,用于替换文本中的特殊符号为 -
def replace_bullets(text):
    # 定义匹配特殊符号的正则表达式模式
    bullet_pattern = r"(^|[n ])[•●○■▪▫–—]( )"
    # 使用正则表达式替换特殊符号为 -
    replaced_string = re.sub(bullet_pattern, r"1-2", text)
    # 返回替换后的文本
    return replaced_string

.markermarkercleanerscode.py

代码语言:javascript复制
# 导入所需的模块和类
from marker.schema import Span, Line, Page
import re
from typing import List
import fitz as pymupdf

# 判断代码行的长度是否符合阈值
def is_code_linelen(lines, thresh=60):
    # 计算所有代码行中的字母数字字符总数
    total_alnum_chars = sum(len(re.findall(r'w', line.prelim_text)) for line in lines)
    # 计算总行数
    total_newlines = max(len(lines) - 1, 1)

    # 如果没有字母数字字符,则返回 False
    if total_alnum_chars == 0:
        return False

    # 计算字母数字字符与行数的比率
    ratio = total_alnum_chars / total_newlines
    return ratio < thresh

# 统计代码行中包含注释的行数
def comment_count(lines):
    # 定义匹配注释的正则表达式模式
    pattern = re.compile(r"^(//|#|'|--|/*|'''|"""|--[[|<!--|%|%{|(*)")
    # 统计匹配到的注释行数
    return sum([1 for line in lines if pattern.match(line)])

# 识别代码块
def identify_code_blocks(blocks: List[Page]):
    # 初始化代码块计数和字体信息
    code_block_count = 0
    font_info = None
    # 遍历每个页面
    for p in blocks:
        # 获取页面的字体统计信息
        stats = p.get_font_stats()
        # 如果是第一页,则将字体信息初始化为当前页面的字体信息
        if font_info is None:
            font_info = stats
        else:
            # 否则将当前页面的字体信息与之前页面的字体信息相加
            font_info  = stats
    try:
        # 获取最常见的字体
        most_common_font = font_info.most_common(1)[0][0]
    except IndexError:
        # 如果找不到最常见的字体,则打印提示信息
        print(f"Could not find most common font")
        most_common_font = None

    # 初始化最后一个代码块
    last_block = None
    # 遍历每一页的文本块
    for page in blocks:
        try:
            # 获取当前页最小行的起始位置
            min_start = page.get_min_line_start()
        except IndexError:
            # 如果出现索引错误,则跳过当前页
            continue

        # 遍历当前页的文本块
        for block in page.blocks:
            # 如果当前文本块的类型不是"Text",则跳过
            if block.most_common_block_type() != "Text":
                last_block = block
                continue

            # 初始化用于判断是否为代码的变量
            is_indent = []
            line_fonts = []
            # 遍历当前文本块的每一行
            for line in block.lines:
                # 获取每行中的字体信息
                fonts = [span.font for span in line.spans]
                line_fonts  = fonts
                # 获取每行的起始位置
                line_start = line.bbox[0]
                # 判断当前行是否缩进
                if line_start > min_start:
                    is_indent.append(True)
                else:
                    is_indent.append(False)
            # 统计每个文本块中的注释行数
            comment_lines = comment_count([line.prelim_text for line in block.lines])
            # 判断当前文本块是否为代码块
            is_code = [
                len(block.lines) > 3,  # 文本块行数大于3
                sum([f != most_common_font for f in line_fonts]) > len(line_fonts) * .8,  # 至少80%的字体不是最常见的字体,因为代码通常使用与主体文本不同的字体
                is_code_linelen(block.lines),  # 判断代码行长度是否符合规范
                (
                    sum(is_indent) > len(block.lines) * .2  # 20%的行有缩进
                    or
                    comment_lines > len(block.lines) * .2  # 20%的行是注释
                 ), 
            ]

            # 检查前一个文本块是否为代码块,当前文本块是否有缩进
            is_code_prev = [
                last_block and last_block.most_common_block_type() == "Code",  # 前一个文本块是代码块
                sum(is_indent) >= len(block.lines) * .8  # 至少80%的行有缩进
            ]

            # 如果当前文本块被判断为代码块,增加代码块计数并设置文本块类型为"Code"
            if all(is_code) or all(is_code_prev):
                code_block_count  = 1
                block.set_block_type("Code")

            last_block = block
    # 返回代码块计数
    return code_block_count
# 缩进代码块,将每个代码块的内容整理成一个新的 Span 对象
def indent_blocks(blocks: List[Page]):
    # 计数器,用于生成新的 Span 对象的 ID
    span_counter = 0
    # 遍历每一页的代码块
    for page in blocks:
        for block in page.blocks:
            # 获取当前代码块中所有行的块类型
            block_types = [span.block_type for line in block.lines for span in line.spans]
            # 如果当前代码块不是代码块,则跳过
            if "Code" not in block_types:
                continue

            # 初始化空列表用于存储处理后的行数据
            lines = []
            # 初始化最左边界和字符宽度
            min_left = 1000  # will contain x- coord of column 0
            col_width = 0  # width of 1 char
            # 遍历当前代码块的每一行
            for line in block.lines:
                text = ""
                # 更新最左边界
                min_left = min(line.bbox[0], min_left)
                # 拼接每行的文本内容
                for span in line.spans:
                    if col_width == 0 and len(span.text) > 0:
                        col_width = (span.bbox[2] - span.bbox[0]) / len(span.text)
                    text  = span.text
                lines.append((pymupdf.Rect(line.bbox), text))

            # 初始化空字符串用于存储处理后的代码块文本
            block_text = ""
            blank_line = False
            # 遍历处理后的每一行
            for line in lines:
                text = line[1]
                prefix = " " * int((line[0].x0 - min_left) / col_width)
                current_line_blank = len(text.strip()) == 0
                # 如果当前行和上一行都是空行,则跳过
                if blank_line and current_line_blank:
                    continue

                # 拼接处理后的代码块文本
                block_text  = prefix   text   "n"
                blank_line = current_line_blank

            # 创建新的 Span 对象,用于替换原有的代码块
            new_span = Span(
                text=block_text,
                bbox=block.bbox,
                color=block.lines[0].spans[0].color,
                span_id=f"{span_counter}_fix_code",
                font=block.lines[0].spans[0].font,
                block_type="Code"
            )
            span_counter  = 1
            # 替换原有的代码块内容为新的 Span 对象
            block.lines = [Line(spans=[new_span], bbox=block.bbox)]

.markermarkercleanersequations.py

代码语言:javascript复制
# 导入所需的库
import io
from copy import deepcopy
from functools import partial
from typing import List

import torch
from texify.inference import batch_inference
from texify.model.model import load_model
from texify.model.processor import load_processor
import re
from PIL import Image, ImageDraw

# 导入自定义模块
from marker.bbox import should_merge_blocks, merge_boxes
from marker.debug.data import dump_equation_debug_data
from marker.settings import settings
from marker.schema import Page, Span, Line, Block, BlockType
import os

# 设置环境变量,禁用 tokenizers 的并行处理
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 加载处理器
processor = load_processor()

# 加载 Texify 模型
def load_texify_model():
    texify_model = load_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
    return texify_model

# 创建遮罩区域
def mask_bbox(png_image, bbox, selected_bboxes):
    # 创建一个与图片大小相同的灰度图像
    mask = Image.new('L', png_image.size, 0)  # 'L' mode for grayscale
    draw = ImageDraw.Draw(mask)
    first_x = bbox[0]
    first_y = bbox[1]
    bbox_height = bbox[3] - bbox[1]
    bbox_width = bbox[2] - bbox[0]

    for box in selected_bboxes:
        # 将框适配到选定区域
        new_box = (box[0] - first_x, box[1] - first_y, box[2] - first_x, box[3] - first_y)
        # 将遮罩适配到图像边界与 PDF 边界
        resized = (
           new_box[0] / bbox_width * png_image.size[0],
           new_box[1] / bbox_height * png_image.size[1],
           new_box[2] / bbox_width * png_image.size[0],
           new_box[3] / bbox_height * png_image.size[1]
        )
        draw.rectangle(resized, fill=255)

    # 通过遮罩创建结果图像
    result = Image.composite(png_image, Image.new('RGBA', png_image.size, 'white'), mask)
    return result

# 获取遮罩后的图像
def get_masked_image(page, bbox, selected_bboxes):
    # 获取页面的像素图
    pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, clip=bbox)
    png = pix.pil_tobytes(format="PNG")
    png_image = Image.open(io.BytesIO(png))
    # 创建遮罩后的图像
    png_image = mask_bbox(png_image, bbox, selected_bboxes)
    png_image = png_image.convert("RGB")
    return png_image
# 批量处理 LaTeX 图像,根据指定的区域长度重新格式化,使用指定的模型进行转换
def get_latex_batched(images, reformat_region_lens, texify_model, batch_size):
    # 如果图像列表为空,则返回空列表
    if len(images) == 0:
        return []

    # 初始化预测结果列表
    predictions = [""] * len(images)

    # 按批次处理图像
    for i in range(0, len(images), batch_size):
        # 动态设置最大长度以节省推理时间
        min_idx = i
        max_idx = min(min_idx   batch_size, len(images))
        max_length = max(reformat_region_lens[min_idx:max_idx])
        max_length = min(max_length, settings.TEXIFY_MODEL_MAX)
        max_length  = settings.TEXIFY_TOKEN_BUFFER

        # 对图像批次进行推理
        model_output = batch_inference(images[min_idx:max_idx], texify_model, processor, max_tokens=max_length)

        # 遍历模型输出
        for j, output in enumerate(model_output):
            token_count = get_total_texify_tokens(output)
            # 如果 token 数量超过最大长度减一,则将输出置为空字符串
            if token_count >= max_length - 1:
                output = ""

            # 计算图像索引
            image_idx = i   j
            predictions[image_idx] = output
    return predictions


# 获取文本中的总 LaTeX token 数量
def get_total_texify_tokens(text):
    tokenizer = processor.tokenizer
    tokens = tokenizer(text)
    return len(tokens["input_ids"])


# 查找页面中的数学公式区域
def find_page_equation_regions(pnum, page, block_types):
    i = 0
    # 提取数学公式区域的边界框
    equation_boxes = [b.bbox for b in block_types[pnum] if b.block_type == "Formula"]
    reformatted_blocks = set()
    reformat_regions = []
    block_lens = []
    return reformat_regions, block_lens


# 获取区域内的边界框
def get_bboxes_for_region(page, region):
    bboxes = []
    merged_box = None
    for idx in region:
        block = page.blocks[idx]
        bbox = block.bbox
        if merged_box is None:
            merged_box = bbox
        else:
            merged_box = merge_boxes(merged_box, bbox)
        bboxes.append(bbox)
    return bboxes, merged_box


# 替换页面块中的文本块为 LaTeX
def replace_blocks_with_latex(page_blocks: Page, merged_boxes, reformat_regions, predictions, pnum):
    new_blocks = []
    converted_spans = []
    current_region = 0
    idx = 0
    success_count = 0
    fail_count = 0
    # 当索引小于页面块列表的长度时,继续循环
    while idx < len(page_blocks.blocks):
        # 获取当前索引对应的页面块
        block = page_blocks.blocks[idx]
        # 如果当前区域索引超过重新格式化区域列表的长度,或者当前索引小于重新格式化区域的起始索引
        if current_region >= len(reformat_regions) or idx < reformat_regions[current_region][0]:
            # 将当前页面块添加到新的块列表中
            new_blocks.append(block)
            # 索引加一
            idx  = 1
            # 继续下一次循环
            continue

        # 获取重新格式化区域的原始文本
        orig_block_text = " ".join([page_blocks.blocks[i].prelim_text for i in reformat_regions[current_region]])
        # 获取预测的 LaTeX 文本
        latex_text = predictions[current_region]
        # 定义条件列表
        conditions = [
            len(latex_text) > 0,
            get_total_texify_tokens(latex_text) < settings.TEXIFY_MODEL_MAX,  # 确保没有达到总体令牌最大值
            len(latex_text) > len(orig_block_text) * .8,
            len(latex_text.strip()) > 0
        ]

        # 更新索引为重新格式化区域的结束索引加一
        idx = reformat_regions[current_region][-1]   1
        # 如果条件不满足
        if not all(conditions):
            # 失败计数加一
            fail_count  = 1
            # 将转换后的区域添加为 None
            converted_spans.append(None)
            # 将重新格式化区域中的页面块添加到新的块列表中
            for i in reformat_regions[current_region]:
                new_blocks.append(page_blocks.blocks[i])
        else:
            # 成功计数加一
            success_count  = 1
            # 创建一个包含 LaTeX 文本的行对象
            block_line = Line(
                spans=[
                    Span(
                        text=latex_text,
                        bbox=merged_boxes[current_region],
                        span_id=f"{pnum}_{idx}_fixeq",
                        font="Latex",
                        color=0,
                        block_type="Formula"
                    )
                ],
                bbox=merged_boxes[current_region]
            )
            # 深拷贝第一个 span 对象并添加到转换后的区域列表中
            converted_spans.append(deepcopy(block_line.spans[0]))
            # 创建一个新的块对象,包含上述行对象
            new_blocks.append(Block(
                lines=[block_line],
                bbox=merged_boxes[current_region],
                pnum=pnum
            ))
        # 更新当前区域索引
        current_region  = 1
    # 返回新的块列表、成功计数、失败计数和转换后的区域列表
    return new_blocks, success_count, fail_count, converted_spans
def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], texify_model, batch_size=settings.TEXIFY_BATCH_SIZE):
    # 初始化未成功 OCR 的计数和成功 OCR 的计数
    unsuccessful_ocr = 0
    successful_ocr = 0

    # 查找潜在的方程区域,并计算每个区域中文本的长度
    reformat_regions = []
    reformat_region_lens = []
    for pnum, page in enumerate(blocks):
        regions, region_lens = find_page_equation_regions(pnum, page, block_types)
        reformat_regions.append(regions)
        reformat_region_lens.append(region_lens)

    # 计算方程的总数
    eq_count = sum([len(x) for x in reformat_regions])

    # 获取每个区域的图像
    flat_reformat_region_lens = [item for sublist in reformat_region_lens for item in sublist]
    images = []
    merged_boxes = []
    for page_idx, reformat_regions_page in enumerate(reformat_regions):
        page_obj = doc[page_idx]
        for reformat_region in reformat_regions_page:
            bboxes, merged_box = get_bboxes_for_region(blocks[page_idx], reformat_region)
            png_image = get_masked_image(page_obj, merged_box, bboxes)
            images.append(png_image)
            merged_boxes.append(merged_box)

    # 进行批量预测
    predictions = get_latex_batched(images, flat_reformat_region_lens, texify_model, batch_size)

    # 替换区域中的文本块为预测结果
    page_start = 0
    converted_spans = []
    # 遍历重排后的区域列表,获取每一页的预测结果和合并后的框
    for page_idx, reformat_regions_page in enumerate(reformat_regions):
        # 获取当前页的预测结果和合并后的框
        page_predictions = predictions[page_start:page_start   len(reformat_regions_page)]
        page_boxes = merged_boxes[page_start:page_start   len(reformat_regions_page)]
        # 替换块内容为 LaTeX,并返回新的块列表、成功计数、失败计数和转换的跨度
        new_page_blocks, success_count, fail_count, converted_span = replace_blocks_with_latex(
            blocks[page_idx],
            page_boxes,
            reformat_regions_page,
            page_predictions,
            page_idx
        )
        # 将转换的跨度添加到列表中
        converted_spans.extend(converted_span)
        # 更新当前页的块列表
        blocks[page_idx].blocks = new_page_blocks
        # 更新页起始位置
        page_start  = len(reformat_regions_page)
        # 更新成功 OCR 计数和失败 OCR 计数
        successful_ocr  = success_count
        unsuccessful_ocr  = fail_count

    # 如果调试模式开启,输出转换结果以供比较
    dump_equation_debug_data(doc, images, converted_spans)

    # 返回更新后的块列表和 OCR 结果统计信息
    return blocks, {"successful_ocr": successful_ocr, "unsuccessful_ocr": unsuccessful_ocr, "equations": eq_count}

.markermarkercleanersheaders.py

代码语言:javascript复制
# 导入所需的模块
import re
from collections import Counter, defaultdict
from itertools import chain
from thefuzz import fuzz
from sklearn.cluster import DBSCAN
import numpy as np
from marker.schema import Page, FullyMergedBlock
from typing import List, Tuple

# 过滤出现频率高于给定阈值的文本块
def filter_common_elements(lines, page_count):
    # 提取所有文本内容
    text = [s.text for line in lines for s in line.spans if len(s.text) > 4]
    # 统计文本内容出现的次数
    counter = Counter(text)
    # 选取出现频率高于阈值的文本内容
    common = [k for k, v in counter.items() if v > page_count * .6]
    # 获取包含常见文本内容的文本块的 span_id
    bad_span_ids = [s.span_id for line in lines for s in line.spans if s.text in common]
    return bad_span_ids

# 过滤页眉页脚文本块
def filter_header_footer(all_page_blocks, max_selected_lines=2):
    first_lines = []
    last_lines = []
    for page in all_page_blocks:
        nonblank_lines = page.get_nonblank_lines()
        first_lines.extend(nonblank_lines[:max_selected_lines])
        last_lines.extend(nonblank_lines[-max_selected_lines:])

    # 获取页眉页脚文本块的 span_id
    bad_span_ids = filter_common_elements(first_lines, len(all_page_blocks))
    bad_span_ids  = filter_common_elements(last_lines, len(all_page_blocks))
    return bad_span_ids

# 对文本块进行分类
def categorize_blocks(all_page_blocks: List[Page]):
    # 提取所有非空文本块的 span
    spans = list(chain.from_iterable([p.get_nonblank_spans() for p in all_page_blocks]))
    # 构建特征矩阵
    X = np.array(
        [(*s.bbox, len(s.text)) for s in spans]
    )

    # 使用 DBSCAN 进行聚类
    dbscan = DBSCAN(eps=.1, min_samples=5)
    dbscan.fit(X)
    labels = dbscan.labels_
    label_chars = defaultdict(int)
    for i, label in enumerate(labels):
        label_chars[label]  = len(spans[i].text)

    # 选择出现次数最多的类别作为主要类别
    most_common_label = None
    most_chars = 0
    for i in label_chars.keys():
        if label_chars[i] > most_chars:
            most_common_label = i
            most_chars = label_chars[i]

    # 将非主要类别标记为 1
    labels = [0 if label == most_common_label else 1 for label in labels]
    # 获取非主要类别的文本块的 span_id
    bad_span_ids = [spans[i].span_id for i in range(len(spans)) if labels[i] == 1]

    return bad_span_ids

# 替换字符串开头的数字
def replace_leading_trailing_digits(string, replacement):
    string = re.sub(r'^d ', replacement, string)
    # 使用正则表达式替换字符串中最后的数字
    string = re.sub(r'd $', replacement, string)
    # 返回替换后的字符串
    return string
# 定义一个函数,用于查找重叠元素
def find_overlap_elements(lst: List[Tuple[str, int]], string_match_thresh=.9, min_overlap=.05) -> List[int]:
    # 初始化一个列表,用于存储符合条件的元素
    result = []
    # 从输入列表中提取所有元组的第一个元素,即标题
    titles = [l[0] for l in lst]

    # 遍历输入列表中的元素
    for i, (str1, id_num) in enumerate(lst):
        overlap_count = 0  # 计算至少80%重叠的元素数量

        # 再次遍历标题列表,检查元素之间的相似度
        for j, str2 in enumerate(titles):
            if i != j and fuzz.ratio(str1, str2) >= string_match_thresh * 100:
                overlap_count  = 1

        # 检查元素是否与至少50%的其他元素重叠
        if overlap_count >= max(3.0, len(lst) * min_overlap):
            result.append(id_num)

    return result


# 定义一个函数,用于过滤常见标题
def filter_common_titles(merged_blocks: List[FullyMergedBlock]) -> List[FullyMergedBlock]:
    titles = []
    # 遍历合并块列表中的块
    for i, block in enumerate(merged_blocks):
        # 如果块类型为"Title"或"Section-header"
        if block.block_type in ["Title", "Section-header"]:
            text = block.text
            # 如果文本以"#"开头,则去除所有"#"
            if text.strip().startswith("#"):
                text = re.sub(r'# ', '', text)
            text = text.strip()
            # 去除文本开头和结尾的页码
            text = replace_leading_trailing_digits(text, "").strip()
            titles.append((text, i))

    # 查找重叠标题的块的索引
    bad_block_ids = find_overlap_elements(titles)

    new_blocks = []
    # 遍历合并块列表中的块
    for i, block in enumerate(merged_blocks):
        # 如果块的索引在重叠块的索引列表中,则跳过该块
        if i in bad_block_ids:
            continue
        new_blocks.append(block)

    return new_blocks

.markermarkercleanerstable.py

代码语言:javascript复制
# 从 marker.bbox 模块中导入 merge_boxes 函数
# 从 marker.schema 模块中导入 Line, Span, Block, Page 类
# 从 copy 模块中导入 deepcopy 函数
# 从 tabulate 模块中导入 tabulate 函数
# 从 typing 模块中导入 List 类型
# 导入 re 模块
# 导入 textwrap 模块
from marker.bbox import merge_boxes
from marker.schema import Line, Span, Block, Page
from copy import deepcopy
from tabulate import tabulate
from typing import List
import re
import textwrap


# 合并表格块
def merge_table_blocks(blocks: List[Page]):
    # 初始化当前行列表和当前边界框
    current_lines = []
    current_bbox = None
    # 遍历每一页
    for page in blocks:
        new_page_blocks = []
        pnum = page.pnum
        # 遍历每个块
        for block in page.blocks:
            # 如果块的最常见类型不是表格
            if block.most_common_block_type() != "Table":
                # 如果当前行列表不为空
                if len(current_lines) > 0:
                    # 创建新的块对象,包含当前行列表和当前页码
                    new_block = Block(
                        lines=deepcopy(current_lines),
                        pnum=pnum,
                        bbox=current_bbox
                    )
                    new_page_blocks.append(new_block)
                    current_lines = []
                    current_bbox = None

                # 将当前块添加到新页块列表中
                new_page_blocks.append(block)
                continue

            # 将块的行添加到当前行列表中
            current_lines.extend(block.lines)
            # 如果当前边界框为空,则设置为块的边界框,否则合并边界框
            if current_bbox is None:
                current_bbox = block.bbox
            else:
                current_bbox = merge_boxes(current_bbox, block.bbox)

        # 如果当前行列表不为空
        if len(current_lines) > 0:
            # 创建新的块对象,包含当前行列表和当前页码
            new_block = Block(
                lines=deepcopy(current_lines),
                pnum=pnum,
                bbox=current_bbox
            )
            new_page_blocks.append(new_block)
            current_lines = []
            current_bbox = None

        # 更新当前页的块列表
        page.blocks = new_page_blocks


# 创建新的表格
def create_new_tables(blocks: List[Page]):
    # 初始化表格索引和正则表达式模式
    table_idx = 0
    dot_pattern = re.compile(r'(s*.s*){4,}')
    dot_multiline_pattern = re.compile(r'.*(s*.s*){4,}.*', re.DOTALL)
    # 遍历每一页中的文本块
    for page in blocks:
        # 遍历每个文本块中的块
        for block in page.blocks:
            # 如果块类型不是表格或者行数小于3,则跳过
            if block.most_common_block_type() != "Table" or len(block.lines) < 3:
                continue

            # 初始化表格行列表和y坐标
            table_rows = []
            y_coord = None
            row = []
            # 遍历每行文本
            for line in block.lines:
                # 遍历每个文本块
                for span in line.spans:
                    # 如果y坐标不同于当前span的起始y坐标
                    if y_coord != span.y_start:
                        # 如果当前行有内容,则添加到表格行列表中
                        if len(row) > 0:
                            table_rows.append(row)
                            row = []
                        y_coord = span.y_start

                    # 获取文本内容并处理多行文本
                    text = span.text
                    if dot_multiline_pattern.match(text):
                        text = dot_pattern.sub(' ', text)
                    row.append(text)
            # 如果当前行有内容,则添加到表格行列表中
            if len(row) > 0:
                table_rows.append(row)

            # 如果表格行字符总长度大于300,或者第一行列数大于8或小于2,则跳过
            if max([len("".join(r)) for r in table_rows]) > 300 or len(table_rows[0]) > 8 or len(table_rows[0]) < 2:
                continue

            # 格式化表格行数据并创建新的Span和Line对象
            new_text = tabulate(table_rows, headers="firstrow", tablefmt="github")
            new_span = Span(
                bbox=block.bbox,
                span_id=f"{table_idx}_fix_table",
                font="Table",
                color=0,
                block_type="Table",
                text=new_text
            )
            new_line = Line(
                bbox=block.bbox,
                spans=[new_span]
            )
            # 替换原有文本块的行为新的行
            block.lines = [new_line]
            table_idx  = 1
    # 返回处理过的表格数量
    return table_idx

.markermarkerconvert.py

代码语言:javascript复制
# 导入所需的库
import fitz as pymupdf

# 导入自定义模块
from marker.cleaners.table import merge_table_blocks, create_new_tables
from marker.debug.data import dump_bbox_debug_data
from marker.extract_text import get_text_blocks
from marker.cleaners.headers import filter_header_footer, filter_common_titles
from marker.cleaners.equations import replace_equations
from marker.ordering import order_blocks
from marker.postprocessors.editor import edit_full_text
from marker.segmentation import detect_document_block_types
from marker.cleaners.code import identify_code_blocks, indent_blocks
from marker.cleaners.bullets import replace_bullets
from marker.markdown import merge_spans, merge_lines, get_full_text
from marker.schema import Page, BlockType
from typing import List, Dict, Tuple, Optional
import re
import magic
from marker.settings import settings

# 定义函数,根据文件路径获取文件类型
def find_filetype(fpath):
    # 获取文件的 MIME 类型
    mimetype = magic.from_file(fpath).lower()

    # 根据 MIME 类型判断文件类型
    if "pdf" in mimetype:
        return "pdf"
    elif "epub" in mimetype:
        return "epub"
    elif "mobi" in mimetype:
        return "mobi"
    elif mimetype in settings.SUPPORTED_FILETYPES:
        return settings.SUPPORTED_FILETYPES[mimetype]
    else:
        # 输出非标准文件类型信息
        print(f"Found nonstandard filetype {mimetype}")
        return "other"

# 定义函数,为文本块添加标注
def annotate_spans(blocks: List[Page], block_types: List[BlockType]):
    for i, page in enumerate(blocks):
        page_block_types = block_types[i]
        page.add_block_types(page_block_types)

# 定义函数,获取文本文件的长度
def get_length_of_text(fname: str) -> int:
    # 获取文件类型
    filetype = find_filetype(fname)
    # 如果文件类型为其他,则返回长度为0
    if filetype == "other":
        return 0

    # 使用 pymupdf 打开文件
    doc = pymupdf.open(fname, filetype=filetype)
    full_text = ""
    # 遍历每一页,获取文本内容并拼接
    for page in doc:
        full_text  = page.get_text("text", sort=True, flags=settings.TEXT_FLAGS)

    return len(full_text)
def convert_single_pdf(
        fname: str,  # 定义函数,将单个 PDF 文件转换为文本
        model_lst: List,  # 模型列表
        max_pages=None,  # 最大页数,默认为 None
        metadata: Optional[Dict]=None,  # 元数据,默认为 None
        parallel_factor: int = 1  # 并行因子,默认为 1
) -> Tuple[str, Dict]:  # 返回类型为元组,包含字符串和字典

    lang = settings.DEFAULT_LANG  # 设置默认语言为系统默认语言
    if metadata:  # 如果有元数据
        lang = metadata.get("language", settings.DEFAULT_LANG)  # 获取元数据中的语言信息,如果不存在则使用系统默认语言

    # 使用 Tesseract 语言,如果可用
    tess_lang = settings.TESSERACT_LANGUAGES.get(lang, "eng")  # 获取 Tesseract 语言设置
    spell_lang = settings.SPELLCHECK_LANGUAGES.get(lang, None)  # 获取拼写检查语言设置
    if "eng" not in tess_lang:  # 如果英语不在 Tesseract 语言中
        tess_lang = f"eng {tess_lang}"  # 添加英语到 Tesseract 语言中

    # 输出元数据
    out_meta = {"language": lang}  # 设置输出元数据的语言信息

    filetype = find_filetype(fname)  # 查找文件类型
    if filetype == "other":  # 如果文件类型为其他
        return "", out_meta  # 返回空字符串和输出元数据

    out_meta["filetype"] = filetype  # 设置输出元数据的文件类型

    doc = pymupdf.open(fname, filetype=filetype)  # 打开文件
    if filetype != "pdf":  # 如果文件类型不是 PDF
        conv = doc.convert_to_pdf()  # 将文件转换为 PDF 格式
        doc = pymupdf.open("pdf", conv)  # 打开 PDF 文件

    blocks, toc, ocr_stats = get_text_blocks(
        doc,
        tess_lang,
        spell_lang,
        max_pages=max_pages,
        parallel=int(parallel_factor * settings.OCR_PARALLEL_WORKERS)
    )  # 获取文本块、目录和 OCR 统计信息

    out_meta["toc"] = toc  # 设置输出元数据的目录信息
    out_meta["pages"] = len(blocks)  # 设置输出元数据的页数
    out_meta["ocr_stats"] = ocr_stats  # 设置输出元数据的 OCR 统计信息
    if len([b for p in blocks for b in p.blocks]) == 0:  # 如果没有提取到任何文本块
        print(f"Could not extract any text blocks for {fname}")  # 打印无法提取文本块的消息
        return "", out_meta  # 返回空字符串和输出元数据

    # 解包模型列表
    texify_model, layoutlm_model, order_model, edit_model = model_lst  # 解包模型列表

    block_types = detect_document_block_types(
        doc,
        blocks,
        layoutlm_model,
        batch_size=int(settings.LAYOUT_BATCH_SIZE * parallel_factor)
    )  # 检测文档的块类型

    # 查找页眉和页脚
    bad_span_ids = filter_header_footer(blocks)  # 过滤页眉和页脚
    out_meta["block_stats"] = {"header_footer": len(bad_span_ids)}  # 设置输出元数据的块统计信息

    annotate_spans(blocks, block_types)  # 标注文本块

    # 如果设置了标志,则转储调试数据
    dump_bbox_debug_data(doc, blocks)  # 转储边界框调试数据
    # 根据指定的参数对文档中的块进行排序
    blocks = order_blocks(
        doc,
        blocks,
        order_model,
        batch_size=int(settings.ORDERER_BATCH_SIZE * parallel_factor)
    )

    # 识别代码块数量并更新元数据
    code_block_count = identify_code_blocks(blocks)
    out_meta["block_stats"]["code"] = code_block_count
    # 缩进代码块
    indent_blocks(blocks)

    # 合并表格块
    merge_table_blocks(blocks)
    # 创建新的表格块并更新元数据
    table_count = create_new_tables(blocks)
    out_meta["block_stats"]["table"] = table_count

    # 遍历每个页面的块
    for page in blocks:
        for block in page.blocks:
            # 过滤掉坏的 span id
            block.filter_spans(bad_span_ids)
            # 过滤掉坏的 span 类型
            block.filter_bad_span_types()

    # 替换方程式并更新元数据
    filtered, eq_stats = replace_equations(
        doc,
        blocks,
        block_types,
        texify_model,
        batch_size=int(settings.TEXIFY_BATCH_SIZE * parallel_factor)
    )
    out_meta["block_stats"]["equations"] = eq_stats

    # 复制以避免更改原始数据
    merged_lines = merge_spans(filtered)
    text_blocks = merge_lines(merged_lines, filtered)
    text_blocks = filter_common_titles(text_blocks)
    full_text = get_full_text(text_blocks)

    # 处理被连接的空块
    full_text = re.sub(r'n{3,}', 'nn', full_text)
    full_text = re.sub(r'(ns){3,}', 'nn', full_text)

    # 用 - 替换项目符号字符
    full_text = replace_bullets(full_text)

    # 使用编辑器模型后处理文本
    full_text, edit_stats = edit_full_text(
        full_text,
        edit_model,
        batch_size=settings.EDITOR_BATCH_SIZE * parallel_factor
    )
    out_meta["postprocess_stats"] = {"edit": edit_stats}

    # 返回处理后的文本和元数据
    return full_text, out_meta

.markermarkerdebugdata.py

代码语言:javascript复制
import base64
import json
import os
import zlib
from typing import List

from marker.schema import Page
from marker.settings import settings
from PIL import Image
import io

# 定义一个函数,用于将公式的调试数据转储到文件中
def dump_equation_debug_data(doc, images, converted_spans):
    # 如果未设置调试数据文件夹或调试级别为0,则直接返回
    if not settings.DEBUG_DATA_FOLDER or settings.DEBUG_LEVEL == 0:
        return

    # 如果图片列表为空,则直接返回
    if len(images) == 0:
        return

    # 断言每个图片都有对应的转换结果
    assert len(converted_spans) == len(images)

    data_lines = []
    # 遍历图片和对应的转换结果
    for idx, (pil_image, converted_span) in enumerate(zip(images, converted_spans)):
        # 如果转换结果为空,则跳过当前图片
        if converted_span is None:
            continue
        # 将 PIL 图像保存为 BytesIO 对象
        img_bytes = io.BytesIO()
        pil_image.save(img_bytes, format="WEBP", lossless=True)
        # 将图片数据进行 base64 编码
        b64_image = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
        # 将图片数据、转换后的文本和边界框信息添加到数据行中
        data_lines.append({
            "image": b64_image,
            "text": converted_span.text,
            "bbox": converted_span.bbox
        })

    # 从文档名称中去除扩展名
    doc_base = os.path.basename(doc.name).rsplit(".", 1)[0]

    # 构建调试数据文件路径
    debug_file = os.path.join(settings.DEBUG_DATA_FOLDER, f"{doc_base}_equations.json")
    # 将数据行写入到 JSON 文件中
    with open(debug_file, "w ") as f:
        json.dump(data_lines, f)

# 定义一个函数,用于将边界框的调试数据转储到文件中
def dump_bbox_debug_data(doc, blocks: List[Page]):
    # 如果未设置调试数据文件夹或调试级别小于2,则直接返回
    if not settings.DEBUG_DATA_FOLDER or settings.DEBUG_LEVEL < 2:
        return

    # 从文档名称中去除扩展名
    doc_base = os.path.basename(doc.name).rsplit(".", 1)[0]

    # 构建调试数据文件路径
    debug_file = os.path.join(settings.DEBUG_DATA_FOLDER, f"{doc_base}_bbox.json")
    debug_data = []
    # 遍历每个页面的块索引和块数据
    for idx, page_blocks in enumerate(blocks):
        # 获取当前页面对象
        page = doc[idx]

        # 获取页面的像素图像
        pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, annots=False, clip=page_blocks.bbox)
        # 将像素图像转换为 PNG 格式的字节流
        png = pix.pil_tobytes(format="PNG")
        # 从 PNG 字节流创建图像对象
        png_image = Image.open(io.BytesIO(png))
        # 获取图像的宽度和高度
        width, height = png_image.size
        # 设置最大尺寸
        max_dimension = 6000
        # 如果图像宽度或高度超过最大尺寸
        if width > max_dimension or height > max_dimension:
            # 计算缩放因子
            scaling_factor = min(max_dimension / width, max_dimension / height)
            # 缩放图像
            png_image = png_image.resize((int(width * scaling_factor), int(height * scaling_factor)), Image.ANTIALIAS)

        # 创建一个字节流对象
        img_bytes = io.BytesIO()
        # 将图像以 WEBP 格式保存到字节流中
        png_image.save(img_bytes, format="WEBP", lossless=True, quality=100)
        # 将字节流编码为 base64 字符串
        b64_image = base64.b64encode(img_bytes.getvalue()).decode("utf-8")

        # 获取页面块的模型数据
        page_data = page_blocks.model_dump()
        # 将图像数据添加到页面数据中
        page_data["image"] = b64_image
        # 将页面数据添加到调试数据列表中
        debug_data.append(page_data)

    # 将调试数据以 JSON 格式写入调试文件
    with open(debug_file, "w ") as f:
        json.dump(debug_data, f)

.markermarkerextract_text.py

代码语言:javascript复制
# 导入所需的模块
import os
from typing import Tuple, List, Optional

# 导入拼写检查器 SpellChecker
from spellchecker import SpellChecker

# 导入正确旋转的边界框函数
from marker.bbox import correct_rotation
# 导入整页 OCR 函数
from marker.ocr.page import ocr_entire_page
# 导入检测不良 OCR 的工具函数和字体标志分解器
from marker.ocr.utils import detect_bad_ocr, font_flags_decomposer
# 导入设置模块中的设置
from marker.settings import settings
# 导入 Span, Line, Block, Page 数据结构
from marker.schema import Span, Line, Block, Page
# 导入线程池执行器
from concurrent.futures import ThreadPoolExecutor

# 设置环境变量 TESSDATA_PREFIX 为设置模块中的 TESSDATA_PREFIX
os.environ["TESSDATA_PREFIX"] = settings.TESSDATA_PREFIX

# 根据垂直分组对旋转文本进行排序
def sort_rotated_text(page_blocks, tolerance=1.25):
    vertical_groups = {}
    for block in page_blocks:
        group_key = round(block.bbox[1] / tolerance) * tolerance
        if group_key not in vertical_groups:
            vertical_groups[group_key] = []
        vertical_groups[group_key].append(block)

    # 对每个组进行水平排序,并将组展平为一个列表
    sorted_page_blocks = []
    for _, group in sorted(vertical_groups.items()):
        sorted_group = sorted(group, key=lambda x: x.bbox[0])
        sorted_page_blocks.extend(sorted_group)

    return sorted_page_blocks

# 获取单个页面的块信息
def get_single_page_blocks(doc, pnum: int, tess_lang: str, spellchecker: Optional[SpellChecker] = None, ocr=False) -> Tuple[List[Block], int]:
    # 获取文档中指定页码的页面
    page = doc[pnum]
    # 获取页面的旋转角度
    rotation = page.rotation

    # 如果需要进行 OCR
    if ocr:
        # 对整个页面进行 OCR,使用指定的语言和拼写检查器
        blocks = ocr_entire_page(page, tess_lang, spellchecker)
    else:
        # 否则,获取页面的文本块信息,按照设置中的标志进行排序
        blocks = page.get_text("dict", sort=True, flags=settings.TEXT_FLAGS)["blocks"]

    # 初始化页面块列表和跨度 ID
    page_blocks = []
    span_id = 0
    # 遍历每个块的索引和块内容
    for block_idx, block in enumerate(blocks):
        # 初始化存储每个块中行的列表
        block_lines = []
        # 遍历每个块中的行
        for l in block["lines"]:
            # 初始化存储每个行中span的列表
            spans = []
            # 遍历每个span
            for i, s in enumerate(l["spans"]):
                # 获取span的文本内容和边界框
                block_text = s["text"]
                bbox = s["bbox"]
                # 创建Span对象,包括文本内容、边界框、span id、字体和颜色等信息
                span_obj = Span(
                    text=block_text,
                    bbox=correct_rotation(bbox, page),
                    span_id=f"{pnum}_{span_id}",
                    font=f"{s['font']}_{font_flags_decomposer(s['flags'])}", # 在字体后面添加字体标志
                    color=s["color"],
                    ascender=s["ascender"],
                    descender=s["descender"],
                )
                spans.append(span_obj)  # 将span对象添加到spans列表中
                span_id  = 1
            # 创建Line对象,包括spans列表和边界框
            line_obj = Line(
                spans=spans,
                bbox=correct_rotation(l["bbox"], page),
            )
            # 只选择有效的行,即边界框面积大于0的行
            if line_obj.area > 0:
                block_lines.append(line_obj)  # 将有效的行添加到block_lines列表中
        # 创建Block对象,包括lines列表和边界框
        block_obj = Block(
            lines=block_lines,
            bbox=correct_rotation(block["bbox"], page),
            pnum=pnum
        )
        # 只选择包含多行的块
        if len(block_lines) > 0:
            page_blocks.append(block_obj)  # 将包含多行的块添加到page_blocks列表中

    # 如果页面被旋转,重新对文本进行排序
    if rotation > 0:
        page_blocks = sort_rotated_text(page_blocks)
    return page_blocks  # 返回处理后的页面块列表
# 将单个页面转换为文本块,进行 OCR 处理
def convert_single_page(doc, pnum, tess_lang: str, spell_lang: Optional[str], no_text: bool, disable_ocr: bool = False, min_ocr_page: int = 2):
    # 初始化变量用于记录 OCR 页面数量、成功次数和失败次数
    ocr_pages = 0
    ocr_success = 0
    ocr_failed = 0
    spellchecker = None
    # 获取当前页面的边界框
    page_bbox = doc[pnum].bound()
    # 如果指定了拼写检查语言,则创建拼写检查器对象
    if spell_lang:
        spellchecker = SpellChecker(language=spell_lang)

    # 获取单个页面的文本块
    blocks = get_single_page_blocks(doc, pnum, tess_lang, spellchecker)
    # 创建页面对象,包含文本块、页面编号和边界框
    page_obj = Page(blocks=blocks, pnum=pnum, bbox=page_bbox)

    # 判断是否需要对页面进行 OCR 处理
    conditions = [
        (
            no_text  # 全文本为空,需要进行完整 OCR 处理
            or
            (len(page_obj.prelim_text) > 0 and detect_bad_ocr(page_obj.prelim_text, spellchecker))  # OCR 处理不佳
        ),
        min_ocr_page < pnum < len(doc) - 1,
        not disable_ocr
    ]
    if all(conditions) or settings.OCR_ALL_PAGES:
        # 获取当前页面对象
        page = doc[pnum]
        # 获取包含 OCR 处理的文本块
        blocks = get_single_page_blocks(doc, pnum, tess_lang, spellchecker, ocr=True)
        # 创建包含 OCR 处理的页面对象,包含文本块、页面编号、边界框和旋转信息
        page_obj = Page(blocks=blocks, pnum=pnum, bbox=page_bbox, rotation=page.rotation)
        ocr_pages = 1
        if len(blocks) == 0:
            ocr_failed = 1
        else:
            ocr_success = 1
    # 返回页面对象和 OCR 处理结果统计信息
    return page_obj, {"ocr_pages": ocr_pages, "ocr_failed": ocr_failed, "ocr_success": ocr_success}


# 获取文本块列表
def get_text_blocks(doc, tess_lang: str, spell_lang: Optional[str], max_pages: Optional[int] = None, parallel: int = settings.OCR_PARALLEL_WORKERS):
    all_blocks = []
    # 获取文档的目录
    toc = doc.get_toc()
    ocr_pages = 0
    ocr_failed = 0
    ocr_success = 0
    # 这是一个线程,因为大部分工作在一个单独的进程中进行(tesseract)
    range_end = len(doc)
    # 判断是否全文本为空
    no_text = len(naive_get_text(doc).strip()) == 0
    # 如果指定了最大页面数,则限制范围
    if max_pages:
        range_end = min(max_pages, len(doc))
    # 使用线程池执行并行任务,最大工作线程数为 parallel
    with ThreadPoolExecutor(max_workers=parallel) as pool:
        # 生成参数列表,包含文档、页数、Tesseract语言、拼写语言、是否无文本的元组
        args_list = [(doc, pnum, tess_lang, spell_lang, no_text) for pnum in range(range_end)]
        # 根据并行数选择使用 map 函数或线程池的 map 函数
        if parallel == 1:
            func = map
        else:
            func = pool.map
        # 执行函数并获取结果
        results = func(lambda a: convert_single_page(*a), args_list)
    
        # 遍历结果
        for result in results:
            # 获取页面对象和 OCR 统计信息
            page_obj, ocr_stats = result
            # 将页面对象添加到所有块列表中
            all_blocks.append(page_obj)
            # 更新 OCR 页面数、失败数和成功数
            ocr_pages  = ocr_stats["ocr_pages"]
            ocr_failed  = ocr_stats["ocr_failed"]
            ocr_success  = ocr_stats["ocr_success"]
    
    # 返回所有块列表、目录和 OCR 统计信息
    return all_blocks, toc, {"ocr_pages": ocr_pages, "ocr_failed": ocr_failed, "ocr_success": ocr_success}
# 定义一个函数,用于从文档中提取文本内容
def naive_get_text(doc):
    # 初始化一个空字符串,用于存储提取的文本内容
    full_text = ""
    # 遍历文档中的每一页
    for page in doc:
        # 获取当前页的文本内容,并按照指定的参数进行排序和处理
        full_text  = page.get_text("text", sort=True, flags=settings.TEXT_FLAGS)
        # 在每一页的文本内容后添加换行符
        full_text  = "n"
    # 返回整个文档的文本内容
    return full_text

.markermarkerlogger.py

代码语言:javascript复制
# 导入 logging 模块
import logging
# 导入 fitz 模块并重命名为 pymupdf
import fitz as pymupdf
# 导入 warnings 模块

# 配置日志记录
def configure_logging():
    # 设置日志级别为 WARNING
    logging.basicConfig(level=logging.WARNING)

    # 设置 pdfminer 模块的日志级别为 ERROR
    logging.getLogger('pdfminer').setLevel(logging.ERROR)
    # 设置 PIL 模块的日志级别为 ERROR
    logging.getLogger('PIL').setLevel(logging.ERROR)
    # 设置 fitz 模块的日志级别为 ERROR
    logging.getLogger('fitz').setLevel(logging.ERROR)
    # 设置 ocrmypdf 模块的日志级别为 ERROR
    logging.getLogger('ocrmypdf').setLevel(logging.ERROR)
    # 设置 fitz 模块的错误显示为 False
    pymupdf.TOOLS.mupdf_display_errors(False)
    # 忽略 FutureWarning 类别的警告
    warnings.simplefilter(action='ignore', category=FutureWarning)

.markermarkermarkdown.py

代码语言:javascript复制
# 从 marker.schema 模块中导入 MergedLine, MergedBlock, FullyMergedBlock, Page 类
from marker.schema import MergedLine, MergedBlock, FullyMergedBlock, Page
# 导入 re 模块,用于正则表达式操作
import re
# 从 typing 模块中导入 List 类型
from typing import List

# 定义一个函数,用于在文本两侧添加指定字符
def surround_text(s, char_to_insert):
    # 匹配字符串开头的空白字符
    leading_whitespace = re.match(r'^(s*)', s).group(1)
    # 匹配字符串结尾的空白字符
    trailing_whitespace = re.search(r'(s*)$', s).group(1)
    # 去除字符串两侧空白字符
    stripped_string = s.strip()
    # 在去除空白字符后的字符串两侧添加指定字符
    modified_string = char_to_insert   stripped_string   char_to_insert
    # 将添加指定字符后的字符串重新加上空白字符,形成最终字符串
    final_string = leading_whitespace   modified_string   trailing_whitespace
    return final_string

# 定义一个函数,用于合并块
def merge_spans(blocks):
    # 初始化一个空列表用于存储合并后的块
    merged_blocks = []
    return merged_blocks

# 定义一个函数,用于根据块类型对文本进行包围处理
def block_surround(text, block_type):
    if block_type == "Section-header":
        if not text.startswith("#"):
            text = "n## "   text.strip().title()   "n"
    elif block_type == "Title":
        if not text.startswith("#"):
            text = "# "   text.strip().title()   "n"
    elif block_type == "Table":
        text = "n"   text   "n"
    elif block_type == "List-item":
        pass
    elif block_type == "Code":
        text = "n"   text   "n"
    return text

# 定义一个函数,用于处理文本行之间的分隔符
def line_separator(line1, line2, block_type, is_continuation=False):
    # 包含拉丁衍生语言和俄语的小写字母
    lowercase_letters = "a-zà-öø-ÿа-яşćăâđêôơưþðæøå"
    # 包含拉丁衍生语言和俄语的大写字母
    uppercase_letters = "A-ZÀ-ÖØ-ßА-ЯŞĆĂÂĐÊÔƠƯÞÐÆØÅ"
    # 匹配当前行是否以连字符结尾,且下一行与当前行似乎连接在一起
    hyphen_pattern = re.compile(rf'.*[{lowercase_letters}][-]s?$', re.DOTALL)
    if line1 and hyphen_pattern.match(line1) and re.match(rf"^[{lowercase_letters}]", line2):
        # 从右侧分割连字符
        line1 = re.split(r"[-—]s?$", line1)[0]
        return line1.rstrip()   line2.lstrip()

    lowercase_pattern1 = re.compile(rf'.*[{lowercase_letters},]s?$', re.DOTALL)
    lowercase_pattern2 = re.compile(rf'^s?[{uppercase_letters}{lowercase_letters}]', re.DOTALL)
    end_pattern = re.compile(r'.*[.?!]s?$', re.DOTALL)
    # 如果块类型为标题或节标题,则返回去除右侧空格的line1和去除左侧空格的line2拼接的字符串
    if block_type in ["Title", "Section-header"]:
        return line1.rstrip()   " "   line2.lstrip()
    # 如果line1和line2都符合小写模式1和小写模式2,并且块类型为文本,则返回去除右侧空格的line1和去除左侧空格的line2拼接的字符串
    elif lowercase_pattern1.match(line1) and lowercase_pattern2.match(line2) and block_type == "Text":
        return line1.rstrip()   " "   line2.lstrip()
    # 如果是续行,则返回去除右侧空格的line1和去除左侧空格的line2拼接的字符串
    elif is_continuation:
        return line1.rstrip()   " "   line2.lstrip()
    # 如果块类型为文本且line1匹配结束模式,则返回line1后加上两个换行符和line2
    elif block_type == "Text" and end_pattern.match(line1):
        return line1   "nn"   line2
    # 如果块类型为公式,则返回line1后加上一个空格和line2
    elif block_type == "Formula":
        return line1   " "   line2
    # 其他情况下,返回line1后加上一个换行符和line2
    else:
        return line1   "n"   line2
# 定义一个函数,用于确定两个不同类型的文本块之间的分隔符
def block_separator(line1, line2, block_type1, block_type2):
    # 默认分隔符为换行符
    sep = "n"
    # 如果第一个块的类型是"Text",则分隔符为两个换行符
    if block_type1 == "Text":
        sep = "nn"

    # 返回第二行和分隔符
    return sep   line2


# 合并文本块中的行
def merge_lines(blocks, page_blocks: List[Page]):
    # 存储文本块的列表
    text_blocks = []
    prev_type = None
    prev_line = None
    block_text = ""
    block_type = ""
    # 存储每个页面的常见行高度统计信息
    common_line_heights = [p.get_line_height_stats() for p in page_blocks]
    # 遍历每个页面的文本块
    for page in blocks:
        for block in page:
            # 获取当前文本块的最常见类型
            block_type = block.most_common_block_type()
            # 如果当前类型与前一个类型不同且前一个类型存在,则将前一个文本块添加到列表中
            if block_type != prev_type and prev_type:
                text_blocks.append(
                    FullyMergedBlock(
                        text=block_surround(block_text, prev_type),
                        block_type=prev_type
                    )
                )
                block_text = ""

            prev_type = block_type
            # 将文本块中的行合并在一起
            for i, line in enumerate(block.lines):
                line_height = line.bbox[3] - line.bbox[1]
                prev_line_height = prev_line.bbox[3] - prev_line.bbox[1] if prev_line else 0
                prev_line_x = prev_line.bbox[0] if prev_line else 0
                prev_line = line
                is_continuation = line_height == prev_line_height and line.bbox[0] == prev_line_x
                if block_text:
                    block_text = line_separator(block_text, line.text, block_type, is_continuation)
                else:
                    block_text = line.text

    # 将最后一个文本块添加到列表中
    text_blocks.append(
        FullyMergedBlock(
            text=block_surround(block_text, prev_type),
            block_type=block_type
        )
    )
    return text_blocks


# 获取完整的文本
def get_full_text(text_blocks):
    full_text = ""
    prev_block = None
    # 遍历文本块列表
    for block in text_blocks:
        # 如果存在前一个文本块
        if prev_block:
            # 将前一个文本块、当前文本块、前一个文本块类型和当前文本块类型传入block_separator函数,将返回的结果添加到full_text中
            full_text  = block_separator(prev_block.text, block.text, prev_block.block_type, block.block_type)
        else:
            # 如果不存在前一个文本块,直接将当前文本块的内容添加到full_text中
            full_text  = block.text
        # 更新prev_block为当前文本块
        prev_block = block
    # 返回完整的文本
    return full_text

0 人点赞