.markermarkermodels.py
代码语言:javascript
复制# 从 marker.cleaners.equations 模块中导入 load_texify_model 函数
from marker.cleaners.equations import load_texify_model
# 从 marker.ordering 模块中导入 load_ordering_model 函数
from marker.ordering import load_ordering_model
# 从 marker.postprocessors.editor 模块中导入 load_editing_model 函数
from marker.postprocessors.editor import load_editing_model
# 从 marker.segmentation 模块中导入 load_layout_model 函数
from marker.segmentation import load_layout_model
# 定义一个函数用于加载所有模型
def load_all_models():
# 调用 load_editing_model 函数,加载编辑模型
edit = load_editing_model()
# 调用 load_ordering_model 函数,加载排序模型
order = load_ordering_model()
# 调用 load_layout_model 函数,加载布局模型
layout = load_layout_model()
# 调用 load_texify_model 函数,加载 TeXify 模型
texify = load_texify_model()
# 将加载的模型按顺序存储在列表中
model_lst = [texify, layout, order, edit]
# 返回模型列表
return model_lst
.markermarkerocrpage.py
代码语言:javascript
复制import io # 导入io模块
from typing import List, Optional # 导入类型提示相关模块
import fitz as pymupdf # 导入fitz模块并重命名为pymupdf
import ocrmypdf # 导入ocrmypdf模块
from spellchecker import SpellChecker # 从spellchecker模块导入SpellChecker类
from marker.ocr.utils import detect_bad_ocr # 从marker.ocr.utils模块导入detect_bad_ocr函数
from marker.schema import Block # 从marker.schema模块导入Block类
from marker.settings import settings # 从marker.settings模块导入settings变量
ocrmypdf.configure_logging(verbosity=ocrmypdf.Verbosity.quiet) # 配置ocrmypdf的日志记录级别为quiet
# 对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
# 如果OCR_ENGINE设置为"tesseract",则调用ocr_entire_page_tess函数
if settings.OCR_ENGINE == "tesseract":
return ocr_entire_page_tess(page, lang, spellchecker)
# 如果OCR_ENGINE设置为"ocrmypdf",则调用ocr_entire_page_ocrmp函数
elif settings.OCR_ENGINE == "ocrmypdf":
return ocr_entire_page_ocrmp(page, lang, spellchecker)
else:
raise ValueError(f"Unknown OCR engine {settings.OCR_ENGINE}") # 抛出数值错误异常,显示未知的OCR引擎
# 使用tesseract对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page_tess(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
try:
# 获取页面的完整OCR文本页
full_tp = page.get_textpage_ocr(flags=settings.TEXT_FLAGS, dpi=settings.OCR_DPI, full=True, language=lang)
# 获取页面的文本块列表
blocks = page.get_text("dict", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)["blocks"]
# 获取页面的完整文本
full_text = page.get_text("text", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)
# 如果完整文本长度为0,则返回空列表
if len(full_text) == 0:
return []
# 检查OCR是否成功。如果失败,返回空列表
# 例如,如果有一张扫描的空白页上有一些淡淡的文本印记,OCR可能会失败
if detect_bad_ocr(full_text, spellchecker):
return []
except RuntimeError:
return []
return blocks # 返回文本块列表
# 使用ocrmypdf对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page_ocrmp(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
# 使用ocrmypdf获取整个页面的OCR文本
src = page.parent # 页面所属文档
blank_doc = pymupdf.open() # 创建临时的1页文档
blank_doc.insert_pdf(src, from_page=page.number, to_page=page.number, annots=False, links=False) # 插入PDF页面
pdfbytes = blank_doc.tobytes() # 获取文档字节流
inbytes = io.BytesIO(pdfbytes) # 转换为BytesIO对象
# 创建一个字节流对象,用于存储 ocrmypdf 处理后的结果 PDF
outbytes = io.BytesIO() # let ocrmypdf store its result pdf here
# 使用 ocrmypdf 进行 OCR 处理
ocrmypdf.ocr(
inbytes,
outbytes,
language=lang,
output_type="pdf",
redo_ocr=None if settings.OCR_ALL_PAGES else True,
force_ocr=True if settings.OCR_ALL_PAGES else None,
progress_bar=False,
optimize=False,
fast_web_view=1e6,
skip_big=15, # skip images larger than 15 megapixels
tesseract_timeout=settings.TESSERACT_TIMEOUT,
tesseract_non_ocr_timeout=settings.TESSERACT_TIMEOUT,
)
# 以 fitz PDF 格式打开 OCR 处理后的输出
ocr_pdf = pymupdf.open("pdf", outbytes.getvalue()) # read output as fitz PDF
# 获取 OCR 处理后的文本块信息
blocks = ocr_pdf[0].get_text("dict", sort=True, flags=settings.TEXT_FLAGS)["blocks"]
# 获取 OCR 处理后的完整文本
full_text = ocr_pdf[0].get_text("text", sort=True, flags=settings.TEXT_FLAGS)
# 确保原始 PDF/EPUB/MOBI 的边界框和 OCR 处理后的 PDF 的边界框相同
assert page.bound() == ocr_pdf[0].bound()
# 如果完整文本为空,则返回空列表
if len(full_text) == 0:
return []
# 如果检测到 OCR 处理不良,则返回空列表
if detect_bad_ocr(full_text, spellchecker):
return []
# 返回文本块信息
return blocks
.markermarkerocrutils.py
代码语言:javascript
复制# 导入必要的模块和类
from typing import Optional
from nltk import wordpunct_tokenize
from spellchecker import SpellChecker
from marker.settings import settings
import re
# 检测 OCR 文本质量是否差,返回布尔值
def detect_bad_ocr(text, spellchecker: Optional[SpellChecker], misspell_threshold=.7, space_threshold=.6, newline_threshold=.5, alphanum_threshold=.4):
# 如果文本长度为0,则假定 OCR 失败
if len(text) == 0:
return True
# 使用 wordpunct_tokenize 函数将文本分词
words = wordpunct_tokenize(text)
# 过滤掉空白字符
words = [w for w in words if w.strip()]
# 提取文本中的字母数字字符
alpha_words = [word for word in words if word.isalnum()]
# 如果提供了拼写检查器
if spellchecker:
# 检查文本中的拼写错误
misspelled = spellchecker.unknown(alpha_words)
# 如果拼写错误数量超过阈值,则返回 True
if len(misspelled) > len(alpha_words) * misspell_threshold:
return True
# 计算文本中空格的数量
spaces = len(re.findall(r's ', text))
# 计算文本中字母字符的数量
alpha_chars = len(re.sub(r's ', '', text))
# 如果空格占比超过阈值,则返回 True
if spaces / (alpha_chars spaces) > space_threshold:
return True
# 计算文本中换行符的数量
newlines = len(re.findall(r'n ', text))
# 计算文本中非换行符的数量
non_newlines = len(re.sub(r'n ', '', text))
# 如果换行符占比超过阈值,则返回 True
if newlines / (newlines non_newlines) > newline_threshold:
return True
# 如果文本中字母数字字符比例低于阈值,则返回 True
if alphanum_ratio(text) < alphanum_threshold: # Garbled text
return True
# 计算文本中无效字符的数量
invalid_chars = len([c for c in text if c in settings.INVALID_CHARS])
# 如果无效字符数量超过阈值,则返回 True
if invalid_chars > max(3.0, len(text) * .02):
return True
# 默认情况下返回 False
return False
# 将字体标志拆解为可读的形式
def font_flags_decomposer(flags):
l = []
# 检查字体标志中是否包含上标
if flags & 2 ** 0:
l.append("superscript")
# 检查字体标志中是否包含斜体
if flags & 2 ** 1:
l.append("italic")
# 检查字体标志中是否包含衬线
if flags & 2 ** 2:
l.append("serifed")
else:
l.append("sans")
# 检查字体标志中是否包含等宽字体
if flags & 2 ** 3:
l.append("monospaced")
else:
l.append("proportional")
# 检查字体标志中是否包含粗体
if flags & 2 ** 4:
l.append("bold")
# 返回拆解后的字体标志字符串
return "_".join(l)
# 计算文本中字母数字字符的比例
def alphanum_ratio(text):
# 去除文本中的空格和换行符
text = text.replace(" ", "")
text = text.replace("n", "")
# 统计文本中的字母数字字符数量
alphanumeric_count = sum([1 for c in text if c.isalnum()])
# 如果文本长度为0,则返回1
if len(text) == 0:
return 1
# 计算字母数字字符比例
ratio = alphanumeric_count / len(text)
# 返回变量 ratio 的值
return ratio
.markermarkerordering.py
代码语言:javascript
复制# 导入必要的模块
from copy import deepcopy
from typing import List
import torch
import sys, os
from marker.extract_text import convert_single_page
from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3Processor
from PIL import Image
import io
from marker.schema import Page
from marker.settings import settings
# 从设置中加载 LayoutLMv3Processor 模型
processor = LayoutLMv3Processor.from_pretrained(settings.ORDERER_MODEL_NAME)
# 加载 LayoutLMv3ForSequenceClassification 模型
def load_ordering_model():
model = LayoutLMv3ForSequenceClassification.from_pretrained(
settings.ORDERER_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
model.eval()
return model
# 获取推理数据
def get_inference_data(page, page_blocks: Page):
# 深拷贝页面块的边界框
bboxes = deepcopy([block.bbox for block in page_blocks.blocks])
# 初始化单词列表
words = ["."] * len(bboxes)
# 获取页面的像素图像
pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
# 将像素图像转换为 PNG 格式
png = pix.pil_tobytes(format="PNG")
# 将 PNG 数据转换为 RGB 图像
rgb_image = Image.open(io.BytesIO(png)).convert("RGB")
# 获取页面块的边界框和宽高
page_box = page_blocks.bbox
pwidth = page_blocks.width
pheight = page_blocks.height
# 调整边界框的值
for box in bboxes:
if box[0] < page_box[0]:
box[0] = page_box[0]
if box[1] < page_box[1]:
box[1] = page_box[1]
if box[2] > page_box[2]:
box[2] = page_box[2]
if box[3] > page_box[3]:
box[3] = page_box[3]
# 将边界框的值转换为相对于页面宽高的比例
box[0] = int(box[0] / pwidth * 1000)
box[1] = int(box[1] / pheight * 1000)
box[2] = int(box[2] / pwidth * 1000)
box[3] = int(box[3] / pheight * 1000)
return rgb_image, bboxes, words
# 批量推理
def batch_inference(rgb_images, bboxes, words, model):
# 对 RGB 图像、单词和边界框进行编码
encoding = processor(
rgb_images,
text=words,
boxes=bboxes,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
)
# 将像素值转换为模型的数据类型
encoding["pixel_values"] = encoding["pixel_values"].to(model.dtype)
# 进入推断模式,不进行梯度计算
with torch.inference_mode():
# 将指定的键对应的值移动到模型所在设备上
for k in ["bbox", "input_ids", "pixel_values", "attention_mask"]:
encoding[k] = encoding[k].to(model.device)
# 使用模型进行推理,获取输出
outputs = model(**encoding)
# 获取模型输出的预测结果
logits = outputs.logits
# 获取预测结果中概率最大的类别索引,并转换为列表
predictions = logits.argmax(-1).squeeze().tolist()
# 如果预测结果是整数,则转换为列表
if isinstance(predictions, int):
predictions = [predictions]
# 将预测结果转换为类别标签
predictions = [model.config.id2label[p] for p in predictions]
# 返回预测结果
return predictions
# 为文档中的每个块添加列数计数
def add_column_counts(doc, doc_blocks, model, batch_size):
# 按照批量大小遍历文档块
for i in range(0, len(doc_blocks), batch_size):
# 创建当前批量的索引范围
batch = range(i, min(i batch_size, len(doc_blocks)))
# 初始化空列表用于存储 RGB 图像、边界框和单词
rgb_images = []
bboxes = []
words = []
# 遍历当前批量的页码
for pnum in batch:
# 获取推理数据:RGB 图像、页边界框和页单词
page = doc[pnum]
rgb_image, page_bboxes, page_words = get_inference_data(page, doc_blocks[pnum])
rgb_images.append(rgb_image)
bboxes.append(page_bboxes)
words.append(page_words)
# 进行批量推理,获取预测结果
predictions = batch_inference(rgb_images, bboxes, words, model)
# 将预测结果与页码对应,更新文档块的列数计数
for pnum, prediction in zip(batch, predictions):
doc_blocks[pnum].column_count = prediction
# 对文档块进行排序
def order_blocks(doc, doc_blocks: List[Page], model, batch_size=settings.ORDERER_BATCH_SIZE):
# 添加列数计数
add_column_counts(doc, doc_blocks, model, batch_size)
# 遍历文档块中的每一页
for page_blocks in doc_blocks:
# 如果该页的列数大于1
if page_blocks.column_count > 1:
# 根据位置重新排序块
split_pos = page_blocks.x_start page_blocks.width / 2
left_blocks = []
right_blocks = []
# 遍历该页的每个块
for block in page_blocks.blocks:
# 根据位置将块分为左右两部分
if block.x_start <= split_pos:
left_blocks.append(block)
else:
right_blocks.append(block)
# 更新该页的块顺序
page_blocks.blocks = left_blocks right_blocks
# 返回排序后的文档块
return doc_blocks
.markermarkerpostprocessorseditor.py
代码语言:javascript
复制# 导入必要的库
from collections import defaultdict, Counter
from itertools import chain
from typing import Optional
# 导入 transformers 库中的 AutoTokenizer 类
from transformers import AutoTokenizer
# 导入 settings 模块中的 settings 变量
from marker.settings import settings
# 导入 torch 库
import torch
import torch.nn.functional as F
# 导入 marker.postprocessors.t5 模块中的 T5ForTokenClassification 类和 byt5_tokenize 函数
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize
# 定义加载编辑模型的函数
def load_editing_model():
# 如果未启用编辑模型,则返回 None
if not settings.ENABLE_EDITOR_MODEL:
return None
# 从预训练模型中加载 T5ForTokenClassification 模型
model = T5ForTokenClassification.from_pretrained(
settings.EDITOR_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
model.eval()
# 配置模型的标签映射
model.config.label2id = {
"equal": 0,
"delete": 1,
"newline-1": 2,
"space-1": 3,
}
model.config.id2label = {v: k for k, v in model.config.label2id.items()}
return model
# 定义编辑全文的函数
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_size: int = settings.EDITOR_BATCH_SIZE):
# 如果模型为空,则直接返回原始文本和空字典
if not model:
return text, {}
# 对文本进行 tokenization
tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
input_ids = tokenized["input_ids"]
char_token_lengths = tokenized["char_token_lengths"]
# 准备 token_masks 列表
token_masks = []
# 遍历输入的 input_ids,按照 batch_size 进行分批处理
for i in range(0, len(input_ids), batch_size):
# 从 tokenized 中获取当前 batch 的 input_ids
batch_input_ids = tokenized["input_ids"][i: i batch_size]
# 将 batch_input_ids 转换为 torch 张量,并指定设备为 model 的设备
batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
# 从 tokenized 中获取当前 batch 的 attention_mask
batch_attention_mask = tokenized["attention_mask"][i: i batch_size]
# 将 batch_attention_mask 转换为 torch 张量,并指定设备为 model 的设备
batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
# 进入推理模式
with torch.inference_mode():
# 使用模型进行预测
predictions = model(batch_input_ids, attention_mask=batch_attention_mask)
# 将预测结果 logits 移动到 CPU 上
logits = predictions.logits.cpu()
# 如果最大概率小于阈值,则假设为不良预测
# 我们希望保守一点,不要对文本进行过多编辑
probs = F.softmax(logits, dim=-1)
max_prob = torch.max(probs, dim=-1)
cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
labels = logits.argmax(-1)
labels[cutoff_prob] = model.config.label2id["equal"]
labels = labels.squeeze().tolist()
if len(labels) == settings.EDITOR_MAX_LENGTH:
labels = [labels]
labels = list(chain.from_iterable(labels))
token_masks.extend(labels)
# 文本中的字符列表
flat_input_ids = list(chain.from_iterable(input_ids)
# 去除特殊标记 0,1。保留未知标记,尽管它不应该被使用
assert len(token_masks) == len(flat_input_ids)
token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
# 确保 token_masks 的长度与文本编码后的长度相等
assert len(token_masks) == len(list(text.encode("utf-8")))
# 统计编辑次数的字典
edit_stats = defaultdict(int)
# 输出文本列表
out_text = []
# 起始位置
start = 0
# 遍历文本中的每个字符及其索引
for i, char in enumerate(text):
# 获取当前字符对应的 token 长度
char_token_length = char_token_lengths[i]
# 获取当前字符对应的 token 的 mask
masks = token_masks[start: start char_token_length]
# 将 mask 转换为标签
labels = [model.config.id2label[mask] for mask in masks]
# 如果所有标签都是 "delete",则执行删除操作
if all(l == "delete" for l in labels):
# 如果删除的是空格,则保留,否则忽略
if char.strip():
out_text.append(char)
else:
edit_stats["delete"] = 1
# 如果标签为 "newline-1",则添加换行符
elif labels[0] == "newline-1":
out_text.append("n")
out_text.append(char)
edit_stats["newline-1"] = 1
# 如果标签为 "space-1",则添加空格
elif labels[0] == "space-1":
out_text.append(" ")
out_text.append(char)
edit_stats["space-1"] = 1
# 如果标签为其他情况,则保留字符
else:
out_text.append(char)
edit_stats["equal"] = 1
# 更新下一个字符的起始位置
start = char_token_length
# 将处理后的文本列表转换为字符串
out_text = "".join(out_text)
# 返回处理后的文本及编辑统计信息
return out_text, edit_stats
.markermarkerpostprocessorst5.py
代码语言:javascript
复制# 从 transformers 库中导入 T5Config 和 T5PreTrainedModel 类
from transformers import T5Config, T5PreTrainedModel
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 copy 库中导入 deepcopy 函数
from copy import deepcopy
# 从 typing 库中导入 Optional, Tuple, Union, List 类型
from typing import Optional, Tuple, Union, List
# 从 itertools 库中导入 chain 函数
from itertools import chain
# 从 transformers.modeling_outputs 模块中导入 TokenClassifierOutput 类
from transformers.modeling_outputs import TokenClassifierOutput
# 从 transformers.models.t5.modeling_t5 模块中导入 T5Stack 类
from transformers.models.t5.modeling_t5 import T5Stack
# 从 transformers.utils.model_parallel_utils 模块中导入 get_device_map, assert_device_map 函数
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
# 定义一个函数,用于将文本进行字节编码并分词
def byt5_tokenize(text: str, max_length: int, pad_token_id: int = 0):
# 初始化一个空列表,用于存储字节编码
byte_codes = []
# 遍历文本中的每个字符
for char in text:
# 将每个字符进行 UTF-8 编码,并加上 3 以考虑特殊标记
byte_codes.append([byte 3 for byte in char.encode('utf-8')])
# 将字节编码展开成一个列表
tokens = list(chain.from_iterable(byte_codes))
# 记录每个字符对应的 token 长度
char_token_lengths = [len(b) for b in byte_codes]
# 初始化批量 token 和注意力掩码列表
batched_tokens = []
attention_mask = []
# 按照最大长度将 token 进行分批
for i in range(0, len(tokens), max_length):
batched_tokens.append(tokens[i:i max_length])
attention_mask.append([1] * len(batched_tokens[-1])
# 对最后一个批次进行填充
if len(batched_tokens[-1]) < max_length:
batched_tokens[-1] = [pad_token_id] * (max_length - len(batched_tokens[-1]))
attention_mask[-1] = [0] * (max_length - len(attention_mask[-1]))
# 返回包含分词结果的字典
return {"input_ids": batched_tokens, "attention_mask": attention_mask, "char_token_lengths": char_token_lengths}
# 定义一个 T5ForTokenClassification 类,继承自 T5PreTrainedModel 类
class T5ForTokenClassification(T5PreTrainedModel):
# 定义一个列表,用于指定加载时忽略的键
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
# 初始化函数,接受一个T5Config对象作为参数
def __init__(self, config: T5Config):
# 调用父类的初始化函数
super().__init__(config)
# 设置模型维度为配置中的d_model值
self.model_dim = config.d_model
# 创建一个共享的嵌入层,词汇表大小为config.vocab_size,维度为config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置对象,用于创建编码器
encoder_config = deepcopy(config)
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.use_cache = False
# 创建T5Stack编码器
self.encoder = T5Stack(encoder_config, self.shared)
# 设置分类器的dropout值
classifier_dropout = (
config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate
)
self.dropout = nn.Dropout(classifier_dropout)
# 创建一个线性层,输入维度为config.d_model,输出维度为config.num_labels
self.classifier = nn.Linear(config.d_model, config.num_labels)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行化
self.model_parallel = False
self.device_map = None
# 并行化函数,接受一个设备映射device_map作为参数
def parallelize(self, device_map=None):
# 如果未提供device_map,则根据编码器块的数量和GPU数量生成一个默认的device_map
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 检查设备映射的有效性
assert_device_map(self.device_map, len(self.encoder.block))
# 将编码器并行化
self.encoder.parallelize(self.device_map)
# 将分类器移动到编码器的第一个设备上
self.classifier.to(self.encoder.first_device)
self.model_parallel = True
# 反并行化函数
def deparallelize(self):
# 取消编码器的并行化
self.encoder.deparallelize()
# 将编码器和分类器移动到CPU上
self.encoder = self.encoder.to("cpu")
self.classifier = self.classifier.to("cpu")
self.model_parallel = False
self.device_map = None
# 释放GPU缓存
torch.cuda.empty_cache()
# 获取输入嵌入层函数
def get_input_embeddings(self):
return self.shared
# 设置输入嵌入层函数,接受一个新的嵌入层new_embeddings作为参数
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
# 设置编码器的输入嵌入层为新的嵌入层
self.encoder.set_input_embeddings(new_embeddings)
# 获取编码器函数
def get_encoder(self):
return self.encoder
# 对模型中的特定头部进行修剪
def _prune_heads(self, heads_to_prune):
# 遍历需要修剪的层和头部
for layer, heads in heads_to_prune.items():
# 调用 SelfAttention 模块的 prune_heads 方法进行修剪
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
# 前向传播函数
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
# 如果 return_dict 为 None,则使用配置中的 use_return_dict
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用编码器进行前向传播
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取序列输出
sequence_output = outputs[0]
# 对序列输出进行 dropout
sequence_output = self.dropout(sequence_output)
# 将序列输出传入分类器得到 logits
logits = self.classifier(sequence_output)
# 初始化损失为 None
loss = None
# 如果不使用 return_dict,则返回输出结果
if not return_dict:
output = (logits,) outputs[2:]
return ((loss,) output) if loss is not None else output
# 使用 TokenClassifierOutput 类返回结果
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)
.markermarkerschema.py
代码语言:javascript
复制# 导入 Counter 类,用于计数
# 导入 List、Optional、Tuple 类型,用于类型提示
from collections import Counter
from typing import List, Optional, Tuple
# 导入 BaseModel、field_validator 类,用于定义数据模型和字段验证
# 导入 ftfy 模块,用于修复文本中的 Unicode 错误
from pydantic import BaseModel, field_validator
import ftfy
# 导入 boxes_intersect_pct、multiple_boxes_intersect 函数,用于计算两个框的交集比例和多个框的交集情况
# 导入 settings 模块,用于获取配置信息
from marker.bbox import boxes_intersect_pct, multiple_boxes_intersect
from marker.settings import settings
# 定义函数 find_span_type,用于查找给定 span 在页面块中的类型
def find_span_type(span, page_blocks):
# 默认块类型为 "Text"
block_type = "Text"
# 遍历页面块列表
for block in page_blocks:
# 如果 span 的边界框与页面块的边界框有交集
if boxes_intersect_pct(span.bbox, block.bbox):
# 更新块类型为页面块的类型
block_type = block.block_type
break
# 返回块类型
return block_type
# 定义类 BboxElement,继承自 BaseModel 类,表示具有边界框的元素
class BboxElement(BaseModel):
bbox: List[float]
# 验证 bbox 字段是否包含 4 个元素
@field_validator('bbox')
@classmethod
def check_4_elements(cls, v: List[float]) -> List[float]:
if len(v) != 4:
raise ValueError('bbox must have 4 elements')
return v
# 计算元素的高度、宽度、起始 x 坐标、起始 y 坐标、面积
@property
def height(self):
return self.bbox[3] - self.bbox[1]
@property
def width(self):
return self.bbox[2] - self.bbox[0]
@property
def x_start(self):
return self.bbox[0]
@property
def y_start(self):
return self.bbox[1]
@property
def area(self):
return self.width * self.height
# 定义类 BlockType,继承自 BboxElement 类,表示具有块类型的元素
class BlockType(BboxElement):
block_type: str
# 定义类 Span,继承自 BboxElement 类,表示具有文本内容的元素
class Span(BboxElement):
text: str
span_id: str
font: str
color: int
ascender: Optional[float] = None
descender: Optional[float] = None
block_type: Optional[str] = None
selected: bool = True
# 修复文本中的 Unicode 错误
@field_validator('text')
@classmethod
def fix_unicode(cls, text: str) -> str:
return ftfy.fix_text(text)
# 定义类 Line,继承自 BboxElement 类,表示具有多个 Span 的行元素
class Line(BboxElement):
spans: List[Span]
# 获取行的预备文本,即所有 Span 的文本拼接而成
@property
def prelim_text(self):
return "".join([s.text for s in self.spans])
# 获取行的起始 x 坐标
@property
def start(self):
return self.spans[0].bbox[0]
# 定义类 Block,继承自 BboxElement 类,表示具有多个 Line 的块元素
class Block(BboxElement):
lines: List[Line]
pnum: int
# 获取块的预备文本,即所有 Line 的预备文本拼接而成
@property
def prelim_text(self):
return "n".join([l.prelim_text for l in self.lines])
# 检查文本块是否包含公式,通过检查每个 span 的 block_type 是否为 "Formula" 来确定
def contains_equation(self, equation_boxes=None):
# 生成一个包含每个 span 的 block_type 是否为 "Formula" 的条件列表
conditions = [s.block_type == "Formula" for l in self.lines for s in l.spans]
# 如果提供了 equation_boxes 参数,则添加一个条件,检查文本块的边界框是否与给定框相交
if equation_boxes:
conditions = [multiple_boxes_intersect(self.bbox, equation_boxes)]
# 返回条件列表中是否有任何一个条件为 True
return any(conditions)
# 过滤掉包含在 bad_span_ids 中的 span
def filter_spans(self, bad_span_ids):
new_lines = []
for line in self.lines:
new_spans = []
for span in line.spans:
# 如果 span 的 span_id 不在 bad_span_ids 中,则保留该 span
if not span.span_id in bad_span_ids:
new_spans.append(span)
# 更新 line 的 spans 属性为过滤后的 new_spans
line.spans = new_spans
# 如果 line 中仍有 spans,则将其添加到 new_lines 中
if len(new_spans) > 0:
new_lines.append(line)
# 更新 self.lines 为过滤后的 new_lines
self.lines = new_lines
# 过滤掉包含在 settings.BAD_SPAN_TYPES 中的 span 的 block_type
def filter_bad_span_types(self):
new_lines = []
for line in self.lines:
new_spans = []
for span in line.spans:
# 如果 span 的 block_type 不在 BAD_SPAN_TYPES 中,则保留该 span
if span.block_type not in settings.BAD_SPAN_TYPES:
new_spans.append(span)
# 更新 line 的 spans 属性为过滤后的 new_spans
line.spans = new_spans
# 如果 line 中仍有 spans,则将其添加到 new_lines 中
if len(new_spans) > 0:
new_lines.append(line)
# 更新 self.lines 为过滤后的 new_lines
self.lines = new_lines
# 返回文本块中出现频率最高的 block_type
def most_common_block_type(self):
# 统计每个 span 的 block_type 出现的次数
counter = Counter([s.block_type for l in self.lines for s in l.spans])
# 返回出现次数最多的 block_type
return counter.most_common(1)[0][0]
# 设置文本块中所有 span 的 block_type 为给定的 block_type
def set_block_type(self, block_type):
for line in self.lines:
for span in line.spans:
# 将 span 的 block_type 设置为给定的 block_type
span.block_type = block_type
# 定义一个名为 Page 的类,继承自 BboxElement 类
class Page(BboxElement):
# 类属性:blocks 为 Block 对象列表,pnum 为整数,column_count 和 rotation 可选整数,默认为 None
blocks: List[Block]
pnum: int
column_count: Optional[int] = None
rotation: Optional[int] = None # 页面的旋转角度
# 获取页面中非空行的方法
def get_nonblank_lines(self):
# 获取页面中所有行
lines = self.get_all_lines()
# 过滤出非空行
nonblank_lines = [l for l in lines if l.prelim_text.strip()]
return nonblank_lines
# 获取页面中所有行的方法
def get_all_lines(self):
# 获取页面中所有行的列表
lines = [l for b in self.blocks for l in b.lines]
return lines
# 获取页面中非空跨度的方法,返回 Span 对象列表
def get_nonblank_spans(self) -> List[Span]:
# 获取页面中所有行
lines = [l for b in self.blocks for l in b.lines]
# 过滤出非空跨度
spans = [s for l in lines for s in l.spans if s.text.strip()]
return spans
# 添加块类型到行的方法
def add_block_types(self, page_block_types):
# 如果检测到的块类型数量与页面行数不匹配,则打印警告信息
if len(page_block_types) != len(self.get_all_lines()):
print(f"Warning: Number of detected lines {len(page_block_types)} does not match number of lines {len(self.get_all_lines())}")
i = 0
for block in self.blocks:
for line in block.lines:
if i < len(page_block_types):
line_block_type = page_block_types[i].block_type
else:
line_block_type = "Text"
i = 1
for span in line.spans:
span.block_type = line_block_type
# 获取页面中字体统计信息的方法
def get_font_stats(self):
# 获取页面中非空跨度的字体信息
fonts = [s.font for s in self.get_nonblank_spans()]
# 统计字体出现次数
font_counts = Counter(fonts)
return font_counts
# 获取页面中行高统计信息的方法
def get_line_height_stats(self):
# 获取页面中非空行的行高信息
heights = [l.bbox[3] - l.bbox[1] for l in self.get_nonblank_lines()]
# 统计行高出现次数
height_counts = Counter(heights)
return height_counts
# 获取页面中行起始位置统计信息的方法
def get_line_start_stats(self):
# 获取页面中非空行的起始位置信息
starts = [l.bbox[0] for l in self.get_nonblank_lines()]
# 统计起始位置出现次数
start_counts = Counter(starts)
return start_counts
# 获取文本块中非空行的起始位置列表
def get_min_line_start(self):
# 通过列表推导式获取非空行的起始位置,并且该行为文本类型
starts = [l.bbox[0] for l in self.get_nonblank_lines() if l.spans[0].block_type == "Text"]
# 如果没有找到非空行,则抛出索引错误
if len(starts) == 0:
raise IndexError("No lines found")
# 返回起始位置列表中的最小值
return min(starts)
# 获取文本块中每个文本块的 prelim_text 属性,并用换行符连接成字符串
@property
def prelim_text(self):
return "n".join([b.prelim_text for b in self.blocks])
# 定义一个继承自BboxElement的MergedLine类,包含文本和字体列表属性
class MergedLine(BboxElement):
text: str
fonts: List[str]
# 返回该行中出现频率最高的字体
def most_common_font(self):
# 统计字体列表中各个字体出现的次数
counter = Counter(self.fonts)
# 返回出现频率最高的字体
return counter.most_common(1)[0][0]
# 定义一个继承自BboxElement的MergedBlock类,包含行列表、段落号和块类型列表属性
class MergedBlock(BboxElement):
lines: List[MergedLine]
pnum: int
block_types: List[str]
# 返回该块中出现频率最高的块类型
def most_common_block_type(self):
# 统计块类型列表中各个类型出现的次数
counter = Counter(self.block_types)
# 返回出现频率最高的块类型
return counter.most_common(1)[0][0]
# 定义一个继承自BaseModel的FullyMergedBlock类,包含文本和块类型属性
class FullyMergedBlock(BaseModel):
text: str
block_type: str
.markermarkersegmentation.py
代码语言:javascript
复制# 导入所需的库
from concurrent.futures import ThreadPoolExecutor
from typing import List
from transformers import LayoutLMv3ForTokenClassification
# 导入自定义的模块
from marker.bbox import unnormalize_box
from transformers.models.layoutlmv3.image_processing_layoutlmv3 import normalize_box
import io
from PIL import Image
from transformers import LayoutLMv3Processor
import numpy as np
from marker.settings import settings
from marker.schema import Page, BlockType
import torch
from math import isclose
# 设置图像最大像素值,避免部分图像被截断
Image.MAX_IMAGE_PIXELS = None
# 从预训练模型加载 LayoutLMv3Processor
processor = LayoutLMv3Processor.from_pretrained(settings.LAYOUT_MODEL_NAME, apply_ocr=False)
# 定义需要分块的键和不需要分块的键
CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
NO_CHUNK_KEYS = ["pixel_values"]
# 加载 LayoutLMv3ForTokenClassification 模型
def load_layout_model():
# 从预训练模型加载 LayoutLMv3ForTokenClassification 模型
model = LayoutLMv3ForTokenClassification.from_pretrained(
settings.LAYOUT_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
# 设置模型的标签映射
model.config.id2label = {
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title"
}
model.config.label2id = {v: k for k, v in model.config.id2label.items()}
return model
# 检测文档块类型
def detect_document_block_types(doc, blocks: List[Page], layoutlm_model, batch_size=settings.LAYOUT_BATCH_SIZE):
# 获取特征编码、元数据和样本长度
encodings, metadata, sample_lengths = get_features(doc, blocks)
# 预测块类型
predictions = predict_block_types(encodings, layoutlm_model, batch_size)
# 将预测结果与框匹配
block_types = match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model)
# 断言块类型数量与块数量相等
assert len(block_types) == len(blocks)
return block_types
# 获取临时框
def get_provisional_boxes(pred, box, is_subword, start_idx=0):
# 从预测结果中获取临时框
prov_predictions = [pred_ for idx, pred_ in enumerate(pred) if not is_subword[idx]][start_idx:]
# 从列表中筛选出不是子词的元素,并从指定索引开始切片,得到新的列表
prov_boxes = [box_ for idx, box_ in enumerate(box) if not is_subword[idx]][start_idx:]
# 返回处理后的预测结果和框
return prov_predictions, prov_boxes
# 获取页面编码信息,输入参数为页面和页面块对象
def get_page_encoding(page, page_blocks: Page):
# 如果页面块中的所有行数为0,则返回空列表
if len(page_blocks.get_all_lines()) == 0:
return [], []
# 获取页面块的边界框、宽度和高度
page_box = page_blocks.bbox
pwidth = page_blocks.width
pheight = page_blocks.height
# 获取页面块的像素图,并转换为 PNG 格式
pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
png = pix.pil_tobytes(format="PNG")
png_image = Image.open(io.BytesIO(png))
# 如果图像太大,则缩小以适应模型
rgb_image = png_image.convert('RGB')
rgb_width, rgb_height = rgb_image.size
# 确保图像大小与 PDF 页面的比例正确
assert isclose(rgb_width / pwidth, rgb_height / pheight, abs_tol=2e-2)
# 获取页面块中的所有行
lines = page_blocks.get_all_lines()
boxes = []
text = []
for line in lines:
box = line.bbox
# 处理边界框溢出的情况
if box[0] < page_box[0]:
box[0] = page_box[0]
if box[1] < page_box[1]:
box[1] = page_box[1]
if box[2] > page_box[2]:
box[2] = page_box[2]
if box[3] > page_box[3]:
box[3] = page_box[3]
# 处理边界框宽度或高度为0或负值的情况
if box[2] <= box[0]:
print("Zero width box found, cannot convert properly")
raise ValueError
if box[3] <= box[1]:
print("Zero height box found, cannot convert properly")
raise ValueError
boxes.append(box)
text.append(line.prelim_text)
# 将边界框归一化为模型(缩放为1000x1000)
boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
for box in boxes:
# 验证所有边界框都是有效的
assert(len(box) == 4)
assert(max(box)) <= 1000
assert(min(box)) >= 0
# 使用 processor 处理 RGB 图像,传入文本、框、返回偏移映射等参数
encoding = processor(
rgb_image,
text=text,
boxes=boxes,
return_offsets_mapping=True,
truncation=True,
return_tensors="pt",
stride=settings.LAYOUT_CHUNK_OVERLAP,
padding="max_length",
max_length=settings.LAYOUT_MODEL_MAX,
return_overflowing_tokens=True
)
# 从 encoding 中弹出 offset_mapping 和 overflow_to_sample_mapping
offset_mapping = encoding.pop('offset_mapping')
overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
# 将 encoding 中的 bbox、input_ids、attention_mask、pixel_values 转换为列表
bbox = list(encoding["bbox"])
input_ids = list(encoding["input_ids"])
attention_mask = list(encoding["attention_mask"])
pixel_values = list(encoding["pixel_values"])
# 断言各列表长度相等
assert len(bbox) == len(input_ids) == len(attention_mask) == len(pixel_values) == len(offset_mapping)
# 将各列表中的元素组成字典,放入 list_encoding 列表中
list_encoding = []
for i in range(len(bbox)):
list_encoding.append({
"bbox": bbox[i],
"input_ids": input_ids[i],
"attention_mask": attention_mask[i],
"pixel_values": pixel_values[i],
"offset_mapping": offset_mapping[i]
})
# 其他数据包括原始框、pwidth 和 pheight
other_data = {
"original_bbox": boxes,
"pwidth": pwidth,
"pheight": pheight,
}
# 返回 list_encoding 和 other_data
return list_encoding, other_data
# 获取文档的特征信息
def get_features(doc, blocks):
# 初始化编码、元数据和样本长度列表
encodings = []
metadata = []
sample_lengths = []
# 遍历每个块
for i in range(len(blocks)):
# 调用函数获取页面编码和其他数据
encoding, other_data = get_page_encoding(doc[i], blocks[i])
# 将页面编码添加到编码列表中
encodings.extend(encoding)
# 将其他数据添加到元数据列表中
metadata.append(other_data)
# 记录当前页面编码的长度
sample_lengths.append(len(encoding))
# 返回编码、元数据和样本长度
return encodings, metadata, sample_lengths
# 预测块类型
def predict_block_types(encodings, layoutlm_model, batch_size):
# 初始化所有预测结果列表
all_predictions = []
# 按批次处理编码
for i in range(0, len(encodings), batch_size):
# 计算当前批次的起始和结束索引
batch_start = i
batch_end = min(i batch_size, len(encodings))
# 获取当前批次的编码
batch = encodings[batch_start:batch_end]
# 构建模型输入
model_in = {}
for k in ["bbox", "input_ids", "attention_mask", "pixel_values"]:
model_in[k] = torch.stack([b[k] for b in batch]).to(layoutlm_model.device)
model_in["pixel_values"] = model_in["pixel_values"].to(layoutlm_model.dtype)
# 进入推理模式
with torch.inference_mode():
# 使用模型进行推理
outputs = layoutlm_model(**model_in)
logits = outputs.logits
# 获取预测结果
predictions = logits.argmax(-1).squeeze().tolist()
if len(predictions) == settings.LAYOUT_MODEL_MAX:
predictions = [predictions]
# 将预测结果添加到所有预测结果列表中
all_predictions.extend(predictions)
# 返回所有预测结果
return all_predictions
# 将预测结果与框匹配
def match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model) -> List[List[BlockType]]:
# 断言编码、预测结果和样本长度的长度相等
assert len(encodings) == len(predictions) == sum(sample_lengths)
assert len(metadata) == len(sample_lengths)
# 初始化页面起始索引和页面块类型列表
page_start = 0
page_block_types = []
# 返回页面块类型列表
return page_block_types
.markermarkersettings.py
代码语言:javascript
复制import os
from typing import Optional, List, Dict
from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
import fitz as pymupdf
import torch
# 定义一个设置类,继承自BaseSettings
class Settings(BaseSettings):
# General
TORCH_DEVICE: Optional[str] = None
# 计算属性,返回TORCH_DEVICE_MODEL
@computed_field
@property
def TORCH_DEVICE_MODEL(self) -> str:
# 如果TORCH_DEVICE不为None,则返回TORCH_DEVICE
if self.TORCH_DEVICE is not None:
return self.TORCH_DEVICE
# 如果CUDA可用,则返回"cuda"
if torch.cuda.is_available():
return "cuda"
# 如果MPS可用,则返回"mps"
if torch.backends.mps.is_available():
return "mps"
# 否则返回"cpu"
return "cpu"
INFERENCE_RAM: int = 40 # 每个GPU的VRAM量(以GB为单位)。
VRAM_PER_TASK: float = 2.5 # 每个任务分配的VRAM量(以GB为单位)。 峰值标记VRAM使用量约为3GB,但工作程序的平均值较低。
DEFAULT_LANG: str = "English" # 我们假设文件所在的默认语言,应该是TESSERACT_LANGUAGES中的一个键
SUPPORTED_FILETYPES: Dict = {
"application/pdf": "pdf",
"application/epub zip": "epub",
"application/x-mobipocket-ebook": "mobi",
"application/vnd.ms-xpsdocument": "xps",
"application/x-fictionbook xml": "fb2"
}
# PyMuPDF
TEXT_FLAGS: int = pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES
# OCR
INVALID_CHARS: List[str] = [chr(0xfffd), "�"]
OCR_DPI: int = 400
TESSDATA_PREFIX: str = ""
TESSERACT_LANGUAGES: Dict = {
"English": "eng",
"Spanish": "spa",
"Portuguese": "por",
"French": "fra",
"German": "deu",
"Russian": "rus",
"Chinese": "chi_sim",
"Japanese": "jpn",
"Korean": "kor",
"Hindi": "hin",
}
TESSERACT_TIMEOUT: int = 20 # 何时放弃OCR
# 定义拼写检查语言对应的字典
SPELLCHECK_LANGUAGES: Dict = {
"English": "en",
"Spanish": "es",
"Portuguese": "pt",
"French": "fr",
"German": "de",
"Russian": "ru",
"Chinese": None,
"Japanese": None,
"Korean": None,
"Hindi": None,
}
# 是否在每一页运行 OCR,即使可以提取文本
OCR_ALL_PAGES: bool = False
# 用于 OCR 的并行 CPU 工作线程数
OCR_PARALLEL_WORKERS: int = 2
# 使用的 OCR 引擎,可以是 "tesseract" 或 "ocrmypdf",ocrmypdf 质量更高但速度较慢
OCR_ENGINE: str = "ocrmypdf"
# Texify 模型相关参数
TEXIFY_MODEL_MAX: int = 384 # Texify 的最大推理长度
TEXIFY_TOKEN_BUFFER: int = 256 # Texify 的 token 缓冲区大小
TEXIFY_DPI: int = 96 # 渲染图像的 DPI
TEXIFY_BATCH_SIZE: int = 2 if TORCH_DEVICE_MODEL == "cpu" else 6 # Texify 的批处理大小,CPU 上较低因为使用 float32
TEXIFY_MODEL_NAME: str = "vikp/texify"
# Layout 模型相关参数
BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
LAYOUT_MODEL_MAX: int = 512
LAYOUT_CHUNK_OVERLAP: int = 64
LAYOUT_DPI: int = 96
LAYOUT_MODEL_NAME: str = "vikp/layout_segmenter"
LAYOUT_BATCH_SIZE: int = 8 # 最大 512 个 token 意味着较高的批处理大小
# Ordering 模型相关参数
ORDERER_BATCH_SIZE: int = 32 # 可以较高,因为最大 token 数为 128
ORDERER_MODEL_NAME: str = "vikp/column_detector"
# 最终编辑模型相关参数
EDITOR_BATCH_SIZE: int = 4
EDITOR_MAX_LENGTH: int = 1024
EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
ENABLE_EDITOR_MODEL: bool = False # 编辑模型可能会产生误报
EDITOR_CUTOFF_THRESH: float = 0.9 # 忽略概率低于此阈值的预测
# Ray 相关参数
RAY_CACHE_PATH: Optional[str] = None # 保存 Ray 缓存的路径
RAY_CORES_PER_WORKER: int = 1 # 每个 worker 分配的 CPU 核心数
# 调试相关参数
DEBUG: bool = False # 启用调试日志
# 调试数据文件夹路径,默认为 None
DEBUG_DATA_FOLDER: Optional[str] = None
# 调试级别,范围从 0 到 2,2 表示记录所有信息
DEBUG_LEVEL: int = 0
# 计算属性,返回是否使用 CUDA
@computed_field
@property
def CUDA(self) -> bool:
return "cuda" in self.TORCH_DEVICE
# 计算属性,返回模型数据类型
@computed_field
@property
def MODEL_DTYPE(self) -> torch.dtype:
if self.TORCH_DEVICE_MODEL == "cuda":
return torch.bfloat16
else:
return torch.float32
# 计算属性,返回用于转换的数据类型
@computed_field
@property
def TEXIFY_DTYPE(self) -> torch.dtype:
return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16
# 类配置
class Config:
# 从环境文件中查找 local.env 文件
env_file = find_dotenv("local.env")
# 额外配置,忽略错误
extra = "ignore"
# 创建一个 Settings 对象实例
settings = Settings()
.markerscriptsverify_benchmark_scores.py
代码语言:javascript
复制# 导入 json 模块和 argparse 模块
import json
import argparse
# 验证分数的函数,接收一个文件路径作为参数
def verify_scores(file_path):
# 打开文件并加载 JSON 数据
with open(file_path, 'r') as file:
data = json.load(file)
# 获取 multicolcnn.pdf 文件的分数
multicolcnn_score = data["marker"]["files"]["multicolcnn.pdf"]["score"]
# 获取 switch_trans.pdf 文件的分数
switch_trans_score = data["marker"]["files"]["switch_trans.pdf"]["score"]
# 如果其中一个分数小于等于 0.4,则抛出 ValueError 异常
if multicolcnn_score <= 0.4 or switch_trans_score <= 0.4:
raise ValueError("One or more scores are below the required threshold of 0.4")
# 如果当前脚本被直接执行
if __name__ == "__main__":
# 创建 ArgumentParser 对象,设置描述信息
parser = argparse.ArgumentParser(description="Verify benchmark scores")
# 添加一个参数,指定文件路径,类型为字符串
parser.add_argument("file_path", type=str, help="Path to the json file")
# 解析命令行参数
args = parser.parse_args()
# 调用 verify_scores 函数,传入文件路径参数
verify_scores(args.file_path)