一、引言
pipeline(管道)是huggingface transformers库中一种极简方式使用大模型推理的抽象,将所有大模型分为音频(Audio)、计算机视觉(Computer vision)、自然语言处理(NLP)、多模态(Multimodal)等4大类,28小类任务(tasks)。共计覆盖32万个模型
今天介绍NLP自然语言处理的第三篇:表格问答(table-question-answering),在huggingface库内有100个表格问答(table-question-answering)模型。
二、表格问答(table-question-answering)
2.1 概述
表格问答(Table QA)是回答有关给定表格上的信息的问题。
2.2 基于BERT的表格问答模型—TAPAS(TAble PArSing)
回答表格上的自然语言问题通常被视为语义解析任务。为了减轻完整逻辑形式的收集成本,一种流行的方法侧重于由符号而不是逻辑形式组成的弱监督。然而,从弱监督中训练语义解析器会带来困难,此外,生成的逻辑形式仅用作检索符号之前的中间步骤。在本文中,我们提出了 TaPas,一种无需生成逻辑形式的表格问答方法。TaPas 从弱监督中进行训练,并通过选择表格单元格并可选地将相应的聚合运算符应用于此类选择来预测符号。TaPas 扩展了 BERT 的架构以将表格编码为输入,从从维基百科爬取的文本段和表格的有效联合预训练中进行初始化,并进行端到端训练。
2.3 应用场景
- 自动化客服系统
- 智能搜索引擎
- 数据可视化工具
- 企业知识图谱构建
- 科学文献自动化抽取等
2.4 pipeline参数
2.4.1 pipeline对象实例化参数
- model(PreTrainedModel或TFPreTrainedModel)— 管道将使用其进行预测的模型。 对于 PyTorch,这需要从PreTrainedModel继承;对于 TensorFlow,这需要从TFPreTrainedModel继承。
- tokenizer ( PreTrainedTokenizer ) — 管道将使用 tokenizer 来为模型编码数据。此对象继承自 PreTrainedTokenizer。
- modelcard(
str
或ModelCard
,可选) — 属于此管道模型的模型卡。 - framework(
str
,可选)— 要使用的框架,"pt"
适用于 PyTorch 或"tf"
TensorFlow。必须安装指定的框架。 - task(
str
,默认为""
)— 管道的任务标识符。 - num_workers(
int
,可选,默认为 8)— 当管道将使用DataLoader(传递数据集时,在 Pytorch 模型的 GPU 上)时,要使用的工作者数量。 - batch_size(
int
,可选,默认为 1)— 当管道将使用DataLoader(传递数据集时,在 Pytorch 模型的 GPU 上)时,要使用的批次的大小,对于推理来说,这并不总是有益的,请阅读使用管道进行批处理。 - args_parser(ArgumentHandler,可选) - 引用负责解析提供的管道参数的对象。
- device(
int
,可选,默认为 -1)— CPU/GPU 支持的设备序号。将其设置为 -1 将利用 CPU,设置为正数将在关联的 CUDA 设备 ID 上运行模型。您可以传递本机torch.device
或str
太 - torch_dtype(
str
或torch.dtype
,可选) - 直接发送model_kwargs
(只是一种更简单的快捷方式)以使用此模型的可用精度(torch.float16
,,torch.bfloat16
...或"auto"
) - binary_output(
bool
,可选,默认为False
)——标志指示管道的输出是否应以序列化格式(即 pickle)或原始输出数据(例如文本)进行。
2.4.2 pipeline对象使用参数
- table(
pd.DataFrame
或Dict
)——Pandas DataFrame 或字典,将转换为包含所有表值的 DataFrame。有关字典的示例,请参见上文。 - query(
str
或List[str]
)——将与表一起发送到模型的查询或查询列表。 - sequential(可选
bool
,默认为)— 是否按顺序或批量进行推理。批处理速度更快,但考虑到 SQA 等模型的对话性质,它们要求按顺序进行推理以提取序列内的关系。False
- padding(
bool
,str
或PaddingStrategy,可选,默认为False
)— 激活并控制填充。接受以下值: - truncation(
bool
,str
或TapasTruncationStrategy
,可选,默认为False
)— 激活并控制截断。接受以下值:True
或'drop_rows_to_fit'
:截断为参数指定的最大长度max_length
,或模型可接受的最大输入长度(如果未提供该参数)。这将逐行截断,从表中删除行。False
或'do_not_truncate'
(默认):不截断(即,可以输出序列长度大于模型最大可接受输入大小的批次)。
2.4.3 pipeline返回参数
- answer (
str
) — 给定表的查询的答案。如果有聚合器,答案前面会加上AGGREGATOR >
。 - coordinates(
List[Tuple[int, int]]
)——答案单元格的坐标。 - cells (
List[str]
) — 由答案单元格值组成的字符串列表。 - aggregator(
str
)— 如果模型具有聚合器,则返回该聚合器。
2.5 pipeline实战
采用pipeline,使用google的tapas-base-finetuned-wtq进行表格问答。
代码语言:javascript复制import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from transformers import pipeline
oracle = pipeline(model="google/tapas-base-finetuned-wtq")
table = {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
}
output=oracle(query="How many stars does the transformers repository have?", table=table)
print(output)
执行后,自动下载模型文件并进行识别:
2.6 模型排名
在huggingface上,我们将表格问答(table-question-answering)模型按下载量从高到低排序,总计100个模型,排名第一是我们上述介绍的tapas-large-finetuned-wtq。
三、总结
本文对transformers之pipeline的表格问答(table-question-answering)从概述、技术原理、pipeline参数、pipeline实战、模型排名等方面进行介绍,读者可以基于pipeline使用文中的2行代码极简的使用NLP中的表格问答(table-question-answering)模型。