Marker 源码解析(二)

2024-03-09 08:46:42 浏览数 (1)

.markermarkermodels.py

代码语言:javascript复制
# 从 marker.cleaners.equations 模块中导入 load_texify_model 函数
from marker.cleaners.equations import load_texify_model
# 从 marker.ordering 模块中导入 load_ordering_model 函数
from marker.ordering import load_ordering_model
# 从 marker.postprocessors.editor 模块中导入 load_editing_model 函数
from marker.postprocessors.editor import load_editing_model
# 从 marker.segmentation 模块中导入 load_layout_model 函数
from marker.segmentation import load_layout_model

# 定义一个函数用于加载所有模型
def load_all_models():
    # 调用 load_editing_model 函数,加载编辑模型
    edit = load_editing_model()
    # 调用 load_ordering_model 函数,加载排序模型
    order = load_ordering_model()
    # 调用 load_layout_model 函数,加载布局模型
    layout = load_layout_model()
    # 调用 load_texify_model 函数,加载 TeXify 模型
    texify = load_texify_model()
    # 将加载的模型按顺序存储在列表中
    model_lst = [texify, layout, order, edit]
    # 返回模型列表
    return model_lst

.markermarkerocrpage.py

代码语言:javascript复制
import io  # 导入io模块
from typing import List, Optional  # 导入类型提示相关模块

import fitz as pymupdf  # 导入fitz模块并重命名为pymupdf
import ocrmypdf  # 导入ocrmypdf模块
from spellchecker import SpellChecker  # 从spellchecker模块导入SpellChecker类

from marker.ocr.utils import detect_bad_ocr  # 从marker.ocr.utils模块导入detect_bad_ocr函数
from marker.schema import Block  # 从marker.schema模块导入Block类
from marker.settings import settings  # 从marker.settings模块导入settings变量

ocrmypdf.configure_logging(verbosity=ocrmypdf.Verbosity.quiet)  # 配置ocrmypdf的日志记录级别为quiet

# 对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
    # 如果OCR_ENGINE设置为"tesseract",则调用ocr_entire_page_tess函数
    if settings.OCR_ENGINE == "tesseract":
        return ocr_entire_page_tess(page, lang, spellchecker)
    # 如果OCR_ENGINE设置为"ocrmypdf",则调用ocr_entire_page_ocrmp函数
    elif settings.OCR_ENGINE == "ocrmypdf":
        return ocr_entire_page_ocrmp(page, lang, spellchecker)
    else:
        raise ValueError(f"Unknown OCR engine {settings.OCR_ENGINE}")  # 抛出数值错误异常,显示未知的OCR引擎

# 使用tesseract对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page_tess(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
    try:
        # 获取页面的完整OCR文本页
        full_tp = page.get_textpage_ocr(flags=settings.TEXT_FLAGS, dpi=settings.OCR_DPI, full=True, language=lang)
        # 获取页面的文本块列表
        blocks = page.get_text("dict", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)["blocks"]
        # 获取页面的完整文本
        full_text = page.get_text("text", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)

        # 如果完整文本长度为0,则返回空列表
        if len(full_text) == 0:
            return []

        # 检查OCR是否成功。如果失败,返回空列表
        # 例如,如果有一张扫描的空白页上有一些淡淡的文本印记,OCR可能会失败
        if detect_bad_ocr(full_text, spellchecker):
            return []
    except RuntimeError:
        return []
    return blocks  # 返回文本块列表

# 使用ocrmypdf对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page_ocrmp(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
    # 使用ocrmypdf获取整个页面的OCR文本
    src = page.parent  # 页面所属文档
    blank_doc = pymupdf.open()  # 创建临时的1页文档
    blank_doc.insert_pdf(src, from_page=page.number, to_page=page.number, annots=False, links=False)  # 插入PDF页面
    pdfbytes = blank_doc.tobytes()  # 获取文档字节流
    inbytes = io.BytesIO(pdfbytes)  # 转换为BytesIO对象
    # 创建一个字节流对象,用于存储 ocrmypdf 处理后的结果 PDF
    outbytes = io.BytesIO()  # let ocrmypdf store its result pdf here
    # 使用 ocrmypdf 进行 OCR 处理
    ocrmypdf.ocr(
        inbytes,
        outbytes,
        language=lang,
        output_type="pdf",
        redo_ocr=None if settings.OCR_ALL_PAGES else True,
        force_ocr=True if settings.OCR_ALL_PAGES else None,
        progress_bar=False,
        optimize=False,
        fast_web_view=1e6,
        skip_big=15, # skip images larger than 15 megapixels
        tesseract_timeout=settings.TESSERACT_TIMEOUT,
        tesseract_non_ocr_timeout=settings.TESSERACT_TIMEOUT,
    )
    # 以 fitz PDF 格式打开 OCR 处理后的输出
    ocr_pdf = pymupdf.open("pdf", outbytes.getvalue())  # read output as fitz PDF
    # 获取 OCR 处理后的文本块信息
    blocks = ocr_pdf[0].get_text("dict", sort=True, flags=settings.TEXT_FLAGS)["blocks"]
    # 获取 OCR 处理后的完整文本
    full_text = ocr_pdf[0].get_text("text", sort=True, flags=settings.TEXT_FLAGS)

    # 确保原始 PDF/EPUB/MOBI 的边界框和 OCR 处理后的 PDF 的边界框相同
    assert page.bound() == ocr_pdf[0].bound()

    # 如果完整文本为空,则返回空列表
    if len(full_text) == 0:
        return []

    # 如果检测到 OCR 处理不良,则返回空列表
    if detect_bad_ocr(full_text, spellchecker):
        return []

    # 返回文本块信息
    return blocks

.markermarkerocrutils.py

代码语言:javascript复制
# 导入必要的模块和类
from typing import Optional
from nltk import wordpunct_tokenize
from spellchecker import SpellChecker
from marker.settings import settings
import re

# 检测 OCR 文本质量是否差,返回布尔值
def detect_bad_ocr(text, spellchecker: Optional[SpellChecker], misspell_threshold=.7, space_threshold=.6, newline_threshold=.5, alphanum_threshold=.4):
    # 如果文本长度为0,则假定 OCR 失败
    if len(text) == 0:
        return True

    # 使用 wordpunct_tokenize 函数将文本分词
    words = wordpunct_tokenize(text)
    # 过滤掉空白字符
    words = [w for w in words if w.strip()]
    # 提取文本中的字母数字字符
    alpha_words = [word for word in words if word.isalnum()]

    # 如果提供了拼写检查器
    if spellchecker:
        # 检查文本中的拼写错误
        misspelled = spellchecker.unknown(alpha_words)
        # 如果拼写错误数量超过阈值,则返回 True
        if len(misspelled) > len(alpha_words) * misspell_threshold:
            return True

    # 计算文本中空格的数量
    spaces = len(re.findall(r's ', text))
    # 计算文本中字母字符的数量
    alpha_chars = len(re.sub(r's ', '', text))
    # 如果空格占比超过阈值,则返回 True
    if spaces / (alpha_chars   spaces) > space_threshold:
        return True

    # 计算文本中换行符的数量
    newlines = len(re.findall(r'n ', text))
    # 计算文本中非换行符的数量
    non_newlines = len(re.sub(r'n ', '', text))
    # 如果换行符占比超过阈值,则返回 True
    if newlines / (newlines   non_newlines) > newline_threshold:
        return True

    # 如果文本中字母数字字符比例低于阈值,则返回 True
    if alphanum_ratio(text) < alphanum_threshold: # Garbled text
        return True

    # 计算文本中无效字符的数量
    invalid_chars = len([c for c in text if c in settings.INVALID_CHARS])
    # 如果无效字符数量超过阈值,则返回 True
    if invalid_chars > max(3.0, len(text) * .02):
        return True

    # 默认情况下返回 False
    return False

# 将字体标志拆解为可读的形式
def font_flags_decomposer(flags):
    l = []
    # 检查字体标志中是否包含上标
    if flags & 2 ** 0:
        l.append("superscript")
    # 检查字体标志中是否包含斜体
    if flags & 2 ** 1:
        l.append("italic")
    # 检查字体标志中是否包含衬线
    if flags & 2 ** 2:
        l.append("serifed")
    else:
        l.append("sans")
    # 检查字体标志中是否包含等宽字体
    if flags & 2 ** 3:
        l.append("monospaced")
    else:
        l.append("proportional")
    # 检查字体标志中是否包含粗体
    if flags & 2 ** 4:
        l.append("bold")
    # 返回拆解后的字体标志字符串
    return "_".join(l)

# 计算文本中字母数字字符的比例
def alphanum_ratio(text):
    # 去除文本中的空格和换行符
    text = text.replace(" ", "")
    text = text.replace("n", "")
    # 统计文本中的字母数字字符数量
    alphanumeric_count = sum([1 for c in text if c.isalnum()])

    # 如果文本长度为0,则返回1
    if len(text) == 0:
        return 1

    # 计算字母数字字符比例
    ratio = alphanumeric_count / len(text)
    # 返回变量 ratio 的值
    return ratio

.markermarkerordering.py

代码语言:javascript复制
# 导入必要的模块
from copy import deepcopy
from typing import List
import torch
import sys, os
from marker.extract_text import convert_single_page
from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3Processor
from PIL import Image
import io
from marker.schema import Page
from marker.settings import settings

# 从设置中加载 LayoutLMv3Processor 模型
processor = LayoutLMv3Processor.from_pretrained(settings.ORDERER_MODEL_NAME)

# 加载 LayoutLMv3ForSequenceClassification 模型
def load_ordering_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(
        settings.ORDERER_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()
    return model

# 获取推理数据
def get_inference_data(page, page_blocks: Page):
    # 深拷贝页面块的边界框
    bboxes = deepcopy([block.bbox for block in page_blocks.blocks])
    # 初始化单词列表
    words = ["."] * len(bboxes)

    # 获取页面的像素图像
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    # 将像素图像转换为 PNG 格式
    png = pix.pil_tobytes(format="PNG")
    # 将 PNG 数据转换为 RGB 图像
    rgb_image = Image.open(io.BytesIO(png)).convert("RGB")

    # 获取页面块的边界框和宽高
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height

    # 调整边界框的值
    for box in bboxes:
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]

        # 将边界框的值转换为相对于页面宽高的比例
        box[0] = int(box[0] / pwidth * 1000)
        box[1] = int(box[1] / pheight * 1000)
        box[2] = int(box[2] / pwidth * 1000)
        box[3] = int(box[3] / pheight * 1000)

    return rgb_image, bboxes, words

# 批量推理
def batch_inference(rgb_images, bboxes, words, model):
    # 对 RGB 图像、单词和边界框进行编码
    encoding = processor(
        rgb_images,
        text=words,
        boxes=bboxes,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )

    # 将像素值转换为模型的数据类型
    encoding["pixel_values"] = encoding["pixel_values"].to(model.dtype)
    # 进入推断模式,不进行梯度计算
    with torch.inference_mode():
        # 将指定的键对应的值移动到模型所在设备上
        for k in ["bbox", "input_ids", "pixel_values", "attention_mask"]:
            encoding[k] = encoding[k].to(model.device)
        # 使用模型进行推理,获取输出
        outputs = model(**encoding)
        # 获取模型输出的预测结果
        logits = outputs.logits

    # 获取预测结果中概率最大的类别索引,并转换为列表
    predictions = logits.argmax(-1).squeeze().tolist()
    # 如果预测结果是整数,则转换为列表
    if isinstance(predictions, int):
        predictions = [predictions]
    # 将预测结果转换为类别标签
    predictions = [model.config.id2label[p] for p in predictions]
    # 返回预测结果
    return predictions
# 为文档中的每个块添加列数计数
def add_column_counts(doc, doc_blocks, model, batch_size):
    # 按照批量大小遍历文档块
    for i in range(0, len(doc_blocks), batch_size):
        # 创建当前批量的索引范围
        batch = range(i, min(i   batch_size, len(doc_blocks)))
        # 初始化空列表用于存储 RGB 图像、边界框和单词
        rgb_images = []
        bboxes = []
        words = []
        # 遍历当前批量的页码
        for pnum in batch:
            # 获取推理数据:RGB 图像、页边界框和页单词
            page = doc[pnum]
            rgb_image, page_bboxes, page_words = get_inference_data(page, doc_blocks[pnum])
            rgb_images.append(rgb_image)
            bboxes.append(page_bboxes)
            words.append(page_words)

        # 进行批量推理,获取预测结果
        predictions = batch_inference(rgb_images, bboxes, words, model)
        # 将预测结果与页码对应,更新文档块的列数计数
        for pnum, prediction in zip(batch, predictions):
            doc_blocks[pnum].column_count = prediction

# 对文档块进行排序
def order_blocks(doc, doc_blocks: List[Page], model, batch_size=settings.ORDERER_BATCH_SIZE):
    # 添加列数计数
    add_column_counts(doc, doc_blocks, model, batch_size)

    # 遍历文档块中的每一页
    for page_blocks in doc_blocks:
        # 如果该页的列数大于1
        if page_blocks.column_count > 1:
            # 根据位置重新排序块
            split_pos = page_blocks.x_start   page_blocks.width / 2
            left_blocks = []
            right_blocks = []
            # 遍历该页的每个块
            for block in page_blocks.blocks:
                # 根据位置将块分为左右两部分
                if block.x_start <= split_pos:
                    left_blocks.append(block)
                else:
                    right_blocks.append(block)
            # 更新该页的块顺序
            page_blocks.blocks = left_blocks   right_blocks
    # 返回排序后的文档块
    return doc_blocks

.markermarkerpostprocessorseditor.py

代码语言:javascript复制
# 导入必要的库
from collections import defaultdict, Counter
from itertools import chain
from typing import Optional

# 导入 transformers 库中的 AutoTokenizer 类
from transformers import AutoTokenizer

# 导入 settings 模块中的 settings 变量
from marker.settings import settings

# 导入 torch 库
import torch
import torch.nn.functional as F

# 导入 marker.postprocessors.t5 模块中的 T5ForTokenClassification 类和 byt5_tokenize 函数
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize

# 定义加载编辑模型的函数
def load_editing_model():
    # 如果未启用编辑模型,则返回 None
    if not settings.ENABLE_EDITOR_MODEL:
        return None

    # 从预训练模型中加载 T5ForTokenClassification 模型
    model = T5ForTokenClassification.from_pretrained(
            settings.EDITOR_MODEL_NAME,
            torch_dtype=settings.MODEL_DTYPE,
        ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()

    # 配置模型的标签映射
    model.config.label2id = {
        "equal": 0,
        "delete": 1,
        "newline-1": 2,
        "space-1": 3,
    }
    model.config.id2label = {v: k for k, v in model.config.label2id.items()}
    return model

# 定义编辑全文的函数
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_size: int = settings.EDITOR_BATCH_SIZE):
    # 如果模型为空,则直接返回原始文本和空字典
    if not model:
        return text, {}

    # 对文本进行 tokenization
    tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
    input_ids = tokenized["input_ids"]
    char_token_lengths = tokenized["char_token_lengths"]

    # 准备 token_masks 列表
    token_masks = []
    # 遍历输入的 input_ids,按照 batch_size 进行分批处理
    for i in range(0, len(input_ids), batch_size):
        # 从 tokenized 中获取当前 batch 的 input_ids
        batch_input_ids = tokenized["input_ids"][i: i   batch_size]
        # 将 batch_input_ids 转换为 torch 张量,并指定设备为 model 的设备
        batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
        # 从 tokenized 中获取当前 batch 的 attention_mask
        batch_attention_mask = tokenized["attention_mask"][i: i   batch_size]
        # 将 batch_attention_mask 转换为 torch 张量,并指定设备为 model 的设备
        batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
        
        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行预测
            predictions = model(batch_input_ids, attention_mask=batch_attention_mask)

        # 将预测结果 logits 移动到 CPU 上
        logits = predictions.logits.cpu()

        # 如果最大概率小于阈值,则假设为不良预测
        # 我们希望保守一点,不要对文本进行过多编辑
        probs = F.softmax(logits, dim=-1)
        max_prob = torch.max(probs, dim=-1)
        cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
        labels = logits.argmax(-1)
        labels[cutoff_prob] = model.config.label2id["equal"]
        labels = labels.squeeze().tolist()
        if len(labels) == settings.EDITOR_MAX_LENGTH:
            labels = [labels]
        labels = list(chain.from_iterable(labels))
        token_masks.extend(labels)

    # 文本中的字符列表
    flat_input_ids = list(chain.from_iterable(input_ids)

    # 去除特殊标记 0,1。保留未知标记,尽管它不应该被使用
    assert len(token_masks) == len(flat_input_ids)
    token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]

    # 确保 token_masks 的长度与文本编码后的长度相等
    assert len(token_masks) == len(list(text.encode("utf-8")))

    # 统计编辑次数的字典
    edit_stats = defaultdict(int)
    # 输出文本列表
    out_text = []
    # 起始位置
    start = 0
    # 遍历文本中的每个字符及其索引
    for i, char in enumerate(text):
        # 获取当前字符对应的 token 长度
        char_token_length = char_token_lengths[i]
        # 获取当前字符对应的 token 的 mask
        masks = token_masks[start: start   char_token_length]
        # 将 mask 转换为标签
        labels = [model.config.id2label[mask] for mask in masks]
        # 如果所有标签都是 "delete",则执行删除操作
        if all(l == "delete" for l in labels):
            # 如果删除的是空格,则保留,否则忽略
            if char.strip():
                out_text.append(char)
            else:
                edit_stats["delete"]  = 1
        # 如果标签为 "newline-1",则添加换行符
        elif labels[0] == "newline-1":
            out_text.append("n")
            out_text.append(char)
            edit_stats["newline-1"]  = 1
        # 如果标签为 "space-1",则添加空格
        elif labels[0] == "space-1":
            out_text.append(" ")
            out_text.append(char)
            edit_stats["space-1"]  = 1
        # 如果标签为其他情况,则保留字符
        else:
            out_text.append(char)
            edit_stats["equal"]  = 1

        # 更新下一个字符的起始位置
        start  = char_token_length

    # 将处理后的文本列表转换为字符串
    out_text = "".join(out_text)
    # 返回处理后的文本及编辑统计信息
    return out_text, edit_stats

.markermarkerpostprocessorst5.py

代码语言:javascript复制
# 从 transformers 库中导入 T5Config 和 T5PreTrainedModel 类
from transformers import T5Config, T5PreTrainedModel
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 copy 库中导入 deepcopy 函数
from copy import deepcopy
# 从 typing 库中导入 Optional, Tuple, Union, List 类型
from typing import Optional, Tuple, Union, List
# 从 itertools 库中导入 chain 函数
from itertools import chain

# 从 transformers.modeling_outputs 模块中导入 TokenClassifierOutput 类
from transformers.modeling_outputs import TokenClassifierOutput
# 从 transformers.models.t5.modeling_t5 模块中导入 T5Stack 类
from transformers.models.t5.modeling_t5 import T5Stack
# 从 transformers.utils.model_parallel_utils 模块中导入 get_device_map, assert_device_map 函数
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map

# 定义一个函数,用于将文本进行字节编码并分词
def byt5_tokenize(text: str, max_length: int, pad_token_id: int = 0):
    # 初始化一个空列表,用于存储字节编码
    byte_codes = []
    # 遍历文本中的每个字符
    for char in text:
        # 将每个字符进行 UTF-8 编码,并加上 3 以考虑特殊标记
        byte_codes.append([byte   3 for byte in char.encode('utf-8')])

    # 将字节编码展开成一个列表
    tokens = list(chain.from_iterable(byte_codes))
    # 记录每个字符对应的 token 长度
    char_token_lengths = [len(b) for b in byte_codes]

    # 初始化批量 token 和注意力掩码列表
    batched_tokens = []
    attention_mask = []
    # 按照最大长度将 token 进行分批
    for i in range(0, len(tokens), max_length):
        batched_tokens.append(tokens[i:i   max_length])
        attention_mask.append([1] * len(batched_tokens[-1])

    # 对最后一个批次进行填充
    if len(batched_tokens[-1]) < max_length:
        batched_tokens[-1]  = [pad_token_id] * (max_length - len(batched_tokens[-1]))
        attention_mask[-1]  = [0] * (max_length - len(attention_mask[-1]))

    # 返回包含分词结果的字典
    return {"input_ids": batched_tokens, "attention_mask": attention_mask, "char_token_lengths": char_token_lengths}

# 定义一个 T5ForTokenClassification 类,继承自 T5PreTrainedModel 类
class T5ForTokenClassification(T5PreTrainedModel):
    # 定义一个列表,用于指定加载时忽略的键
    _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
    # 初始化函数,接受一个T5Config对象作为参数
    def __init__(self, config: T5Config):
        # 调用父类的初始化函数
        super().__init__(config)
        # 设置模型维度为配置中的d_model值
        self.model_dim = config.d_model

        # 创建一个共享的嵌入层,词汇表大小为config.vocab_size,维度为config.d_model
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        # 复制配置对象,用于创建编码器
        encoder_config = deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.is_encoder_decoder = False
        encoder_config.use_cache = False
        # 创建T5Stack编码器
        self.encoder = T5Stack(encoder_config, self.shared)

        # 设置分类器的dropout值
        classifier_dropout = (
            config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate
        )
        self.dropout = nn.Dropout(classifier_dropout)
        # 创建一个线性层,输入维度为config.d_model,输出维度为config.num_labels
        self.classifier = nn.Linear(config.d_model, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

        # 模型并行化
        self.model_parallel = False
        self.device_map = None


    # 并行化函数,接受一个设备映射device_map作为参数
    def parallelize(self, device_map=None):
        # 如果未提供device_map,则根据编码器块的数量和GPU数量生成一个默认的device_map
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        # 检查设备映射的有效性
        assert_device_map(self.device_map, len(self.encoder.block))
        # 将编码器并行化
        self.encoder.parallelize(self.device_map)
        # 将分类器移动到编码器的第一个设备上
        self.classifier.to(self.encoder.first_device)
        self.model_parallel = True

    # 反并行化函数
    def deparallelize(self):
        # 取消编码器的并行化
        self.encoder.deparallelize()
        # 将编码器和分类器移动到CPU上
        self.encoder = self.encoder.to("cpu")
        self.classifier = self.classifier.to("cpu")
        self.model_parallel = False
        self.device_map = None
        # 释放GPU缓存
        torch.cuda.empty_cache()

    # 获取输入嵌入层函数
    def get_input_embeddings(self):
        return self.shared

    # 设置输入嵌入层函数,接受一个新的嵌入层new_embeddings作为参数
    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        # 设置编码器的输入嵌入层为新的嵌入层
        self.encoder.set_input_embeddings(new_embeddings)

    # 获取编码器函数
    def get_encoder(self):
        return self.encoder
    # 对模型中的特定头部进行修剪
    def _prune_heads(self, heads_to_prune):
        # 遍历需要修剪的层和头部
        for layer, heads in heads_to_prune.items():
            # 调用 SelfAttention 模块的 prune_heads 方法进行修剪
            self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)

    # 前向传播函数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
        # 如果 return_dict 为 None,则使用配置中的 use_return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用编码器进行前向传播
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取序列输出
        sequence_output = outputs[0]

        # 对序列输出进行 dropout
        sequence_output = self.dropout(sequence_output)
        # 将序列输出传入分类器得到 logits
        logits = self.classifier(sequence_output)

        # 初始化损失为 None
        loss = None

        # 如果不使用 return_dict,则返回输出结果
        if not return_dict:
            output = (logits,)   outputs[2:]
            return ((loss,)   output) if loss is not None else output

        # 使用 TokenClassifierOutput 类返回结果
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions
        )

.markermarkerschema.py

代码语言:javascript复制
# 导入 Counter 类,用于计数
# 导入 List、Optional、Tuple 类型,用于类型提示
from collections import Counter
from typing import List, Optional, Tuple

# 导入 BaseModel、field_validator 类,用于定义数据模型和字段验证
# 导入 ftfy 模块,用于修复文本中的 Unicode 错误
from pydantic import BaseModel, field_validator
import ftfy

# 导入 boxes_intersect_pct、multiple_boxes_intersect 函数,用于计算两个框的交集比例和多个框的交集情况
# 导入 settings 模块,用于获取配置信息
from marker.bbox import boxes_intersect_pct, multiple_boxes_intersect
from marker.settings import settings

# 定义函数 find_span_type,用于查找给定 span 在页面块中的类型
def find_span_type(span, page_blocks):
    # 默认块类型为 "Text"
    block_type = "Text"
    # 遍历页面块列表
    for block in page_blocks:
        # 如果 span 的边界框与页面块的边界框有交集
        if boxes_intersect_pct(span.bbox, block.bbox):
            # 更新块类型为页面块的类型
            block_type = block.block_type
            break
    # 返回块类型
    return block_type

# 定义类 BboxElement,继承自 BaseModel 类,表示具有边界框的元素
class BboxElement(BaseModel):
    bbox: List[float]

    # 验证 bbox 字段是否包含 4 个元素
    @field_validator('bbox')
    @classmethod
    def check_4_elements(cls, v: List[float]) -> List[float]:
        if len(v) != 4:
            raise ValueError('bbox must have 4 elements')
        return v

    # 计算元素的高度、宽度、起始 x 坐标、起始 y 坐标、面积
    @property
    def height(self):
        return self.bbox[3] - self.bbox[1]

    @property
    def width(self):
        return self.bbox[2] - self.bbox[0]

    @property
    def x_start(self):
        return self.bbox[0]

    @property
    def y_start(self):
        return self.bbox[1]

    @property
    def area(self):
        return self.width * self.height

# 定义类 BlockType,继承自 BboxElement 类,表示具有块类型的元素
class BlockType(BboxElement):
    block_type: str

# 定义类 Span,继承自 BboxElement 类,表示具有文本内容的元素
class Span(BboxElement):
    text: str
    span_id: str
    font: str
    color: int
    ascender: Optional[float] = None
    descender: Optional[float] = None
    block_type: Optional[str] = None
    selected: bool = True

    # 修复文本中的 Unicode 错误
    @field_validator('text')
    @classmethod
    def fix_unicode(cls, text: str) -> str:
        return ftfy.fix_text(text)

# 定义类 Line,继承自 BboxElement 类,表示具有多个 Span 的行元素
class Line(BboxElement):
    spans: List[Span]

    # 获取行的预备文本,即所有 Span 的文本拼接而成
    @property
    def prelim_text(self):
        return "".join([s.text for s in self.spans])

    # 获取行的起始 x 坐标
    @property
    def start(self):
        return self.spans[0].bbox[0]

# 定义类 Block,继承自 BboxElement 类,表示具有多个 Line 的块元素
class Block(BboxElement):
    lines: List[Line]
    pnum: int

    # 获取块的预备文本,即所有 Line 的预备文本拼接而成
    @property
    def prelim_text(self):
        return "n".join([l.prelim_text for l in self.lines])
    # 检查文本块是否包含公式,通过检查每个 span 的 block_type 是否为 "Formula" 来确定
    def contains_equation(self, equation_boxes=None):
        # 生成一个包含每个 span 的 block_type 是否为 "Formula" 的条件列表
        conditions = [s.block_type == "Formula" for l in self.lines for s in l.spans]
        # 如果提供了 equation_boxes 参数,则添加一个条件,检查文本块的边界框是否与给定框相交
        if equation_boxes:
            conditions  = [multiple_boxes_intersect(self.bbox, equation_boxes)]
        # 返回条件列表中是否有任何一个条件为 True
        return any(conditions)

    # 过滤掉包含在 bad_span_ids 中的 span
    def filter_spans(self, bad_span_ids):
        new_lines = []
        for line in self.lines:
            new_spans = []
            for span in line.spans:
                # 如果 span 的 span_id 不在 bad_span_ids 中,则保留该 span
                if not span.span_id in bad_span_ids:
                    new_spans.append(span)
            # 更新 line 的 spans 属性为过滤后的 new_spans
            line.spans = new_spans
            # 如果 line 中仍有 spans,则将其添加到 new_lines 中
            if len(new_spans) > 0:
                new_lines.append(line)
        # 更新 self.lines 为过滤后的 new_lines
        self.lines = new_lines

    # 过滤掉包含在 settings.BAD_SPAN_TYPES 中的 span 的 block_type
    def filter_bad_span_types(self):
        new_lines = []
        for line in self.lines:
            new_spans = []
            for span in line.spans:
                # 如果 span 的 block_type 不在 BAD_SPAN_TYPES 中,则保留该 span
                if span.block_type not in settings.BAD_SPAN_TYPES:
                    new_spans.append(span)
            # 更新 line 的 spans 属性为过滤后的 new_spans
            line.spans = new_spans
            # 如果 line 中仍有 spans,则将其添加到 new_lines 中
            if len(new_spans) > 0:
                new_lines.append(line)
        # 更新 self.lines 为过滤后的 new_lines
        self.lines = new_lines

    # 返回文本块中出现频率最高的 block_type
    def most_common_block_type(self):
        # 统计每个 span 的 block_type 出现的次数
        counter = Counter([s.block_type for l in self.lines for s in l.spans])
        # 返回出现次数最多的 block_type
        return counter.most_common(1)[0][0]

    # 设置文本块中所有 span 的 block_type 为给定的 block_type
    def set_block_type(self, block_type):
        for line in self.lines:
            for span in line.spans:
                # 将 span 的 block_type 设置为给定的 block_type
                span.block_type = block_type
# 定义一个名为 Page 的类,继承自 BboxElement 类
class Page(BboxElement):
    # 类属性:blocks 为 Block 对象列表,pnum 为整数,column_count 和 rotation 可选整数,默认为 None
    blocks: List[Block]
    pnum: int
    column_count: Optional[int] = None
    rotation: Optional[int] = None # 页面的旋转角度

    # 获取页面中非空行的方法
    def get_nonblank_lines(self):
        # 获取页面中所有行
        lines = self.get_all_lines()
        # 过滤出非空行
        nonblank_lines = [l for l in lines if l.prelim_text.strip()]
        return nonblank_lines

    # 获取页面中所有行的方法
    def get_all_lines(self):
        # 获取页面中所有行的列表
        lines = [l for b in self.blocks for l in b.lines]
        return lines

    # 获取页面中非空跨度的方法,返回 Span 对象列表
    def get_nonblank_spans(self) -> List[Span]:
        # 获取页面中所有行
        lines = [l for b in self.blocks for l in b.lines]
        # 过滤出非空跨度
        spans = [s for l in lines for s in l.spans if s.text.strip()]
        return spans

    # 添加块类型到行的方法
    def add_block_types(self, page_block_types):
        # 如果检测到的块类型数量与页面行数不匹配,则打印警告信息
        if len(page_block_types) != len(self.get_all_lines()):
            print(f"Warning: Number of detected lines {len(page_block_types)} does not match number of lines {len(self.get_all_lines())}")

        i = 0
        for block in self.blocks:
            for line in block.lines:
                if i < len(page_block_types):
                    line_block_type = page_block_types[i].block_type
                else:
                    line_block_type = "Text"
                i  = 1
                for span in line.spans:
                    span.block_type = line_block_type

    # 获取页面中字体统计信息的方法
    def get_font_stats(self):
        # 获取页面中非空跨度的字体信息
        fonts = [s.font for s in self.get_nonblank_spans()]
        # 统计字体出现次数
        font_counts = Counter(fonts)
        return font_counts

    # 获取页面中行高统计信息的方法
    def get_line_height_stats(self):
        # 获取页面中非空行的行高信息
        heights = [l.bbox[3] - l.bbox[1] for l in self.get_nonblank_lines()]
        # 统计行高出现次数
        height_counts = Counter(heights)
        return height_counts

    # 获取页面中行起始位置统计信息的方法
    def get_line_start_stats(self):
        # 获取页面中非空行的起始位置信息
        starts = [l.bbox[0] for l in self.get_nonblank_lines()]
        # 统计起始位置出现次数
        start_counts = Counter(starts)
        return start_counts
    # 获取文本块中非空行的起始位置列表
    def get_min_line_start(self):
        # 通过列表推导式获取非空行的起始位置,并且该行为文本类型
        starts = [l.bbox[0] for l in self.get_nonblank_lines() if l.spans[0].block_type == "Text"]
        # 如果没有找到非空行,则抛出索引错误
        if len(starts) == 0:
            raise IndexError("No lines found")
        # 返回起始位置列表中的最小值
        return min(starts)

    # 获取文本块中每个文本块的 prelim_text 属性,并用换行符连接成字符串
    @property
    def prelim_text(self):
        return "n".join([b.prelim_text for b in self.blocks])
# 定义一个继承自BboxElement的MergedLine类,包含文本和字体列表属性
class MergedLine(BboxElement):
    text: str
    fonts: List[str]

    # 返回该行中出现频率最高的字体
    def most_common_font(self):
        # 统计字体列表中各个字体出现的次数
        counter = Counter(self.fonts)
        # 返回出现频率最高的字体
        return counter.most_common(1)[0][0]


# 定义一个继承自BboxElement的MergedBlock类,包含行列表、段落号和块类型列表属性
class MergedBlock(BboxElement):
    lines: List[MergedLine]
    pnum: int
    block_types: List[str]

    # 返回该块中出现频率最高的块类型
    def most_common_block_type(self):
        # 统计块类型列表中各个类型出现的次数
        counter = Counter(self.block_types)
        # 返回出现频率最高的块类型
        return counter.most_common(1)[0][0]


# 定义一个继承自BaseModel的FullyMergedBlock类,包含文本和块类型属性
class FullyMergedBlock(BaseModel):
    text: str
    block_type: str

.markermarkersegmentation.py

代码语言:javascript复制
# 导入所需的库
from concurrent.futures import ThreadPoolExecutor
from typing import List

from transformers import LayoutLMv3ForTokenClassification

# 导入自定义的模块
from marker.bbox import unnormalize_box
from transformers.models.layoutlmv3.image_processing_layoutlmv3 import normalize_box
import io
from PIL import Image
from transformers import LayoutLMv3Processor
import numpy as np
from marker.settings import settings
from marker.schema import Page, BlockType
import torch
from math import isclose

# 设置图像最大像素值,避免部分图像被截断
Image.MAX_IMAGE_PIXELS = None

# 从预训练模型加载 LayoutLMv3Processor
processor = LayoutLMv3Processor.from_pretrained(settings.LAYOUT_MODEL_NAME, apply_ocr=False)

# 定义需要分块的键和不需要分块的键
CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
NO_CHUNK_KEYS = ["pixel_values"]

# 加载 LayoutLMv3ForTokenClassification 模型
def load_layout_model():
    # 从预训练模型加载 LayoutLMv3ForTokenClassification 模型
    model = LayoutLMv3ForTokenClassification.from_pretrained(
        settings.LAYOUT_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)

    # 设置模型的标签映射
    model.config.id2label = {
        0: "Caption",
        1: "Footnote",
        2: "Formula",
        3: "List-item",
        4: "Page-footer",
        5: "Page-header",
        6: "Picture",
        7: "Section-header",
        8: "Table",
        9: "Text",
        10: "Title"
    }

    model.config.label2id = {v: k for k, v in model.config.id2label.items()}
    return model

# 检测文档块类型
def detect_document_block_types(doc, blocks: List[Page], layoutlm_model, batch_size=settings.LAYOUT_BATCH_SIZE):
    # 获取特征编码、元数据和样本长度
    encodings, metadata, sample_lengths = get_features(doc, blocks)
    # 预测块类型
    predictions = predict_block_types(encodings, layoutlm_model, batch_size)
    # 将预测结果与框匹配
    block_types = match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model)
    # 断言块类型数量与块数量相等
    assert len(block_types) == len(blocks)
    return block_types

# 获取临时框
def get_provisional_boxes(pred, box, is_subword, start_idx=0):
    # 从预测结果中获取临时框
    prov_predictions = [pred_ for idx, pred_ in enumerate(pred) if not is_subword[idx]][start_idx:]
    # 从列表中筛选出不是子词的元素,并从指定索引开始切片,得到新的列表
    prov_boxes = [box_ for idx, box_ in enumerate(box) if not is_subword[idx]][start_idx:]
    # 返回处理后的预测结果和框
    return prov_predictions, prov_boxes
# 获取页面编码信息,输入参数为页面和页面块对象
def get_page_encoding(page, page_blocks: Page):
    # 如果页面块中的所有行数为0,则返回空列表
    if len(page_blocks.get_all_lines()) == 0:
        return [], []

    # 获取页面块的边界框、宽度和高度
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height

    # 获取页面块的像素图,并转换为 PNG 格式
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    png = pix.pil_tobytes(format="PNG")
    png_image = Image.open(io.BytesIO(png))
    # 如果图像太大,则缩小以适应模型
    rgb_image = png_image.convert('RGB')
    rgb_width, rgb_height = rgb_image.size

    # 确保图像大小与 PDF 页面的比例正确
    assert isclose(rgb_width / pwidth, rgb_height / pheight, abs_tol=2e-2)

    # 获取页面块中的所有行
    lines = page_blocks.get_all_lines()

    boxes = []
    text = []
    for line in lines:
        box = line.bbox
        # 处理边界框溢出的情况
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]

        # 处理边界框宽度或高度为0或负值的情况
        if box[2] <= box[0]:
            print("Zero width box found, cannot convert properly")
            raise ValueError
        if box[3] <= box[1]:
            print("Zero height box found, cannot convert properly")
            raise ValueError
        boxes.append(box)
        text.append(line.prelim_text)

    # 将边界框归一化为模型(缩放为1000x1000)
    boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
    for box in boxes:
        # 验证所有边界框都是有效的
        assert(len(box) == 4)
        assert(max(box)) <= 1000
        assert(min(box)) >= 0
    # 使用 processor 处理 RGB 图像,传入文本、框、返回偏移映射等参数
    encoding = processor(
        rgb_image,
        text=text,
        boxes=boxes,
        return_offsets_mapping=True,
        truncation=True,
        return_tensors="pt",
        stride=settings.LAYOUT_CHUNK_OVERLAP,
        padding="max_length",
        max_length=settings.LAYOUT_MODEL_MAX,
        return_overflowing_tokens=True
    )
    # 从 encoding 中弹出 offset_mapping 和 overflow_to_sample_mapping
    offset_mapping = encoding.pop('offset_mapping')
    overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
    # 将 encoding 中的 bbox、input_ids、attention_mask、pixel_values 转换为列表
    bbox = list(encoding["bbox"])
    input_ids = list(encoding["input_ids"])
    attention_mask = list(encoding["attention_mask"])
    pixel_values = list(encoding["pixel_values"])

    # 断言各列表长度相等
    assert len(bbox) == len(input_ids) == len(attention_mask) == len(pixel_values) == len(offset_mapping)

    # 将各列表中的元素组成字典,放入 list_encoding 列表中
    list_encoding = []
    for i in range(len(bbox)):
        list_encoding.append({
            "bbox": bbox[i],
            "input_ids": input_ids[i],
            "attention_mask": attention_mask[i],
            "pixel_values": pixel_values[i],
            "offset_mapping": offset_mapping[i]
        })

    # 其他数据包括原始框、pwidth 和 pheight
    other_data = {
        "original_bbox": boxes,
        "pwidth": pwidth,
        "pheight": pheight,
    }
    # 返回 list_encoding 和 other_data
    return list_encoding, other_data
# 获取文档的特征信息
def get_features(doc, blocks):
    # 初始化编码、元数据和样本长度列表
    encodings = []
    metadata = []
    sample_lengths = []
    # 遍历每个块
    for i in range(len(blocks)):
        # 调用函数获取页面编码和其他数据
        encoding, other_data = get_page_encoding(doc[i], blocks[i])
        # 将页面编码添加到编码列表中
        encodings.extend(encoding)
        # 将其他数据添加到元数据列表中
        metadata.append(other_data)
        # 记录当前页面编码的长度
        sample_lengths.append(len(encoding))
    # 返回编码、元数据和样本长度
    return encodings, metadata, sample_lengths


# 预测块类型
def predict_block_types(encodings, layoutlm_model, batch_size):
    # 初始化所有预测结果列表
    all_predictions = []
    # 按批次处理编码
    for i in range(0, len(encodings), batch_size):
        # 计算当前批次的起始和结束索引
        batch_start = i
        batch_end = min(i   batch_size, len(encodings))
        # 获取当前批次的编码
        batch = encodings[batch_start:batch_end]

        # 构建模型输入
        model_in = {}
        for k in ["bbox", "input_ids", "attention_mask", "pixel_values"]:
            model_in[k] = torch.stack([b[k] for b in batch]).to(layoutlm_model.device)

        model_in["pixel_values"] = model_in["pixel_values"].to(layoutlm_model.dtype)

        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行推理
            outputs = layoutlm_model(**model_in)
            logits = outputs.logits

        # 获取预测结果
        predictions = logits.argmax(-1).squeeze().tolist()
        if len(predictions) == settings.LAYOUT_MODEL_MAX:
            predictions = [predictions]
        # 将预测结果添加到所有预测结果列表中
        all_predictions.extend(predictions)
    # 返回所有预测结果
    return all_predictions


# 将预测结果与框匹配
def match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model) -> List[List[BlockType]]:
    # 断言编码、预测结果和样本长度的长度相等
    assert len(encodings) == len(predictions) == sum(sample_lengths)
    assert len(metadata) == len(sample_lengths)

    # 初始化页面起始索引和页面块类型列表
    page_start = 0
    page_block_types = []
    # 返回页面块类型列表
    return page_block_types

.markermarkersettings.py

代码语言:javascript复制
import os
from typing import Optional, List, Dict

from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
import fitz as pymupdf
import torch

# 定义一个设置类,继承自BaseSettings
class Settings(BaseSettings):
    # General
    TORCH_DEVICE: Optional[str] = None

    # 计算属性,返回TORCH_DEVICE_MODEL
    @computed_field
    @property
    def TORCH_DEVICE_MODEL(self) -> str:
        # 如果TORCH_DEVICE不为None,则返回TORCH_DEVICE
        if self.TORCH_DEVICE is not None:
            return self.TORCH_DEVICE

        # 如果CUDA可用,则返回"cuda"
        if torch.cuda.is_available():
            return "cuda"

        # 如果MPS可用,则返回"mps"
        if torch.backends.mps.is_available():
            return "mps"

        # 否则返回"cpu"
        return "cpu"

    INFERENCE_RAM: int = 40 # 每个GPU的VRAM量(以GB为单位)。
    VRAM_PER_TASK: float = 2.5 # 每个任务分配的VRAM量(以GB为单位)。 峰值标记VRAM使用量约为3GB,但工作程序的平均值较低。
    DEFAULT_LANG: str = "English" # 我们假设文件所在的默认语言,应该是TESSERACT_LANGUAGES中的一个键

    SUPPORTED_FILETYPES: Dict = {
        "application/pdf": "pdf",
        "application/epub zip": "epub",
        "application/x-mobipocket-ebook": "mobi",
        "application/vnd.ms-xpsdocument": "xps",
        "application/x-fictionbook xml": "fb2"
    }

    # PyMuPDF
    TEXT_FLAGS: int = pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES

    # OCR
    INVALID_CHARS: List[str] = [chr(0xfffd), "�"]
    OCR_DPI: int = 400
    TESSDATA_PREFIX: str = ""
    TESSERACT_LANGUAGES: Dict = {
        "English": "eng",
        "Spanish": "spa",
        "Portuguese": "por",
        "French": "fra",
        "German": "deu",
        "Russian": "rus",
        "Chinese": "chi_sim",
        "Japanese": "jpn",
        "Korean": "kor",
        "Hindi": "hin",
    }
    TESSERACT_TIMEOUT: int = 20 # 何时放弃OCR
    # 定义拼写检查语言对应的字典
    SPELLCHECK_LANGUAGES: Dict = {
        "English": "en",
        "Spanish": "es",
        "Portuguese": "pt",
        "French": "fr",
        "German": "de",
        "Russian": "ru",
        "Chinese": None,
        "Japanese": None,
        "Korean": None,
        "Hindi": None,
    }
    # 是否在每一页运行 OCR,即使可以提取文本
    OCR_ALL_PAGES: bool = False
    # 用于 OCR 的并行 CPU 工作线程数
    OCR_PARALLEL_WORKERS: int = 2
    # 使用的 OCR 引擎,可以是 "tesseract" 或 "ocrmypdf",ocrmypdf 质量更高但速度较慢
    OCR_ENGINE: str = "ocrmypdf"

    # Texify 模型相关参数
    TEXIFY_MODEL_MAX: int = 384 # Texify 的最大推理长度
    TEXIFY_TOKEN_BUFFER: int = 256 # Texify 的 token 缓冲区大小
    TEXIFY_DPI: int = 96 # 渲染图像的 DPI
    TEXIFY_BATCH_SIZE: int = 2 if TORCH_DEVICE_MODEL == "cpu" else 6 # Texify 的批处理大小,CPU 上较低因为使用 float32
    TEXIFY_MODEL_NAME: str = "vikp/texify"

    # Layout 模型相关参数
    BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
    LAYOUT_MODEL_MAX: int = 512
    LAYOUT_CHUNK_OVERLAP: int = 64
    LAYOUT_DPI: int = 96
    LAYOUT_MODEL_NAME: str = "vikp/layout_segmenter"
    LAYOUT_BATCH_SIZE: int = 8 # 最大 512 个 token 意味着较高的批处理大小

    # Ordering 模型相关参数
    ORDERER_BATCH_SIZE: int = 32 # 可以较高,因为最大 token 数为 128
    ORDERER_MODEL_NAME: str = "vikp/column_detector"

    # 最终编辑模型相关参数
    EDITOR_BATCH_SIZE: int = 4
    EDITOR_MAX_LENGTH: int = 1024
    EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
    ENABLE_EDITOR_MODEL: bool = False # 编辑模型可能会产生误报
    EDITOR_CUTOFF_THRESH: float = 0.9 # 忽略概率低于此阈值的预测

    # Ray 相关参数
    RAY_CACHE_PATH: Optional[str] = None # 保存 Ray 缓存的路径
    RAY_CORES_PER_WORKER: int = 1 # 每个 worker 分配的 CPU 核心数

    # 调试相关参数
    DEBUG: bool = False # 启用调试日志
    # 调试数据文件夹路径,默认为 None
    DEBUG_DATA_FOLDER: Optional[str] = None
    # 调试级别,范围从 0 到 2,2 表示记录所有信息
    DEBUG_LEVEL: int = 0
    
    # 计算属性,返回是否使用 CUDA
    @computed_field
    @property
    def CUDA(self) -> bool:
        return "cuda" in self.TORCH_DEVICE
    
    # 计算属性,返回模型数据类型
    @computed_field
    @property
    def MODEL_DTYPE(self) -> torch.dtype:
        if self.TORCH_DEVICE_MODEL == "cuda":
            return torch.bfloat16
        else:
            return torch.float32
    
    # 计算属性,返回用于转换的数据类型
    @computed_field
    @property
    def TEXIFY_DTYPE(self) -> torch.dtype:
        return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16
    
    # 类配置
    class Config:
        # 从环境文件中查找 local.env 文件
        env_file = find_dotenv("local.env")
        # 额外配置,忽略错误
        extra = "ignore"
# 创建一个 Settings 对象实例
settings = Settings()

.markerscriptsverify_benchmark_scores.py

代码语言:javascript复制
# 导入 json 模块和 argparse 模块
import json
import argparse

# 验证分数的函数,接收一个文件路径作为参数
def verify_scores(file_path):
    # 打开文件并加载 JSON 数据
    with open(file_path, 'r') as file:
        data = json.load(file)

    # 获取 multicolcnn.pdf 文件的分数
    multicolcnn_score = data["marker"]["files"]["multicolcnn.pdf"]["score"]
    # 获取 switch_trans.pdf 文件的分数
    switch_trans_score = data["marker"]["files"]["switch_trans.pdf"]["score"]

    # 如果其中一个分数小于等于 0.4,则抛出 ValueError 异常
    if multicolcnn_score <= 0.4 or switch_trans_score <= 0.4:
        raise ValueError("One or more scores are below the required threshold of 0.4")

# 如果当前脚本被直接执行
if __name__ == "__main__":
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(description="Verify benchmark scores")
    # 添加一个参数,指定文件路径,类型为字符串
    parser.add_argument("file_path", type=str, help="Path to the json file")
    # 解析命令行参数
    args = parser.parse_args()
    # 调用 verify_scores 函数,传入文件路径参数
    verify_scores(args.file_path)

0 人点赞