原文:
huggingface.co/docs/transformers
应用程序接口
主要类
代理和工具
原文:
huggingface.co/docs/transformers/v4.37.2/en/main_classes/agent
Transformers Agents 是一个实验性 API,随时可能发生变化。代理返回的结果可能会有所不同,因为 API 或底层模型可能会发生变化。
要了解更多关于代理和工具的信息,请确保阅读入门指南。此页面包含底层类的 API 文档。
代理
我们提供三种类型的代理:HfAgent 使用开源模型的推理端点,LocalAgent 在本地使用您选择的模型,OpenAiAgent 使用 OpenAI 封闭模型。
HfAgent
class transformers.HfAgent
<来源>
代码语言:javascript复制( url_endpoint token = None chat_prompt_template = None run_prompt_template = None additional_tools = None )
参数
-
url_endpoint
(str
)— 要使用的 url 端点的名称。 -
token
(str
,可选)— 用作远程文件的 HTTP 令牌的授权。如果未设置,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
chat_prompt_template
(str
,可选)— 如果要覆盖chat
方法的默认模板,请传递您自己的提示。在这种情况下,提示应该在此存储库中的名为chat_prompt_template.txt
的文件中。 -
run_prompt_template
(str
,可选)— 如果要覆盖run
方法的默认模板,请传递您自己的提示。在这种情况下,提示应该在此存储库中的名为run_prompt_template.txt
的文件中。 -
additional_tools
(Tool,工具列表或具有工具值的字典,可选)— 除默认工具外要包含的任何其他工具。如果传递具有与默认工具相同名称的工具,则将覆盖该默认工具。
使用推理端点生成代码的代理。
示例:
代码语言:javascript复制from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
LocalAgent
class transformers.LocalAgent
<来源>
代码语言:javascript复制( model tokenizer chat_prompt_template = None run_prompt_template = None additional_tools = None )
参数
model
(PreTrainedModel,工具列表或具有工具值的字典,可选)— 除默认工具外要包含的任何其他工具。如果传递具有与默认工具相同名称的工具,则将覆盖该默认工具。
使用本地模型和分词器生成代码的代理。
示例:
代码语言:javascript复制import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
checkpoint = "bigcode/starcoder"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
agent = LocalAgent(model, tokenizer)
agent.run("Draw me a picture of rivers and lakes.")
from_pretrained
<来源>
代码语言:javascript复制( pretrained_model_name_or_path **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — Hub 上的存储库名称或包含模型和分词器的本地路径的文件夹的名称。 -
kwargs
(Dict[str, Any]
, optional) — 传递给 from_pretrained()的关键字参数。
从预训练检查点构建LocalAgent
的便利方法。
示例:
代码语言:javascript复制import torch
from transformers import LocalAgent
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
agent.run("Draw me a picture of rivers and lakes.")
OpenAiAgent
class transformers.OpenAiAgent
<来源>
代码语言:javascript复制( model = 'text-davinci-003' api_key = None chat_prompt_template = None run_prompt_template = None additional_tools = None )
参数
-
model
(str
, optional, 默认为"text-davinci-003"
) — 要使用的 OpenAI 模型的名称。 -
api_key
(str
, optional) — 要使用的 API 密钥。如果未设置,将查找环境变量"OPENAI_API_KEY"
。 -
chat_prompt_template
(str
, optional) — 如果要覆盖chat
方法的默认模板,请传递您自己的提示。可以是实际的提示模板,也可以是存储库 ID(在 Hugging Face Hub 上)。在这种情况下,提示应该在此存储库中的名为chat_prompt_template.txt
的文件中。 -
run_prompt_template
(str
, optional) — 如果要覆盖run
方法的默认模板,请传递您自己的提示。可以是实际的提示模板,也可以是存储库 ID(在 Hugging Face Hub 上)。在这种情况下,提示应该在此存储库中的名为run_prompt_template.txt
的文件中。 -
additional_tools
(Tool,工具列表或具有工具值的字典,optional) — 除默认工具之外包含的任何其他工具。如果传递具有与默认工具之一相同名称的工具,则将覆盖该默认工具。
使用 openai API 生成代码的代理。
openAI 模型以生成模式使用,因此即使对于chat()
API,最好使用像"text-davinci-003"
这样的模型,而不是 chat-GPT 变体。对 chat-GPT 模型的适当支持将在下一个版本中提供。
示例:
代码语言:javascript复制from transformers import OpenAiAgent
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
AzureOpenAiAgent
class transformers.AzureOpenAiAgent
<来源>
代码语言:javascript复制( deployment_id api_key = None resource_name = None api_version = '2022-12-01' is_chat_model = None chat_prompt_template = None run_prompt_template = None additional_tools = None )
参数
-
deployment_id
(str
) — 要使用的部署的 Azure openAI 模型的名称。 -
api_key
(str
, optional) — 要使用的 API 密钥。如果未设置,将查找环境变量"AZURE_OPENAI_API_KEY"
。 -
resource_name
(str
, optional) — 您的 Azure OpenAI 资源的名称。如果未设置,将查找环境变量"AZURE_OPENAI_RESOURCE_NAME"
。 -
api_version
(str
, optional, 默认为"2022-12-01"
) — 用于此代理的 API 版本。 -
is_chat_mode
(bool
, optional) — 您是否使用完成模型或聊天模型(请参见上面的说明,聊天模型效率不高)。默认情况下,gpt
将在deployment_id
中或不在其中。 -
chat_prompt_template
(str
, optional) — 如果要覆盖chat
方法的默认模板,请传递您自己的提示。可以是实际的提示模板,也可以是存储库 ID(在 Hugging Face Hub 上)。在这种情况下,提示应该在此存储库中的名为chat_prompt_template.txt
的文件中。 -
run_prompt_template
(str
, optional) — 如果要覆盖run
方法的默认模板,请传递您自己的提示。可以是实际的提示模板,也可以是存储库 ID(在 Hugging Face Hub 上)。在这种情况下,提示应该在此存储库中的名为run_prompt_template.txt
的文件中。 -
additional_tools
(Tool,工具列表或具有工具值的字典,optional) — 除默认工具之外包含的任何其他工具。如果传递具有与默认工具之一相同名称的工具,则将覆盖该默认工具。
使用 Azure OpenAI 生成代码的代理。查看官方文档以了解如何在 Azure 上部署 openAI 模型
openAI 模型以生成模式使用,因此即使对于 chat()
API,最好使用像 "text-davinci-003"
这样的模型,而不是 chat-GPT 变体。chat-GPT 模型的适当支持将在下一个版本中提供。
示例:
代码语言:javascript复制from transformers import AzureOpenAiAgent
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
代理
class transformers.Agent
< source >
代码语言:javascript复制( chat_prompt_template = None run_prompt_template = None additional_tools = None )
参数
-
chat_prompt_template
(str
, 可选) — 如果要覆盖chat
方法的默认模板,则传递您自己的提示。在这种情况下,提示应该在此 repo 中名为chat_prompt_template.txt
的文件中。 -
run_prompt_template
(str
, 可选) — 如果要覆盖run
方法的默认模板,则传递您自己的提示。在这种情况下,提示应该在此 repo 中名为run_prompt_template.txt
的文件中。 -
additional_tools
(Tool, 工具列表或具有工具值的字典, 可选) — 除默认工具外要包含的任何其他工具。如果传递一个与默认工具中的某个同名的工具,那么默认工具将被覆盖。
包含主要 API 方法的所有代理的基类。
chat
< source >
代码语言:javascript复制( task return_code = False remote = False **kwargs )
参数
-
task
(str
) — 要执行的任务 -
return_code
(bool
, 可选, 默认为False
) — 是否只返回代码而不评估它。 -
remote
(bool
, 可选, 默认为False
) — 是否使用远程工具(推理端点)而不是本地工具。 -
kwargs
(额外的关键字参数, 可选) — 在评估代码时发送给代理的任何关键字参数。
向代理发送一个新的请求。将使用其历史记录中的先前请求。
示例:
代码语言:javascript复制from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.chat("Draw me a picture of rivers and lakes")
agent.chat("Transform the picture so that there is a rock in there")
run
< source >
代码语言:javascript复制( task return_code = False remote = False **kwargs )
参数
-
task
(str
) — 要执行的任务 -
return_code
(bool
, 可选, 默认为False
) — 是否只返回代码而不评估它。 -
remote
(bool
, 可选, 默认为False
) — 是否使用远程工具(推理端点)而不是本地工具。 -
kwargs
(额外的关键字参数, 可选) — 在评估代码时发送给代理的任何关键字参数。
向代理发送一个请求。
示例:
代码语言:javascript复制from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.run("Draw me a picture of rivers and lakes")
prepare_for_new_chat
< source >
代码语言:javascript复制( )
清除之前调用 chat() 的历史记录。
工具
load_tool
transformers.load_tool
< source >
代码语言:javascript复制( task_or_repo_id model_repo_id = None remote = False token = None **kwargs )
参数
-
task_or_repo_id
(str
) — 要加载工具的任务或 Hub 上工具的 repo ID。在 Transformers 中实现的任务有:-
"文档问答"
-
"图像字幕"
-
"图像问答"
-
"图像分割"
-
"语音到文本"
-
"摘要"
-
"文本分类"
-
"文本问答"
-
"文本到语音"
-
"翻译"
-
-
model_repo_id
(str
, 可选) — 使用此参数来使用与您选择的工具不同的模型。 -
remote
(bool
, 可选, 默认为False
) — 是否通过下载模型或(如果可用)使用推理端点来使用您的工具。 -
token
(str
, 可选) — 用于在 hf.co 上识别您的令牌。如果未设置,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
kwargs
(额外的关键字参数,可选) — 将被分成两部分的额外关键字参数:所有与 Hub 相关的参数(如cache_dir
、revision
、subfolder
)将在下载工具文件时使用,其他参数将传递给其 init。
快速加载工具的主要函数,无论是在 Hub 上还是在 Transformers 库中。
工具
class transformers.Tool
< source >
代码语言:javascript复制( *args **kwargs )
代理使用的函数的基类。子类化这个类并实现__call__
方法以及以下类属性:
-
description
(str
) — 您的工具的简短描述,它做什么,它期望的输入以及它将返回的输出。例如‘这是一个从url
下载文件的工具。它以url
作为输入,并返回文件中包含的文本’。 -
name
(str
) — 一个表现性的名称,将在提示代理时用于您的工具。例如"text-classifier"
或"image_generator"
。 -
inputs
(List[str]
) — 期望输入的模态列表(与调用中的顺序相同)。模态应为"text"
、"image"
或"audio"
。仅供launch_gradio_demo
使用或为您的工具创建一个漂亮的空间。 -
outputs
(List[str]
) — 工具返回的模态列表(与调用方法返回的顺序相同)。模态应为"text"
、"image"
或"audio"
。仅供launch_gradio_demo
使用或为您的工具创建一个漂亮的空间。
如果您的工具有一个昂贵的操作需要在可用之前执行(例如加载模型),您也可以重写方法 setup()。setup()将在您第一次使用工具时调用,但不会在实例化时调用。
from_gradio
< source >
代码语言:javascript复制( gradio_tool )
从 gradio 工具创建一个 Tool。
from_hub
< source >
代码语言:javascript复制( repo_id: str model_repo_id: Optional = None token: Optional = None remote: bool = False **kwargs )
参数
-
repo_id
(str
) — 在 Hub 上定义您的工具的 repo 的名称。 -
model_repo_id
(str
, 可选) — 如果您的工具使用模型并且想要使用不同于默认的模型,您可以将第二个 repo ID 或端点 url 传递给此参数。 -
token
(str
, 可选) — 用于在 hf.co 上识别您的令牌。如果未设置,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
remote
(bool
, 可选,默认为False
) — 是否通过下载模型或(如果可用)使用推理端点来使用您的工具。 -
kwargs
(额外的关键字参数,可选) — 将被分成两部分的额外关键字参数:所有与 Hub 相关的参数(如cache_dir
、revision
、subfolder
)将在下载工具文件时使用,其他参数将传递给其 init。
加载在 Hub 上定义的工具。
push_to_hub
< source >
代码语言:javascript复制( repo_id: str commit_message: str = 'Upload tool' private: Optional = None token: Union = None create_pr: bool = False )
参数
-
repo_id
(str
) — 您要将工具推送到的存储库的名称。在将工具推送到给定组织时,它应包含您的组织名称。 -
commit_message
(str
, 可选,默认为"Upload tool"
) — 推送时要提交的消息。 -
private
(bool
, 可选) — 是否应该创建私有的存储库。 -
token
(bool
或str
, 可选) — 用作远程文件的 HTTP bearer 授权的令牌。如果未设置,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
create_pr
(bool
, 可选, 默认为False
) — 是否创建一个带有上传文件的 PR 或直接提交。
将工具上传到 Hub。
save
<来源>
代码语言:javascript复制( output_dir )
参数
output_dir
(str
) — 您想要保存工具的文件夹。
保存与您的工具相关的代码文件,以便将其推送到 Hub。这将在output_dir
中复制您的工具的代码,并自动生成:
- 一个名为
tool_config.json
的配置文件 - 一个
app.py
文件,以便将您的工具转换为一个空间 - 一个包含您的工具使用的模块名称的
requirements.txt
文件(在检查其代码时检测到)。
您应该只使用此方法来保存在单独模块中定义的工具(不是__main__
)。
setup
<来源>
代码语言:javascript复制( )
在此处覆盖任何昂贵且需要在开始使用工具之前执行的操作的方法。例如加载一个大模型。
PipelineTool
class transformers.PipelineTool
<来源>
代码语言:javascript复制( model = None pre_processor = None post_processor = None device = None device_map = None model_kwargs = None token = None **hub_kwargs )
参数
-
model
(str
或 PreTrainedModel, 可选) — 用于模型的检查点名称,或实例化的模型。如果未设置,将默认为类属性default_checkpoint
的值。 -
pre_processor
(str
或Any
, 可选) — 用于预处理器的检查点名称,或实例化的预处理器(可以是分词器、图像处理器、特征提取器或处理器)。如果未设置,将默认为model
的值。 -
post_processor
(str
或Any
, 可选) — 用于后处理器的检查点名称,或实例化的预处理器(可以是分词器、图像处理器、特征提取器或处理器)。如果未设置,将默认为pre_processor
的值。 -
device
(int
、str
或torch.device
, 可选) — 用于执行模型的设备。默认情况下将使用任何可用的加速器(GPU、MPS 等),否则使用 CPU。 -
device_map
(str
或dict
, 可选) — 如果传递,将用于实例化模型。 -
model_kwargs
(dict
, 可选) — 发送到模型实例化的任何关键字参数。 -
token
(str
, 可选) — 用作远程文件的 HTTP bearer 授权的令牌。如果未设置,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
hub_kwargs
(额外的关键字参数,可选) — 发送到从 Hub 加载数据的方法的任何额外关键字参数。
一个 Tool 专为 Transformer 模型定制。除了基类 Tool 的类属性外,您还需要指定:
-
model_class
(type
) — 用于在此工具中加载模型的类。 -
default_checkpoint
(str
) — 当用户未指定时应使用的默认检查点。 -
pre_processor_class
(type
, 可选, 默认为 AutoProcessor) — 用于加载预处理器的类 -
post_processor_class
(type
, 可选, 默认为 AutoProcessor) — 用于加载后处理器的类(当与预处理器不同时)。
decode
<来源>
代码语言:javascript复制( outputs )
使用post_processor
解码模型输出。
encode
<来源>
代码语言:javascript复制( raw_inputs )
使用pre_processor
准备输入以供model
使用。
forward
<来源>
代码语言:javascript复制( inputs )
将输入发送到model
。
setup
<来源>
代码语言:javascript复制( )
如有必要,实例化pre_processor
、model
和post_processor
。
RemoteTool
class transformers.RemoteTool
<来源>
代码语言:javascript复制( endpoint_url = None token = None tool_class = None )
参数
-
endpoint_url
(str
,可选)—要使用的端点的 url。 -
token
(str
,可选)—用作远程文件的 HTTP bearer 授权的令牌。如果未设置,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
tool_class
(type
,可选)—如果这是现有工具的远程版本,则相应的tool_class
。将有助于确定何时应将输出转换为另一种类型(如图像)。
将向推理端点发出请求的 Tool。
extract_outputs
<来源>
代码语言:javascript复制( outputs )
您可以在您的自定义类 RemoteTool 中覆盖此方法,以对端点的输出应用一些自定义后处理。
prepare_inputs
<来源>
代码语言:javascript复制( *args **kwargs )
准备接收到的输入以便通过 HTTP 客户端将数据发送到端点。如果在实例化时提供了tool_class
,则位置参数将与tool_class
的签名匹配。图像将被编码为字节。
您可以在您的自定义类 RemoteTool 中覆盖此方法。
launch_gradio_demo
transformers.launch_gradio_demo
<来源>
代码语言:javascript复制( tool_class: Tool )
参数
tool_class
(type
)—要启动演示的工具类。
启动工具的 gradio 演示。相应的工具类需要正确实现类属性inputs
和outputs
。
Agent 类型
代理可以处理工具之间的任何类型的对象;工具完全多模态,可以接受和返回文本、图像、音频、视频等类型。为了增加工具之间的兼容性,并正确在 ipython(jupyter、colab、ipython 笔记本等)中呈现这些返回,我们实现了这些类型的包装类。
包装对象应继续最初的行为;文本对象仍应表现为字符串,图像对象仍应表现为PIL.Image
。
这些类型有三个特定目的:
- 在类型上调用
to_raw
应该返回底层对象 - 在类型上调用
to_string
应该将对象作为字符串返回:在AgentText
的情况下可以是字符串,但在其他实例中将是对象的序列化版本的路径 - 在 ipython 内核中显示它应该正确显示对象
AgentText
class transformers.tools.agent_types.AgentText
<来源>
代码语言:javascript复制( value )
代理返回的文本类型。表现为字符串。
AgentImage
class transformers.tools.agent_types.AgentImage
<来源>
代码语言:javascript复制( value )
代理返回的图像类型。行为类似于 PIL.Image。
to_raw
<来源>
代码语言:javascript复制( )
返回该对象的“原始”版本。在 AgentImage 的情况下,它是一个 PIL.Image。
to_string
<来源>
代码语言:javascript复制( )
返回该对象的字符串版本。在 AgentImage 的情况下,它是图像序列化版本的路径。
AgentAudio
class transformers.tools.agent_types.AgentAudio
<来源>
代码语言:javascript复制( value samplerate = 16000 )
代理返回的音频类型。
to_raw
<来源>
代码语言:javascript复制( )
返回该对象的“原始”版本。它是一个torch.Tensor
对象。
to_string
<来源>
代码语言:javascript复制( )
返回该对象的字符串版本。在 AgentAudio 的情况下,它是音频序列化版本的路径。
自动类
原始文本:
huggingface.co/docs/transformers/v4.37.2/en/model_doc/auto
在许多情况下,您想要使用的架构可以从您提供给from_pretrained()
方法的预训练模型的名称或路径中猜出。AutoClasses 在这里为您执行此操作,以便根据预训练权重/配置/词汇的名称/路径自动检索相关模型。
实例化 AutoConfig、AutoModel 和 AutoTokenizer 中的一个将直接创建相关架构的类。例如
代码语言:javascript复制model = AutoModel.from_pretrained("bert-base-cased")
将创建一个 BertModel 的实例模型。
每个任务和每个后端(PyTorch、TensorFlow 或 Flax)都有一个AutoModel
类。
扩展自动类
每个自动类都有一个方法可以用来扩展您的自定义类。例如,如果您定义了一个名为NewModel
的自定义模型类,请确保有一个NewModelConfig
,然后您可以像这样将它们添加到自动类中:
from transformers import AutoConfig, AutoModel
AutoConfig.register("new-model", NewModelConfig)
AutoModel.register(NewModelConfig, NewModel)
然后您就可以像通常一样使用自动类了!
如果您的NewModelConfig
是 PretrainedConfig 的子类,请确保其model_type
属性设置为注册配置时使用的相同键(这里是"new-model"
)。
同样,如果您的NewModel
是 PreTrainedModel 的子类,请确保其config_class
属性设置为注册模型时使用的相同类(这里是NewModelConfig
)。
AutoConfig
class transformers.AutoConfig
<来源>
代码语言:javascript复制( )
这是一个通用的配置类,当使用 from_pretrained()类方法创建时,将实例化为库的配置类之一。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
<来源>
代码语言:javascript复制( pretrained_model_name_or_path **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 一个字符串,预训练模型配置的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()方法保存的配置文件,或者 save_pretrained()方法,例如
./my_model_directory/
。 - 一个保存的配置 JSON 文件的路径或 url,例如
./my_model_directory/configuration.json
。
- 一个字符串,预训练模型配置的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
cache_dir
(str
或os.PathLike
, optional) — 下载的预训练模型配置应该缓存在其中的目录路径,如果不使用标准缓存。 -
force_download
(bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,并覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, 默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求上使用。 -
revision
(str
,可选,默认为"main"
)— 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
return_unused_kwargs
(bool
,可选,默认为False
)— 如果为False
,则此函数仅返回最终配置对象。 如果为True
,则此函数返回一个Tuple(config, unused_kwargs)
,其中unused_kwargs是一个字典,由那些键/值对组成,其键不是配置属性:即kwargs
的一部分,未被用于更新config
且被忽略的部分。 -
trust_remote_code
(bool
,可选,默认为False
)— 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
kwargs
(附加关键字参数,可选)— kwargs 中任何键的值,其为配置属性,将用于覆盖加载的值。关于键/值对中键不是配置属性的行为由return_unused_kwargs
关键字参数控制。
从预训练模型配置中实例化库的配置类之一。
实例化的配置类是根据加载的配置对象的model_type
属性选择的,或者当它缺失时,通过在pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— AlbertConfig(ALBERT 模型) -
align
— AlignConfig(ALIGN 模型) -
altclip
— AltCLIPConfig(AltCLIP 模型) -
audio-spectrogram-transformer
— ASTConfig(音频频谱变换器模型) -
autoformer
— AutoformerConfig(Autoformer 模型) -
bark
— BarkConfig(Bark 模型) -
bart
— BartConfig(BART 模型) -
beit
— BeitConfig(BEiT 模型) -
bert
— BertConfig(BERT 模型) -
bert-generation
— BertGenerationConfig(Bert 生成模型) -
big_bird
— BigBirdConfig(BigBird 模型) -
bigbird_pegasus
— BigBirdPegasusConfig(BigBird-Pegasus 模型) -
biogpt
— BioGptConfig(BioGpt 模型) -
bit
— BitConfig(BiT 模型) -
blenderbot
— BlenderbotConfig(Blenderbot 模型) -
blenderbot-small
— BlenderbotSmallConfig(BlenderbotSmall 模型) -
blip
— BlipConfig (BLIP 模型) -
blip-2
— Blip2Config (BLIP-2 模型) -
bloom
— BloomConfig (BLOOM 模型) -
bridgetower
— BridgeTowerConfig (BridgeTower 模型) -
bros
— BrosConfig (BROS 模型) -
camembert
— CamembertConfig (CamemBERT 模型) -
canine
— CanineConfig (CANINE 模型) -
chinese_clip
— ChineseCLIPConfig (Chinese-CLIP 模型) -
clap
— ClapConfig (CLAP 模型) -
clip
— CLIPConfig (CLIP 模型) -
clip_vision_model
— CLIPVisionConfig (CLIPVisionModel 模型) -
clipseg
— CLIPSegConfig (CLIPSeg 模型) -
clvp
— ClvpConfig (CLVP 模型) -
code_llama
— LlamaConfig (CodeLlama 模型) -
codegen
— CodeGenConfig (CodeGen 模型) -
conditional_detr
— ConditionalDetrConfig (Conditional DETR 模型) -
convbert
— ConvBertConfig (ConvBERT 模型) -
convnext
— ConvNextConfig (ConvNeXT 模型) -
convnextv2
— ConvNextV2Config (ConvNeXTV2 模型) -
cpmant
— CpmAntConfig (CPM-Ant 模型) -
ctrl
— CTRLConfig (CTRL 模型) -
cvt
— CvtConfig (CvT 模型) -
data2vec-audio
— Data2VecAudioConfig (Data2VecAudio 模型) -
data2vec-text
— Data2VecTextConfig (Data2VecText 模型) -
data2vec-vision
— Data2VecVisionConfig (Data2VecVision 模型) -
deberta
— DebertaConfig (DeBERTa 模型) -
deberta-v2
— DebertaV2Config (DeBERTa-v2 模型) -
decision_transformer
— DecisionTransformerConfig (Decision Transformer 模型) -
deformable_detr
— DeformableDetrConfig (Deformable DETR 模型) -
deit
— DeiTConfig (DeiT 模型) -
deta
— DetaConfig (DETA 模型) -
detr
— DetrConfig (DETR 模型) -
dinat
— DinatConfig (DiNAT 模型) -
dinov2
— Dinov2Config (DINOv2 模型) -
distilbert
— DistilBertConfig (DistilBERT 模型) -
donut-swin
— DonutSwinConfig (DonutSwin 模型) -
dpr
— DPRConfig (DPR 模型) -
dpt
— DPTConfig (DPT 模型) -
efficientformer
— EfficientFormerConfig (EfficientFormer 模型) -
efficientnet
— EfficientNetConfig (EfficientNet 模型) -
electra
— ElectraConfig (ELECTRA 模型) -
encodec
— EncodecConfig (EnCodec 模型) -
encoder-decoder
— EncoderDecoderConfig (编码器解码器模型) -
ernie
— ErnieConfig (ERNIE 模型) -
ernie_m
— ErnieMConfig (ErnieM 模型) -
esm
— EsmConfig (ESM 模型) -
falcon
— FalconConfig (Falcon 模型) -
fastspeech2_conformer
— FastSpeech2ConformerConfig (FastSpeech2Conformer 模型) -
flaubert
— FlaubertConfig (FlauBERT 模型) -
flava
— FlavaConfig (FLAVA 模型) -
fnet
— FNetConfig (FNet 模型) -
focalnet
— FocalNetConfig (FocalNet 模型) -
fsmt
— FSMTConfig (FairSeq 机器翻译模型) -
funnel
— FunnelConfig (Funnel Transformer 模型) -
fuyu
— FuyuConfig (Fuyu 模型) -
git
— GitConfig (GIT 模型) -
glpn
— GLPNConfig (GLPN 模型) -
gpt-sw3
— GPT2Config (GPT-Sw3 模型) -
gpt2
— GPT2Config (OpenAI GPT-2 模型) -
gpt_bigcode
— GPTBigCodeConfig (GPTBigCode 模型) -
gpt_neo
— GPTNeoConfig (GPT Neo 模型) -
gpt_neox
— GPTNeoXConfig (GPT NeoX 模型) -
gpt_neox_japanese
— GPTNeoXJapaneseConfig (GPT NeoX 日语模型) -
gptj
— GPTJConfig (GPT-J 模型) -
gptsan-japanese
— GPTSanJapaneseConfig(GPTSAN-japanese 模型) -
graphormer
— GraphormerConfig(Graphormer 模型) -
groupvit
— GroupViTConfig(GroupViT 模型) -
hubert
— HubertConfig(Hubert 模型) -
ibert
— IBertConfig(I-BERT 模型) -
idefics
— IdeficsConfig(IDEFICS 模型) -
imagegpt
— ImageGPTConfig(ImageGPT 模型) -
informer
— InformerConfig(Informer 模型) -
instructblip
— InstructBlipConfig(InstructBLIP 模型) -
jukebox
— JukeboxConfig(Jukebox 模型) -
kosmos-2
— Kosmos2Config(KOSMOS-2 模型) -
layoutlm
— LayoutLMConfig(LayoutLM 模型) -
layoutlmv2
— LayoutLMv2Config(LayoutLMv2 模型) -
layoutlmv3
— LayoutLMv3Config(LayoutLMv3 模型) -
led
— LEDConfig(LED 模型) -
levit
— LevitConfig(LeViT 模型) -
lilt
— LiltConfig(LiLT 模型) -
llama
— LlamaConfig(LLaMA 模型) -
llava
— LlavaConfig(LLaVa 模型) -
longformer
— LongformerConfig(Longformer 模型) -
longt5
— LongT5Config(LongT5 模型) -
luke
— LukeConfig(LUKE 模型) -
lxmert
— LxmertConfig(LXMERT 模型) -
m2m_100
— M2M100Config(M2M100 模型) -
marian
— MarianConfig(Marian 模型) -
markuplm
— MarkupLMConfig(MarkupLM 模型) -
mask2former
— Mask2FormerConfig(Mask2Former 模型) -
maskformer
— MaskFormerConfig(MaskFormer 模型) -
maskformer-swin
—MaskFormerSwinConfig
(MaskFormerSwin 模型) -
mbart
— MBartConfig(mBART 模型) -
mctct
— MCTCTConfig(M-CTC-T 模型) -
mega
— MegaConfig(MEGA 模型) -
megatron-bert
— MegatronBertConfig(Megatron-BERT 模型) -
mgp-str
— MgpstrConfig (MGP-STR 模型) -
mistral
— MistralConfig (Mistral 模型) -
mixtral
— MixtralConfig (Mixtral 模型) -
mobilebert
— MobileBertConfig (MobileBERT 模型) -
mobilenet_v1
— MobileNetV1Config (MobileNetV1 模型) -
mobilenet_v2
— MobileNetV2Config (MobileNetV2 模型) -
mobilevit
— MobileViTConfig (MobileViT 模型) -
mobilevitv2
— MobileViTV2Config (MobileViTV2 模型) -
mpnet
— MPNetConfig (MPNet 模型) -
mpt
— MptConfig (MPT 模型) -
mra
— MraConfig (MRA 模型) -
mt5
— MT5Config (MT5 模型) -
musicgen
— MusicgenConfig (MusicGen 模型) -
mvp
— MvpConfig (MVP 模型) -
nat
— NatConfig (NAT 模型) -
nezha
— NezhaConfig (Nezha 模型) -
nllb-moe
— NllbMoeConfig (NLLB-MOE 模型) -
nougat
— VisionEncoderDecoderConfig (Nougat 模型) -
nystromformer
— NystromformerConfig (Nyströmformer 模型) -
oneformer
— OneFormerConfig (OneFormer 模型) -
open-llama
— OpenLlamaConfig (OpenLlama 模型) -
openai-gpt
— OpenAIGPTConfig (OpenAI GPT 模型) -
opt
— OPTConfig (OPT 模型) -
owlv2
— Owlv2Config (OWLv2 模型) -
owlvit
— OwlViTConfig (OWL-ViT 模型) -
patchtsmixer
— PatchTSMixerConfig (PatchTSMixer 模型) -
patchtst
— PatchTSTConfig (PatchTST 模型) -
pegasus
— PegasusConfig (Pegasus 模型) -
pegasus_x
— PegasusXConfig (PEGASUS-X 模型) -
perceiver
— PerceiverConfig (Perceiver 模型) -
persimmon
— PersimmonConfig (Persimmon 模型) -
phi
— PhiConfig (Phi 模型) -
pix2struct
— Pix2StructConfig (Pix2Struct 模型) -
plbart
— PLBartConfig (PLBart 模型) -
poolformer
— PoolFormerConfig (PoolFormer 模型) -
pop2piano
— Pop2PianoConfig (Pop2Piano 模型) -
prophetnet
— ProphetNetConfig (ProphetNet 模型) -
pvt
— PvtConfig (PVT 模型) -
qdqbert
— QDQBertConfig (QDQBert 模型) -
qwen2
— Qwen2Config (Qwen2 模型) -
rag
— RagConfig (RAG 模型) -
realm
— RealmConfig (REALM 模型) -
reformer
— ReformerConfig (Reformer 模型) -
regnet
— RegNetConfig (RegNet 模型) -
rembert
— RemBertConfig (RemBERT 模型) -
resnet
— ResNetConfig (ResNet 模型) -
retribert
— RetriBertConfig (RetriBERT 模型) -
roberta
— RobertaConfig (RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormConfig (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertConfig (RoCBert 模型) -
roformer
— RoFormerConfig (RoFormer 模型) -
rwkv
— RwkvConfig (RWKV 模型) -
sam
— SamConfig (SAM 模型) -
seamless_m4t
— SeamlessM4TConfig (SeamlessM4T 模型) -
seamless_m4t_v2
— SeamlessM4Tv2Config (SeamlessM4Tv2 模型) -
segformer
— SegformerConfig (SegFormer 模型) -
sew
— SEWConfig (SEW 模型) -
sew-d
— SEWDConfig (SEW-D 模型) -
siglip
— SiglipConfig (SigLIP 模型) -
siglip_vision_model
— SiglipVisionConfig (SiglipVisionModel 模型) -
speech-encoder-decoder
— SpeechEncoderDecoderConfig (Speech 编码器解码器模型) -
speech_to_text
— Speech2TextConfig (Speech2Text 模型) -
speech_to_text_2
— Speech2Text2Config (Speech2Text2 模型) -
speecht5
— SpeechT5Config (SpeechT5 模型) -
splinter
— SplinterConfig (Splinter 模型) -
squeezebert
— SqueezeBertConfig (SqueezeBERT 模型) -
swiftformer
— SwiftFormerConfig (SwiftFormer 模型) -
swin
— SwinConfig (Swin Transformer 模型) -
swin2sr
— Swin2SRConfig (Swin2SR 模型) -
swinv2
— Swinv2Config (Swin Transformer V2 模型) -
switch_transformers
— SwitchTransformersConfig (SwitchTransformers 模型) -
t5
— T5Config (T5 模型) -
table-transformer
— TableTransformerConfig (Table Transformer 模型) -
tapas
— TapasConfig (TAPAS 模型) -
time_series_transformer
— TimeSeriesTransformerConfig (Time Series Transformer 模型) -
timesformer
— TimesformerConfig (TimeSformer 模型) -
timm_backbone
—TimmBackboneConfig
(TimmBackbone 模型) -
trajectory_transformer
— TrajectoryTransformerConfig (Trajectory Transformer 模型) -
transfo-xl
— TransfoXLConfig (Transformer-XL 模型) -
trocr
— TrOCRConfig (TrOCR 模型) -
tvlt
— TvltConfig (TVLT 模型) -
tvp
— TvpConfig (TVP 模型) -
umt5
— UMT5Config (UMT5 模型) -
unispeech
— UniSpeechConfig (UniSpeech 模型) -
unispeech-sat
— UniSpeechSatConfig (UniSpeechSat 模型) -
univnet
— UnivNetConfig (UnivNet 模型) -
upernet
— UperNetConfig (UPerNet 模型) -
van
— VanConfig (VAN 模型) -
videomae
— VideoMAEConfig (VideoMAE 模型) -
vilt
— ViltConfig (ViLT 模型) -
vipllava
— VipLlavaConfig (VipLlava 模型) -
vision-encoder-decoder
— VisionEncoderDecoderConfig (Vision Encoder decoder 模型) -
vision-text-dual-encoder
— VisionTextDualEncoderConfig (VisionTextDualEncoder 模型) -
visual_bert
— VisualBertConfig(VisualBERT 模型) -
vit
— ViTConfig(ViT 模型) -
vit_hybrid
— ViTHybridConfig(ViT 混合模型) -
vit_mae
— ViTMAEConfig(ViTMAE 模型) -
vit_msn
— ViTMSNConfig(ViTMSN 模型) -
vitdet
— VitDetConfig(VitDet 模型) -
vitmatte
— VitMatteConfig(ViTMatte 模型) -
vits
— VitsConfig(VITS 模型) -
vivit
— VivitConfig(ViViT 模型) -
wav2vec2
— Wav2Vec2Config(Wav2Vec2 模型) -
wav2vec2-bert
— Wav2Vec2BertConfig(Wav2Vec2-BERT 模型) -
wav2vec2-conformer
— Wav2Vec2ConformerConfig(Wav2Vec2-Conformer 模型) -
wavlm
— WavLMConfig(WavLM 模型) -
whisper
— WhisperConfig(Whisper 模型) -
xclip
— XCLIPConfig(X-CLIP 模型) -
xglm
— XGLMConfig(XGLM 模型) -
xlm
— XLMConfig(XLM 模型) -
xlm-prophetnet
— XLMProphetNetConfig(XLM-ProphetNet 模型) -
xlm-roberta
— XLMRobertaConfig(XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLConfig(XLM-RoBERTa-XL 模型) -
xlnet
— XLNetConfig(XLNet 模型) -
xmod
— XmodConfig(X-MOD 模型) -
yolos
— YolosConfig(YOLOS 模型) -
yoso
— YosoConfig(YOSO 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
>>> # Load a specific configuration file.
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
>>> # Change some config attributes when loading a pretrained config.
>>> config = AutoConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
>>> config.output_attentions
True
>>> config, unused_kwargs = AutoConfig.from_pretrained(
... "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
... )
>>> config.output_attentions
True
>>> unused_kwargs
{'foo': False}
register
<来源>
代码语言:javascript复制( model_type config exist_ok = False )
参数
-
model_type
(str
)— 模型类型,如“bert”或“gpt”。 -
config
(PretrainedConfig)— 要注册的配置。
为这个类注册一个新的配置。
AutoTokenizer
class transformers.AutoTokenizer
<来源>
代码语言:javascript复制( )
这是一个通用的分词器类,当使用 AutoTokenizer.from_pretrained()类方法创建时,将实例化为库中的分词器类之一。
这个类不能直接使用__init__()
实例化(会报错)。
from_pretrained
<来源>
代码语言:javascript复制( pretrained_model_name_or_path *inputs **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)- 可以是:- 一个字符串,预定义的分词器的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 包含分词器所需的词汇文件的目录路径,例如使用 save_pretrained()方法保存的,例如,
./my_model_directory/
。 - 如果且仅当分词器只需要单个词汇文件(如 Bert 或 XLNet)时,可以是单个保存的词汇文件的路径或 url,例如:
./my_model_directory/vocab.txt
。(不适用于所有派生类)
- 一个字符串,预定义的分词器的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级别,如
-
inputs
(额外的位置参数,可选)- 将传递给分词器__init__()
方法。 -
config
(PretrainedConfig,可选)- 用于确定要实例化的分词器类的配置对象。 -
cache_dir
(str
或os.PathLike
,可选)- 如果不应使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
force_download
(bool
,可选,默认为False
)- 是否强制(重新)下载模型权重和配置文件,并覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)- 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)- 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
revision
(str
,可选,默认为"main"
)- 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
subfolder
(str
,可选)- 如果相关文件位于 huggingface.co 上模型存储库的子文件夹中(例如对于 facebook/rag-token-base),请在此处指定。 -
use_fast
(bool
,可选,默认为True
)- 如果给定模型支持,使用快速基于 Rust 的分词器。如果给定模型不支持快速分词器,则将返回普通的基于 Python 的分词器。 -
tokenizer_type
(str
,可选)- 要加载的分词器类型。 -
trust_remote_code
(bool
,可选,默认为False
)- 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读代码,因为它将在本地计算机上执行 Hub 上存在的代码。 -
kwargs
(额外的关键字参数,可选)- 将传递给分词器__init__()
方法。可用于设置特殊标记,如bos_token
、eos_token
、unk_token
、sep_token
、pad_token
、cls_token
、mask_token
、additional_special_tokens
。有关更多详细信息,请参阅__init__()
中的参数。
从预训练模型词汇实例化库中的一个分词器类。
要实例化的分词器类是根据配置对象的model_type
属性(如果可能作为参数传递或从pretrained_model_name_or_path
加载)选择的,或者当缺少时,通过在pretrained_model_name_or_path
上使用模式匹配来回退选择:
-
albert
— AlbertTokenizer 或 AlbertTokenizerFast (ALBERT 模型) -
align
— BertTokenizer 或 BertTokenizerFast (ALIGN 模型) -
bark
— BertTokenizer 或 BertTokenizerFast (Bark 模型) -
bart
— BartTokenizer 或 BartTokenizerFast (BART 模型) -
barthez
— BarthezTokenizer 或 BarthezTokenizerFast (BARThez 模型) -
bartpho
— BartphoTokenizer (BARTpho 模型) -
bert
— BertTokenizer 或 BertTokenizerFast (BERT 模型) -
bert-generation
— BertGenerationTokenizer (Bert Generation 模型) -
bert-japanese
— BertJapaneseTokenizer (BertJapanese 模型) -
bertweet
— BertweetTokenizer (BERTweet 模型) -
big_bird
— BigBirdTokenizer 或 BigBirdTokenizerFast (BigBird 模型) -
bigbird_pegasus
— PegasusTokenizer 或 PegasusTokenizerFast (BigBird-Pegasus 模型) -
biogpt
— BioGptTokenizer (BioGpt 模型) -
blenderbot
— BlenderbotTokenizer 或 BlenderbotTokenizerFast (Blenderbot 模型) -
blenderbot-small
— BlenderbotSmallTokenizer (BlenderbotSmall 模型) -
blip
— BertTokenizer 或 BertTokenizerFast (BLIP 模型) -
blip-2
— GPT2Tokenizer 或 GPT2TokenizerFast (BLIP-2 模型) -
bloom
— BloomTokenizerFast (BLOOM 模型) -
bridgetower
— RobertaTokenizer 或 RobertaTokenizerFast (BridgeTower 模型) -
bros
— BertTokenizer 或 BertTokenizerFast (BROS 模型) -
byt5
— ByT5Tokenizer (ByT5 模型) -
camembert
— CamembertTokenizer 或 CamembertTokenizerFast (CamemBERT 模型) -
canine
— CanineTokenizer (CANINE 模型) -
chinese_clip
— BertTokenizer 或 BertTokenizerFast (Chinese-CLIP 模型) -
clap
— RobertaTokenizer 或 RobertaTokenizerFast (CLAP 模型) -
clip
— CLIPTokenizer 或 CLIPTokenizerFast (CLIP 模型) -
clipseg
— CLIPTokenizer 或 CLIPTokenizerFast (CLIPSeg 模型) -
clvp
— ClvpTokenizer (CLVP 模型) -
code_llama
— CodeLlamaTokenizer 或 CodeLlamaTokenizerFast (CodeLlama 模型) -
codegen
— CodeGenTokenizer 或 CodeGenTokenizerFast (CodeGen 模型) -
convbert
— ConvBertTokenizer 或 ConvBertTokenizerFast (ConvBERT 模型) -
cpm
— CpmTokenizer 或 CpmTokenizerFast (CPM 模型) -
cpmant
— CpmAntTokenizer (CPM-Ant 模型) -
ctrl
— CTRLTokenizer (CTRL 模型) -
data2vec-audio
— Wav2Vec2CTCTokenizer (Data2VecAudio 模型) -
data2vec-text
— RobertaTokenizer 或 RobertaTokenizerFast (Data2VecText 模型) -
deberta
— DebertaTokenizer 或 DebertaTokenizerFast (DeBERTa 模型) -
deberta-v2
— DebertaV2Tokenizer 或 DebertaV2TokenizerFast (DeBERTa-v2 模型) -
distilbert
— DistilBertTokenizer 或 DistilBertTokenizerFast (DistilBERT 模型) -
dpr
— DPRQuestionEncoderTokenizer 或 DPRQuestionEncoderTokenizerFast (DPR 模型) -
electra
— ElectraTokenizer 或 ElectraTokenizerFast (ELECTRA 模型) -
ernie
— BertTokenizer 或 BertTokenizerFast (ERNIE 模型) -
ernie_m
— ErnieMTokenizer (ErnieM 模型) -
esm
— EsmTokenizer (ESM 模型) -
falcon
— PreTrainedTokenizerFast (Falcon 模型) -
fastspeech2_conformer
— (FastSpeech2Conformer 模型) -
flaubert
— FlaubertTokenizer (FlauBERT 模型) -
fnet
— FNetTokenizer 或 FNetTokenizerFast (FNet 模型) -
fsmt
— FSMTTokenizer (FairSeq 机器翻译模型) -
funnel
— FunnelTokenizer 或 FunnelTokenizerFast (Funnel Transformer 模型) -
git
— BertTokenizer 或 BertTokenizerFast (GIT 模型) -
gpt-sw3
— GPTSw3Tokenizer (GPT-Sw3 模型) -
gpt2
— GPT2Tokenizer 或 GPT2TokenizerFast (OpenAI GPT-2 模型) -
gpt_bigcode
— GPT2Tokenizer 或 GPT2TokenizerFast (GPTBigCode 模型) -
gpt_neo
— GPT2Tokenizer 或 GPT2TokenizerFast (GPT Neo 模型) -
gpt_neox
— GPTNeoXTokenizerFast (GPT NeoX 模型) -
gpt_neox_japanese
— GPTNeoXJapaneseTokenizer (GPT NeoX Japanese 模型) -
gptj
— GPT2Tokenizer 或 GPT2TokenizerFast (GPT-J 模型) -
gptsan-japanese
— GPTSanJapaneseTokenizer (GPTSAN-japanese 模型) -
groupvit
— CLIPTokenizer 或 CLIPTokenizerFast (GroupViT 模型) -
herbert
— HerbertTokenizer 或 HerbertTokenizerFast (HerBERT 模型) -
hubert
— Wav2Vec2CTCTokenizer (Hubert 模型) -
ibert
— RobertaTokenizer 或 RobertaTokenizerFast (I-BERT 模型) -
idefics
— LlamaTokenizerFast (IDEFICS 模型) -
instructblip
— GPT2Tokenizer 或 GPT2TokenizerFast (InstructBLIP 模型) -
jukebox
— JukeboxTokenizer (Jukebox 模型) -
kosmos-2
— XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (KOSMOS-2 模型) -
layoutlm
— LayoutLMTokenizer 或 LayoutLMTokenizerFast (LayoutLM 模型) -
layoutlmv2
— LayoutLMv2Tokenizer 或 LayoutLMv2TokenizerFast (LayoutLMv2 模型) -
layoutlmv3
— LayoutLMv3Tokenizer 或 LayoutLMv3TokenizerFast (LayoutLMv3 模型) -
layoutxlm
— LayoutXLMTokenizer 或 LayoutXLMTokenizerFast (LayoutXLM 模型) -
led
— LEDTokenizer 或 LEDTokenizerFast (LED 模型) -
lilt
— LayoutLMv3Tokenizer 或 LayoutLMv3TokenizerFast (LiLT 模型) -
llama
— LlamaTokenizer 或 LlamaTokenizerFast (LLaMA 模型) -
llava
— LlamaTokenizer 或 LlamaTokenizerFast (LLaVa 模型) -
longformer
— LongformerTokenizer 或 LongformerTokenizerFast (Longformer 模型) -
longt5
— T5Tokenizer 或 T5TokenizerFast (LongT5 模型) -
luke
— LukeTokenizer (LUKE 模型) -
lxmert
— LxmertTokenizer 或 LxmertTokenizerFast (LXMERT 模型) -
m2m_100
— M2M100Tokenizer (M2M100 模型) -
marian
— MarianTokenizer (Marian 模型) -
mbart
— MBartTokenizer 或 MBartTokenizerFast (mBART 模型) -
mbart50
— MBart50Tokenizer 或 MBart50TokenizerFast (mBART-50 模型) -
mega
— RobertaTokenizer 或 RobertaTokenizerFast (MEGA 模型) -
megatron-bert
— BertTokenizer 或 BertTokenizerFast (Megatron-BERT 模型) -
mgp-str
— MgpstrTokenizer (MGP-STR 模型) -
mistral
— LlamaTokenizer 或 LlamaTokenizerFast (Mistral 模型) -
mixtral
— LlamaTokenizer 或 LlamaTokenizerFast (Mixtral 模型) -
mluke
— MLukeTokenizer (mLUKE 模型) -
mobilebert
— MobileBertTokenizer 或 MobileBertTokenizerFast (MobileBERT 模型) -
mpnet
— MPNetTokenizer 或 MPNetTokenizerFast (MPNet 模型) -
mpt
— GPTNeoXTokenizerFast (MPT 模型) -
mra
— RobertaTokenizer 或 RobertaTokenizerFast (MRA 模型) -
mt5
— MT5Tokenizer 或 MT5TokenizerFast (MT5 模型) -
musicgen
— T5Tokenizer 或 T5TokenizerFast (MusicGen 模型) -
mvp
— MvpTokenizer 或 MvpTokenizerFast (MVP 模型) -
nezha
— BertTokenizer 或 BertTokenizerFast (Nezha 模型) -
nllb
— NllbTokenizer 或 NllbTokenizerFast (NLLB 模型) -
nllb-moe
— NllbTokenizer 或 NllbTokenizerFast (NLLB-MOE 模型) -
nystromformer
— AlbertTokenizer 或 AlbertTokenizerFast(Nyströmformer 模型) -
oneformer
— CLIPTokenizer 或 CLIPTokenizerFast(OneFormer 模型) -
openai-gpt
— OpenAIGPTTokenizer 或 OpenAIGPTTokenizerFast(OpenAI GPT 模型) -
opt
— GPT2Tokenizer 或 GPT2TokenizerFast(OPT 模型) -
owlv2
— CLIPTokenizer 或 CLIPTokenizerFast(OWLv2 模型) -
owlvit
— CLIPTokenizer 或 CLIPTokenizerFast(OWL-ViT 模型) -
pegasus
— PegasusTokenizer 或 PegasusTokenizerFast(Pegasus 模型) -
pegasus_x
— PegasusTokenizer 或 PegasusTokenizerFast(PEGASUS-X 模型) -
perceiver
— PerceiverTokenizer(Perceiver 模型) -
persimmon
— LlamaTokenizer 或 LlamaTokenizerFast(Persimmon 模型) -
phi
— CodeGenTokenizer 或 CodeGenTokenizerFast(Phi 模型) -
phobert
— PhobertTokenizer(PhoBERT 模型) -
pix2struct
— T5Tokenizer 或 T5TokenizerFast(Pix2Struct 模型) -
plbart
— PLBartTokenizer(PLBart 模型) -
prophetnet
— ProphetNetTokenizer(ProphetNet 模型) -
qdqbert
— BertTokenizer 或 BertTokenizerFast(QDQBert 模型) -
qwen2
— Qwen2Tokenizer 或 Qwen2TokenizerFast(Qwen2 模型) -
rag
— RagTokenizer(RAG 模型) -
realm
— RealmTokenizer 或 RealmTokenizerFast(REALM 模型) -
reformer
— ReformerTokenizer 或 ReformerTokenizerFast (Reformer 模型) -
rembert
— RemBertTokenizer 或 RemBertTokenizerFast (RemBERT 模型) -
retribert
— RetriBertTokenizer 或 RetriBertTokenizerFast (RetriBERT 模型) -
roberta
— RobertaTokenizer 或 RobertaTokenizerFast (RoBERTa 模型) -
roberta-prelayernorm
— RobertaTokenizer 或 RobertaTokenizerFast (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertTokenizer (RoCBert 模型) -
roformer
— RoFormerTokenizer 或 RoFormerTokenizerFast (RoFormer 模型) -
rwkv
— GPTNeoXTokenizerFast (RWKV 模型) -
seamless_m4t
— SeamlessM4TTokenizer 或 SeamlessM4TTokenizerFast (SeamlessM4T 模型) -
seamless_m4t_v2
— SeamlessM4TTokenizer 或 SeamlessM4TTokenizerFast (SeamlessM4Tv2 模型) -
siglip
— SiglipTokenizer (SigLIP 模型) -
speech_to_text
— Speech2TextTokenizer (Speech2Text 模型) -
speech_to_text_2
— Speech2Text2Tokenizer (Speech2Text2 模型) -
speecht5
— SpeechT5Tokenizer (SpeechT5 模型) -
splinter
— SplinterTokenizer 或 SplinterTokenizerFast (Splinter 模型) -
squeezebert
— SqueezeBertTokenizer 或 SqueezeBertTokenizerFast (SqueezeBERT 模型) -
switch_transformers
— T5Tokenizer 或 T5TokenizerFast (SwitchTransformers 模型) -
t5
— T5Tokenizer 或 T5TokenizerFast (T5 模型) -
tapas
— TapasTokenizer (TAPAS 模型) -
tapex
— TapexTokenizer (TAPEX 模型) -
transfo-xl
— TransfoXLTokenizer (Transformer-XL 模型) -
tvp
— BertTokenizer 或 BertTokenizerFast (TVP 模型) -
umt5
— T5Tokenizer 或 T5TokenizerFast (UMT5 模型) -
vilt
— BertTokenizer 或 BertTokenizerFast (ViLT 模型) -
vipllava
— LlamaTokenizer 或 LlamaTokenizerFast (VipLlava 模型) -
visual_bert
— BertTokenizer 或 BertTokenizerFast (VisualBERT 模型) -
vits
— VitsTokenizer (VITS 模型) -
wav2vec2
— Wav2Vec2CTCTokenizer (Wav2Vec2 模型) -
wav2vec2-bert
— Wav2Vec2CTCTokenizer (Wav2Vec2-BERT 模型) -
wav2vec2-conformer
— Wav2Vec2CTCTokenizer (Wav2Vec2-Conformer 模型) -
wav2vec2_phoneme
— Wav2Vec2PhonemeCTCTokenizer (Wav2Vec2Phoneme 模型) -
whisper
— WhisperTokenizer 或 WhisperTokenizerFast (Whisper 模型) -
xclip
— CLIPTokenizer 或 CLIPTokenizerFast (X-CLIP 模型) -
xglm
— XGLMTokenizer 或 XGLMTokenizerFast (XGLM 模型) -
xlm
— XLMTokenizer (XLM 模型) -
xlm-prophetnet
— XLMProphetNetTokenizer (XLM-ProphetNet 模型) -
xlm-roberta
— XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (XLM-RoBERTa-XL 模型) -
xlnet
— XLNetTokenizer 或 XLNetTokenizerFast (XLNet 模型) -
xmod
— XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (X-MOD 模型) -
yoso
— AlbertTokenizer 或 AlbertTokenizerFast(YOSO 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoTokenizer
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
>>> # Download vocabulary from huggingface.co and define model-specific arguments
>>> tokenizer = AutoTokenizer.from_pretrained("roberta-base", add_prefix_space=True)
register
< source >
代码语言:javascript复制( config_class slow_tokenizer_class = None fast_tokenizer_class = None exist_ok = False )
参数
-
config_class
(PretrainedConfig) — 与要注册的模型对应的配置。 -
slow_tokenizer_class
(PretrainedTokenizer
, optional) — 要注册的慢速分词器。 -
fast_tokenizer_class
(PretrainedTokenizerFast
, optional) — 要注册的快速分词器。
在此映射中注册一个新的分词器。
AutoFeatureExtractor
class transformers.AutoFeatureExtractor
< source >
代码语言:javascript复制( )
这是一个通用的特征提取器类,在使用 AutoFeatureExtractor.from_pretrained() 类方法创建时,将实例化为库的特征提取器类之一。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_pretrained
< source >
代码语言:javascript复制( pretrained_model_name_or_path **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 这可以是:- 一个字符串,预训练特征提取器的 模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 方法保存的特征提取器文件的 目录 路径,例如
./my_model_directory/
。 - 一个保存的特征提取器 JSON 文件 的路径或 URL,例如
./my_model_directory/preprocessor_config.json
。
- 一个字符串,预训练特征提取器的 模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
-
cache_dir
(str
或os.PathLike
, optional) — 下载的预训练模型特征提取器应该缓存在其中的目录路径,如果不想使用标准缓存。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制(重新)下载特征提取器文件并覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,则尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 -
token
(str
或 bool, optional) — 用作远程文件的 HTTP 令牌授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
return_unused_kwargs
(bool
, optional, defaults toFalse
) — 如果为False
,则此函数仅返回最终的特征提取器对象。如果为True
,则此函数返回一个Tuple(feature_extractor, unused_kwargs)
,其中 unused_kwargs 是一个字典,包含那些未被用于更新feature_extractor
的键/值对:即kwargs
的一部分,未被用于更新feature_extractor
且被忽略的部分。 -
trust_remote_code
(bool
, optional, 默认为False
) — 是否允许在 Hub 上定义自定义模型并在其自己的建模文件中执行。只有对您信任的存储库以及您已阅读代码的情况下,才应将此选项设置为True
,因为它将在本地机器上执行 Hub 上存在的代码。 -
kwargs
(Dict[str, Any]
, optional) — 任何键为特征提取器属性的 kwargs 中的值将用于覆盖加载的值。关于键/值对中键不是特征提取器属性的行为由return_unused_kwargs
关键参数控制。
从预训练模型词汇表中实例化库中的特征提取器类之一。
要实例化的特征提取器类是根据配置对象的 model_type
属性(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载)选择的,或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
audio-spectrogram-transformer
— ASTFeatureExtractor (Audio Spectrogram Transformer 模型) -
beit
— BeitFeatureExtractor (BEiT 模型) -
chinese_clip
— ChineseCLIPFeatureExtractor (Chinese-CLIP 模型) -
clap
— ClapFeatureExtractor (CLAP 模型) -
clip
— CLIPFeatureExtractor (CLIP 模型) -
clipseg
— ViTFeatureExtractor (CLIPSeg 模型) -
clvp
— ClvpFeatureExtractor (CLVP 模型) -
conditional_detr
— ConditionalDetrFeatureExtractor (Conditional DETR 模型) -
convnext
— ConvNextFeatureExtractor (ConvNeXT 模型) -
cvt
— ConvNextFeatureExtractor (CvT 模型) -
data2vec-audio
— Wav2Vec2FeatureExtractor (Data2VecAudio 模型) -
data2vec-vision
— BeitFeatureExtractor (Data2VecVision 模型) -
deformable_detr
— DeformableDetrFeatureExtractor (Deformable DETR 模型) -
deit
— DeiTFeatureExtractor (DeiT 模型) -
detr
— DetrFeatureExtractor (DETR 模型) -
dinat
— ViTFeatureExtractor (DiNAT 模型) -
donut-swin
— DonutFeatureExtractor (DonutSwin 模型) -
dpt
— DPTFeatureExtractor (DPT 模型) -
encodec
— EncodecFeatureExtractor (EnCodec 模型) -
flava
— FlavaFeatureExtractor (FLAVA 模型) -
glpn
— GLPNFeatureExtractor (GLPN 模型) -
groupvit
— CLIPFeatureExtractor (GroupViT 模型) -
hubert
— Wav2Vec2FeatureExtractor (Hubert 模型) -
imagegpt
— ImageGPTFeatureExtractor (ImageGPT 模型) -
layoutlmv2
— LayoutLMv2FeatureExtractor (LayoutLMv2 模型) -
layoutlmv3
— LayoutLMv3FeatureExtractor (LayoutLMv3 模型) -
levit
— LevitFeatureExtractor (LeViT 模型) -
maskformer
— MaskFormerFeatureExtractor (MaskFormer 模型) -
mctct
— MCTCTFeatureExtractor (M-CTC-T 模型) -
mobilenet_v1
— MobileNetV1FeatureExtractor (MobileNetV1 模型) -
mobilenet_v2
— MobileNetV2FeatureExtractor (MobileNetV2 模型) -
mobilevit
— MobileViTFeatureExtractor (MobileViT 模型) -
nat
— ViTFeatureExtractor (NAT 模型) -
owlvit
— OwlViTFeatureExtractor (OWL-ViT 模型) -
perceiver
— PerceiverFeatureExtractor (Perceiver 模型) -
poolformer
— PoolFormerFeatureExtractor (PoolFormer 模型) -
pop2piano
— Pop2PianoFeatureExtractor (Pop2Piano 模型) -
regnet
— ConvNextFeatureExtractor (RegNet 模型) -
resnet
— ConvNextFeatureExtractor (ResNet 模型) -
seamless_m4t
— SeamlessM4TFeatureExtractor (SeamlessM4T 模型) -
seamless_m4t_v2
— SeamlessM4TFeatureExtractor (SeamlessM4Tv2 模型) -
segformer
— SegformerFeatureExtractor (SegFormer 模型) -
sew
— Wav2Vec2FeatureExtractor (SEW 模型) -
sew-d
— Wav2Vec2FeatureExtractor (SEW-D 模型) -
speech_to_text
— Speech2TextFeatureExtractor (Speech2Text 模型) -
speecht5
— SpeechT5FeatureExtractor (SpeechT5 模型) -
swiftformer
— ViTFeatureExtractor (SwiftFormer 模型) -
swin
— ViTFeatureExtractor (Swin Transformer 模型) -
swinv2
— ViTFeatureExtractor (Swin Transformer V2 模型) -
table-transformer
— DetrFeatureExtractor (Table Transformer 模型) -
timesformer
— VideoMAEFeatureExtractor (TimeSformer 模型) -
tvlt
— TvltFeatureExtractor (TVLT 模型) -
unispeech
— Wav2Vec2FeatureExtractor (UniSpeech 模型) -
unispeech-sat
— Wav2Vec2FeatureExtractor (UniSpeechSat 模型) -
univnet
— UnivNetFeatureExtractor (UnivNet 模型) -
van
— ConvNextFeatureExtractor (VAN 模型) -
videomae
— VideoMAEFeatureExtractor (VideoMAE 模型) -
vilt
— ViltFeatureExtractor (ViLT 模型) -
vit
— ViTFeatureExtractor (ViT 模型) -
vit_mae
— ViTFeatureExtractor (ViTMAE 模型) -
vit_msn
— ViTFeatureExtractor (ViTMSN 模型) -
wav2vec2
— Wav2Vec2FeatureExtractor (Wav2Vec2 模型) -
wav2vec2-bert
— Wav2Vec2FeatureExtractor (Wav2Vec2-BERT 模型) -
wav2vec2-conformer
— Wav2Vec2FeatureExtractor (Wav2Vec2-Conformer 模型) -
wavlm
— Wav2Vec2FeatureExtractor (WavLM 模型) -
whisper
— WhisperFeatureExtractor (Whisper 模型) -
xclip
— CLIPFeatureExtractor (X-CLIP 模型) -
yolos
— YolosFeatureExtractor (YOLOS 模型)
当您想使用私有模型时,需要传递 token=True
。
示例:
代码语言:javascript复制>>> from transformers import AutoFeatureExtractor
>>> # Download feature extractor from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
register
< source >
代码语言:javascript复制( config_class feature_extractor_class exist_ok = False )
参数
-
config_class
(PretrainedConfig) — 要注册的模型对应的配置。 -
feature_extractor_class
(FeatureExtractorMixin
) — 要注册的特征提取器。
为此类注册一个新的特征提取器。
AutoImageProcessor
class transformers.AutoImageProcessor
< source >
代码语言:javascript复制( )
这是一个通用的图像处理器类,在使用 AutoImageProcessor.from_pretrained() 类方法创建时,将被实例化为库中的图像处理器类之一。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_pretrained
< source >
代码语言:javascript复制( pretrained_model_name_or_path **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 这可以是:- 一个预训练图像处理器的 模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 方法保存的图像处理器文件的 目录 路径,例如,
./my_model_directory/
。 - 一个保存的图像处理器 JSON 文件 的路径或 URL,例如,
./my_model_directory/preprocessor_config.json
。
- 一个预训练图像处理器的 模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
cache_dir
(str
或os.PathLike
, 可选) — 预下载的预训练模型图像处理器应该缓存在其中的目录路径,如果不应使用标准缓存。 -
force_download
(bool
, 可选, 默认为False
) — 是否强制(重新)下载图像处理器文件并覆盖缓存版本(如果存在)。 -
resume_download
(bool
, 可选, 默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,则尝试恢复下载。 -
proxies
(Dict[str, str]
, 可选) — 一个按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求上使用。 -
token
(str
或 bool, 可选) — 用作远程文件的 HTTP bearer 授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 -
revision
(str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 id,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
return_unused_kwargs
(bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终的图像处理器对象。如果为True
,则此函数返回一个Tuple(image_processor, unused_kwargs)
,其中 unused_kwargs 是一个字典,包含那些键/值对,其键不是图像处理器属性:即kwargs
中未被用于更新image_processor
且被忽略的部分。 -
trust_remote_code
(bool
, 可选, 默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
kwargs
(Dict[str, Any]
, 可选) — 任何键为图像处理器属性的 kwargs 中的值将用于覆盖加载的值。关于键/值对中键 不是 图像处理器属性的行为由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表中实例化库中的一个图像处理器类。
要实例化的图像处理器类是根据配置对象的 model_type
属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
align
— EfficientNetImageProcessor (ALIGN 模型) -
beit
— BeitImageProcessor (BEiT 模型) -
bit
— BitImageProcessor (BiT 模型) -
blip
— BlipImageProcessor (BLIP 模型) -
blip-2
— BlipImageProcessor (BLIP-2 模型) -
bridgetower
— BridgeTowerImageProcessor (BridgeTower 模型) -
chinese_clip
— ChineseCLIPImageProcessor (Chinese-CLIP 模型) -
clip
— CLIPImageProcessor (CLIP 模型) -
clipseg
— ViTImageProcessor (CLIPSeg 模型) -
conditional_detr
— ConditionalDetrImageProcessor (Conditional DETR 模型) -
convnext
— ConvNextImageProcessor (ConvNeXT 模型) -
convnextv2
— ConvNextImageProcessor (ConvNeXTV2 模型) -
cvt
— ConvNextImageProcessor (CvT 模型) -
data2vec-vision
— BeitImageProcessor (Data2VecVision 模型) -
deformable_detr
— DeformableDetrImageProcessor (Deformable DETR 模型) -
deit
— DeiTImageProcessor (DeiT 模型) -
deta
— DetaImageProcessor (DETA 模型) -
detr
— DetrImageProcessor (DETR 模型) -
dinat
— ViTImageProcessor (DiNAT 模型) -
dinov2
— BitImageProcessor (DINOv2 模型) -
donut-swin
— DonutImageProcessor (DonutSwin 模型) -
dpt
— DPTImageProcessor (DPT 模型) -
efficientformer
— EfficientFormerImageProcessor (EfficientFormer 模型) -
efficientnet
— EfficientNetImageProcessor (EfficientNet 模型) -
flava
— FlavaImageProcessor (FLAVA 模型) -
focalnet
— BitImageProcessor (FocalNet 模型) -
fuyu
— FuyuImageProcessor (Fuyu 模型) -
git
— CLIPImageProcessor (GIT 模型) -
glpn
— GLPNImageProcessor (GLPN 模型) -
groupvit
— CLIPImageProcessor (GroupViT 模型) -
idefics
— IdeficsImageProcessor (IDEFICS 模型) -
imagegpt
— ImageGPTImageProcessor (ImageGPT 模型) -
instructblip
— BlipImageProcessor (InstructBLIP 模型) -
kosmos-2
— CLIPImageProcessor (KOSMOS-2 模型) -
layoutlmv2
— LayoutLMv2ImageProcessor (LayoutLMv2 模型) -
layoutlmv3
— LayoutLMv3ImageProcessor (LayoutLMv3 模型) -
levit
— LevitImageProcessor (LeViT 模型) -
llava
— CLIPImageProcessor (LLaVa 模型) -
mask2former
— Mask2FormerImageProcessor (Mask2Former 模型) -
maskformer
— MaskFormerImageProcessor (MaskFormer 模型) -
mgp-str
— ViTImageProcessor (MGP-STR 模型) -
mobilenet_v1
— MobileNetV1ImageProcessor (MobileNetV1 模型) -
mobilenet_v2
— MobileNetV2ImageProcessor (MobileNetV2 模型) -
mobilevit
— MobileViTImageProcessor (MobileViT 模型) -
mobilevitv2
— MobileViTImageProcessor (MobileViTV2 模型) -
nat
— ViTImageProcessor (NAT 模型) -
nougat
— NougatImageProcessor (Nougat 模型) -
oneformer
— OneFormerImageProcessor (OneFormer 模型) -
owlv2
— Owlv2ImageProcessor (OWLv2 模型) -
owlvit
— OwlViTImageProcessor (OWL-ViT 模型) -
perceiver
— PerceiverImageProcessor (Perceiver 模型) -
pix2struct
— Pix2StructImageProcessor (Pix2Struct 模型) -
poolformer
— PoolFormerImageProcessor (PoolFormer 模型) -
pvt
— PvtImageProcessor (PVT 模型) -
regnet
— ConvNextImageProcessor (RegNet 模型) -
resnet
— ConvNextImageProcessor (ResNet 模型) -
sam
— SamImageProcessor (SAM 模型) -
segformer
— SegformerImageProcessor (SegFormer 模型) -
siglip
— SiglipImageProcessor (SigLIP 模型) -
swiftformer
— ViTImageProcessor (SwiftFormer 模型) -
swin
— ViTImageProcessor (Swin Transformer model) -
swin2sr
— Swin2SRImageProcessor (Swin2SR model) -
swinv2
— ViTImageProcessor (Swin Transformer V2 model) -
table-transformer
— DetrImageProcessor (Table Transformer model) -
timesformer
— VideoMAEImageProcessor (TimeSformer model) -
tvlt
— TvltImageProcessor (TVLT model) -
tvp
— TvpImageProcessor (TVP model) -
upernet
— SegformerImageProcessor (UPerNet model) -
van
— ConvNextImageProcessor (VAN model) -
videomae
— VideoMAEImageProcessor (VideoMAE model) -
vilt
— ViltImageProcessor (ViLT model) -
vipllava
— CLIPImageProcessor (VipLlava model) -
vit
— ViTImageProcessor (ViT model) -
vit_hybrid
— ViTHybridImageProcessor (ViT Hybrid model) -
vit_mae
— ViTImageProcessor (ViTMAE model) -
vit_msn
— ViTImageProcessor (ViTMSN model) -
vitmatte
— VitMatteImageProcessor (ViTMatte model) -
xclip
— CLIPImageProcessor (X-CLIP model) -
yolos
— YolosImageProcessor (YOLOS model)
当您想使用私有模型时,需要传递token=True
。
示例:
代码语言:javascript复制>>> from transformers import AutoImageProcessor
>>> # Download image processor from huggingface.co and cache.
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
register
< source >
代码语言:javascript复制( config_class image_processor_class exist_ok = False )
参数
-
config_class
(PretrainedConfig) — 与要注册的模型对应的配置。 -
image_processor_class
(ImageProcessingMixin) — 要注册的图像处理器。
为这个类注册一个新的图像处理器。
AutoProcessor
class transformers.AutoProcessor
< source >
代码语言:javascript复制( )
这是一个通用的处理器类,在使用 AutoProcessor.from_pretrained()类方法创建时,将作为库的处理器类之一实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_pretrained
< source >
代码语言:javascript复制( pretrained_model_name_or_path **kwargs )
参数
-
pretrained_model_name_or_path
(str
oros.PathLike
) — 这可以是:- 一个字符串,预训练特征提取器的模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用
save_pretrained()
方法保存的处理器文件,例如./my_model_directory/
。
- 一个字符串,预训练特征提取器的模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
-
cache_dir
(str
或os.PathLike
,可选) — 下载的预训练模型特征提取器应该缓存在其中的目录路径,如果不应使用标准缓存。 -
force_download
(bool
,可选,默认为False
) — 是否强制(重新)下载特征提取器文件并覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,则尝试恢复下载。 -
proxies
(Dict[str, str]
,可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128','http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
token
(str
或bool,可选) — 用作远程文件的 HTTP bearer 授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
)。 -
revision
(str
,可选,默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
return_unused_kwargs
(bool
,可选,默认为False
) — 如果为False
,则此函数仅返回最终特征提取器对象。如果为True
,则此函数返回一个Tuple(feature_extractor, unused_kwargs)
,其中unused_kwargs是一个字典,包含未使用的键/值对,这些键不是特征提取器属性:即kwargs
的一部分,未用于更新feature_extractor
且被忽略。 -
trust_remote_code
(bool
,可选,默认为False
) — 是否允许在 Hub 上定义自定义模型的代码。此选项应仅对您信任的存储库设置为True
,并且您已阅读代码,因为它将在本地计算机上执行 Hub 上存在的代码。 -
kwargs
(Dict[str, Any]
,可选) — 任何键为特征提取器属性的 kwargs 中的值将用于覆盖加载的值。关于键/值对中键不是特征提取器属性的行为由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表中实例化库中的处理器类之一。
要实例化的处理器类是根据配置对象的model_type
属性选择的(如果可能,作为参数传递或从pretrained_model_name_or_path
加载):
-
align
— AlignProcessor(ALIGN 模型) -
altclip
— AltCLIPProcessor(AltCLIP 模型) -
bark
— BarkProcessor(Bark 模型) -
blip
— BlipProcessor(BLIP 模型) -
blip-2
— Blip2Processor(BLIP-2 模型) -
bridgetower
— BridgeTowerProcessor(BridgeTower 模型) -
chinese_clip
— ChineseCLIPProcessor(Chinese-CLIP 模型) -
clap
— ClapProcessor (CLAP 模型) -
clip
— CLIPProcessor (CLIP 模型) -
clipseg
— CLIPSegProcessor (CLIPSeg 模型) -
clvp
— ClvpProcessor (CLVP 模型) -
flava
— FlavaProcessor (FLAVA 模型) -
fuyu
— FuyuProcessor (Fuyu 模型) -
git
— GitProcessor (GIT 模型) -
groupvit
— CLIPProcessor (GroupViT 模型) -
hubert
— Wav2Vec2Processor (Hubert 模型) -
idefics
— IdeficsProcessor (IDEFICS 模型) -
instructblip
— InstructBlipProcessor (InstructBLIP 模型) -
kosmos-2
— Kosmos2Processor (KOSMOS-2 模型) -
layoutlmv2
— LayoutLMv2Processor (LayoutLMv2 模型) -
layoutlmv3
— LayoutLMv3Processor (LayoutLMv3 模型) -
llava
— LlavaProcessor (LLaVa 模型) -
markuplm
— MarkupLMProcessor (MarkupLM 模型) -
mctct
— MCTCTProcessor (M-CTC-T 模型) -
mgp-str
— MgpstrProcessor (MGP-STR 模型) -
oneformer
— OneFormerProcessor (OneFormer 模型) -
owlv2
— Owlv2Processor (OWLv2 模型) -
owlvit
— OwlViTProcessor (OWL-ViT 模型) -
pix2struct
— Pix2StructProcessor (Pix2Struct 模型) -
pop2piano
— Pop2PianoProcessor (Pop2Piano 模型) -
sam
— SamProcessor (SAM 模型) -
seamless_m4t
— SeamlessM4TProcessor (SeamlessM4T 模型) -
sew
— Wav2Vec2Processor (SEW 模型) -
sew-d
— Wav2Vec2Processor (SEW-D 模型) -
siglip
— SiglipProcessor (SigLIP 模型) -
speech_to_text
— Speech2TextProcessor (Speech2Text 模型) -
speech_to_text_2
— Speech2Text2Processor (Speech2Text2 模型) -
speecht5
— SpeechT5Processor (SpeechT5 模型) -
trocr
— TrOCRProcessor(TrOCR 模型) -
tvlt
— TvltProcessor(TVLT 模型) -
tvp
— TvpProcessor(TVP 模型) -
unispeech
— Wav2Vec2Processor(UniSpeech 模型) -
unispeech-sat
— Wav2Vec2Processor(UniSpeechSat 模型) -
vilt
— ViltProcessor(ViLT 模型) -
vipllava
— LlavaProcessor(VipLlava 模型) -
vision-text-dual-encoder
— VisionTextDualEncoderProcessor(VisionTextDualEncoder 模型) -
wav2vec2
— Wav2Vec2Processor(Wav2Vec2 模型) -
wav2vec2-bert
— Wav2Vec2Processor(Wav2Vec2-BERT 模型) -
wav2vec2-conformer
— Wav2Vec2Processor(Wav2Vec2-Conformer 模型) -
wavlm
— Wav2Vec2Processor(WavLM 模型) -
whisper
— WhisperProcessor(Whisper 模型) -
xclip
— XCLIPProcessor(X-CLIP 模型)
当您想使用私有模型时,需要传递token=True
。
示例:
代码语言:javascript复制>>> from transformers import AutoProcessor
>>> # Download processor from huggingface.co and cache.
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
register
< source >
代码语言:javascript复制( config_class processor_class exist_ok = False )
参数
-
config_class
(PretrainedConfig) — 与要注册的模型对应的配置。 -
processor_class
(FeatureExtractorMixin
) — 要注册的处理器。
为这个类注册一个新的处理器。
通用模型类
以下自动类可用于实例化一个基本模型类,而无需特定头部。
AutoModel
class transformers.AutoModel
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用模型类,当使用 class method 或 class method 创建时,将作为库的基本模型类之一实例化。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 选择要实例化的模型类基于配置类:- ASTConfig 配置类:ASTModel(音频频谱变换器模型)
- AlbertConfig 配置类:AlbertModel(ALBERT 模型)
- AlignConfig 配置类: AlignModel (ALIGN 模型)
- AltCLIPConfig 配置类: AltCLIPModel (AltCLIP 模型)
- AutoformerConfig 配置类: AutoformerModel (Autoformer 模型)
- BarkConfig 配置类: BarkModel (Bark 模型)
- BartConfig 配置类: BartModel (BART 模型)
- BeitConfig 配置类: BeitModel (BEiT 模型)
- BertConfig 配置类: BertModel (BERT 模型)
- BertGenerationConfig 配置类: BertGenerationEncoder (Bert Generation 模型)
- BigBirdConfig 配置类: BigBirdModel (BigBird 模型)
- BigBirdPegasusConfig 配置类: BigBirdPegasusModel (BigBird-Pegasus 模型)
- BioGptConfig 配置类: BioGptModel (BioGpt 模型)
- BitConfig 配置类: BitModel (BiT 模型)
- BlenderbotConfig 配置类: BlenderbotModel (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: BlenderbotSmallModel (BlenderbotSmall 模型)
- Blip2Config 配置类: Blip2Model (BLIP-2 模型)
- BlipConfig 配置类: BlipModel (BLIP 模型)
- BloomConfig 配置类: BloomModel (BLOOM 模型)
- BridgeTowerConfig 配置类: BridgeTowerModel (BridgeTower 模型)
- BrosConfig 配置类: BrosModel (BROS 模型)
- CLIPConfig 配置类: CLIPModel (CLIP 模型)
- CLIPSegConfig 配置类: CLIPSegModel (CLIPSeg 模型)
- CLIPVisionConfig 配置类: CLIPVisionModel (CLIPVisionModel 模型)
- CTRLConfig 配置类: CTRLModel (CTRL 模型)
- CamembertConfig 配置类: CamembertModel (CamemBERT 模型)
- CanineConfig 配置类: CanineModel (CANINE 模型)
- ChineseCLIPConfig 配置类: ChineseCLIPModel (Chinese-CLIP 模型)
- ClapConfig 配置类: ClapModel (CLAP 模型)
- ClvpConfig 配置类: ClvpModelForConditionalGeneration (CLVP 模型)
- CodeGenConfig 配置类: CodeGenModel (CodeGen 模型)
- ConditionalDetrConfig 配置类: ConditionalDetrModel (Conditional DETR 模型)
- ConvBertConfig 配置类: ConvBertModel (ConvBERT 模型)
- ConvNextConfig 配置类: ConvNextModel (ConvNeXT 模型)
- ConvNextV2Config 配置类: ConvNextV2Model (ConvNeXTV2 模型)
- CpmAntConfig 配置类: CpmAntModel (CPM-Ant 模型)
- CvtConfig 配置类: CvtModel (CvT 模型)
- DPRConfig 配置类: DPRQuestionEncoder (DPR 模型)
- DPTConfig 配置类: DPTModel (DPT 模型)
- Data2VecAudioConfig 配置类: Data2VecAudioModel (Data2VecAudio 模型)
- Data2VecTextConfig 配置类: Data2VecTextModel (Data2VecText 模型)
- Data2VecVisionConfig 配置类: Data2VecVisionModel (Data2VecVision 模型)
- DebertaConfig 配置类: DebertaModel (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2Model (DeBERTa-v2 模型)
- DecisionTransformerConfig 配置类: DecisionTransformerModel (Decision Transformer 模型)
- DeformableDetrConfig 配置类: DeformableDetrModel (Deformable DETR 模型)
- DeiTConfig 配置类: DeiTModel (DeiT 模型)
- DetaConfig 配置类: DetaModel (DETA 模型)
- DetrConfig 配置类: DetrModel (DETR 模型)
- DinatConfig 配置类: DinatModel (DiNAT 模型)
- Dinov2Config 配置类: Dinov2Model (DINOv2 模型)
- DistilBertConfig 配置类: DistilBertModel (DistilBERT 模型)
- DonutSwinConfig 配置类: DonutSwinModel (DonutSwin 模型)
- EfficientFormerConfig 配置类: EfficientFormerModel (EfficientFormer 模型)
- EfficientNetConfig 配置类: EfficientNetModel (EfficientNet 模型)
- ElectraConfig 配置类: ElectraModel (ELECTRA 模型)
- EncodecConfig 配置类: EncodecModel (EnCodec 模型)
- ErnieConfig 配置类: ErnieModel (ERNIE 模型)
- ErnieMConfig 配置类: ErnieMModel (ErnieM 模型)
- EsmConfig 配置类: EsmModel (ESM 模型)
- FNetConfig 配置类: FNetModel (FNet 模型)
- FSMTConfig 配置类: FSMTModel (FairSeq 机器翻译模型)
- FalconConfig 配置类: FalconModel (Falcon 模型)
- FastSpeech2ConformerConfig 配置类: FastSpeech2ConformerModel (FastSpeech2Conformer 模型)
- FlaubertConfig 配置类: FlaubertModel (FlauBERT 模型)
- FlavaConfig 配置类: FlavaModel (FLAVA 模型)
- FocalNetConfig 配置类: FocalNetModel (FocalNet 模型)
- FunnelConfig 配置类: FunnelModel 或 FunnelBaseModel (Funnel Transformer 模型)
- GLPNConfig 配置类: GLPNModel (GLPN 模型)
- GPT2Config 配置类: GPT2Model (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeModel (GPTBigCode 模型)
- GPTJConfig 配置类: GPTJModel (GPT-J 模型)
- GPTNeoConfig 配置类: GPTNeoModel (GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXModel (GPT NeoX 模型)
- GPTNeoXJapaneseConfig 配置类: GPTNeoXJapaneseModel (GPT NeoX 日语模型)
- GPTSanJapaneseConfig 配置类: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- GitConfig 配置类: GitModel (GIT 模型)
- GraphormerConfig 配置类: GraphormerModel (Graphormer 模型)
- GroupViTConfig 配置类: GroupViTModel (GroupViT 模型)
- HubertConfig 配置类: HubertModel (Hubert 模型)
- IBertConfig 配置类: IBertModel (I-BERT 模型)
- IdeficsConfig 配置类: IdeficsModel (IDEFICS 模型)
- ImageGPTConfig 配置类: ImageGPTModel (ImageGPT 模型)
- InformerConfig 配置类: InformerModel (Informer 模型)
- JukeboxConfig 配置类: JukeboxModel (Jukebox 模型)
- Kosmos2Config 配置类: Kosmos2Model (KOSMOS-2 模型)
- LEDConfig 配置类: LEDModel (LED 模型)
- LayoutLMConfig 配置类: LayoutLMModel (LayoutLM 模型)
- LayoutLMv2Config 配置类: LayoutLMv2Model (LayoutLMv2 模型)
- LayoutLMv3Config 配置类: LayoutLMv3Model (LayoutLMv3 模型)
- LevitConfig 配置类: LevitModel (LeViT 模型)
- LiltConfig 配置类: LiltModel (LiLT 模型)
- LlamaConfig 配置类: LlamaModel (LLaMA 模型)
- LongT5Config 配置类: LongT5Model (LongT5 模型)
- LongformerConfig 配置类: LongformerModel (Longformer 模型)
- LukeConfig 配置类: LukeModel (LUKE 模型)
- LxmertConfig 配置类: LxmertModel (LXMERT 模型)
- M2M100Config 配置类: M2M100Model (M2M100 模型)
- MBartConfig 配置类: MBartModel (mBART 模型)
- MCTCTConfig 配置类: MCTCTModel (M-CTC-T 模型)
- MPNetConfig 配置类: MPNetModel (MPNet 模型)
- MT5Config 配置类: MT5Model (MT5 模型)
- MarianConfig 配置类: MarianModel (Marian 模型)
- MarkupLMConfig 配置类: MarkupLMModel (MarkupLM 模型)
- Mask2FormerConfig 配置类: Mask2FormerModel (Mask2Former 模型)
- MaskFormerConfig 配置类: MaskFormerModel (MaskFormer 模型)
-
MaskFormerSwinConfig
配置类:MaskFormerSwinModel
(MaskFormerSwin 模型) - MegaConfig 配置类: MegaModel (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertModel (Megatron-BERT 模型)
- MgpstrConfig 配置类: MgpstrForSceneTextRecognition (MGP-STR 模型)
- MistralConfig 配置类: MistralModel (Mistral 模型)
- MixtralConfig 配置类: MixtralModel (Mixtral 模型)
- MobileBertConfig 配置类: MobileBertModel (MobileBERT 模型)
- MobileNetV1Config 配置类: MobileNetV1Model (MobileNetV1 模型)
- MobileNetV2Config 配置类: MobileNetV2Model (MobileNetV2 模型)
- MobileViTConfig 配置类: MobileViTModel (MobileViT 模型)
- MobileViTV2Config 配置类: MobileViTV2Model (MobileViTV2 模型)
- MptConfig 配置类: MptModel (MPT 模型)
- MraConfig 配置类: MraModel (MRA 模型)
- MvpConfig 配置类: MvpModel (MVP 模型)
- NatConfig 配置类: NatModel (NAT 模型)
- NezhaConfig 配置类: NezhaModel (Nezha 模型)
- NllbMoeConfig 配置类: NllbMoeModel (NLLB-MOE 模型)
- NystromformerConfig 配置类: NystromformerModel (Nyströmformer 模型)
- OPTConfig 配置类: OPTModel (OPT 模型)
- OneFormerConfig 配置类: OneFormerModel (OneFormer 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTModel (OpenAI GPT 模型)
- OpenLlamaConfig 配置类: OpenLlamaModel (OpenLlama 模型)
- OwlViTConfig 配置类: OwlViTModel (OWL-ViT 模型)
- Owlv2Config 配置类: Owlv2Model (OWLv2 模型)
- PLBartConfig 配置类: PLBartModel (PLBart 模型)
- PatchTSMixerConfig 配置类: PatchTSMixerModel (PatchTSMixer 模型)
- PatchTSTConfig 配置类: PatchTSTModel (PatchTST 模型)
- PegasusConfig 配置类: PegasusModel (Pegasus 模型)
- PegasusXConfig 配置类: PegasusXModel (PEGASUS-X 模型)
- PerceiverConfig 配置类: PerceiverModel (Perceiver 模型)
- PersimmonConfig 配置类: PersimmonModel (Persimmon 模型)
- PhiConfig 配置类: PhiModel (Phi 模型)
- PoolFormerConfig 配置类: PoolFormerModel (PoolFormer 模型)
- ProphetNetConfig 配置类: ProphetNetModel (ProphetNet 模型)
- PvtConfig 配置类: PvtModel (PVT 模型)
- QDQBertConfig 配置类: QDQBertModel (QDQBert 模型)
- Qwen2Config 配置类: Qwen2Model (Qwen2 模型)
- ReformerConfig 配置类: ReformerModel (Reformer 模型)
- RegNetConfig 配置类: RegNetModel (RegNet 模型)
- RemBertConfig 配置类: RemBertModel (RemBERT 模型)
- ResNetConfig 配置类: ResNetModel (ResNet 模型)
- RetriBertConfig 配置类: RetriBertModel (RetriBERT 模型)
- RoCBertConfig 配置类: RoCBertModel (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerModel (RoFormer 模型)
- RobertaConfig 配置类: RobertaModel (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- RwkvConfig 配置类: RwkvModel (RWKV 模型)
- SEWConfig 配置类: SEWModel (SEW 模型)
- SEWDConfig 配置类: SEWDModel (SEW-D 模型)
- SamConfig 配置类: SamModel (SAM 模型)
- SeamlessM4TConfig 配置类: SeamlessM4TModel (SeamlessM4T 模型)
- SeamlessM4Tv2Config 配置类: SeamlessM4Tv2Model (SeamlessM4Tv2 模型)
- SegformerConfig 配置类: SegformerModel (SegFormer 模型)
- SiglipConfig 配置类: SiglipModel (SigLIP 模型)
- SiglipVisionConfig 配置类: SiglipVisionModel (SiglipVisionModel 模型)
- Speech2TextConfig 配置类: Speech2TextModel (Speech2Text 模型)
- SpeechT5Config 配置类: SpeechT5Model (SpeechT5 模型)
- SplinterConfig 配置类: SplinterModel (Splinter 模型)
- SqueezeBertConfig 配置类: SqueezeBertModel (SqueezeBERT 模型)
- SwiftFormerConfig 配置类: SwiftFormerModel (SwiftFormer 模型)
- Swin2SRConfig 配置类: Swin2SRModel (Swin2SR 模型)
- SwinConfig 配置类: SwinModel (Swin Transformer 模型)
- Swinv2Config 配置类: Swinv2Model (Swin Transformer V2 模型)
- SwitchTransformersConfig 配置类: SwitchTransformersModel (SwitchTransformers 模型)
- T5Config 配置类: T5Model (T5 模型)
- TableTransformerConfig 配置类: TableTransformerModel (Table Transformer 模型)
- TapasConfig 配置类: TapasModel (TAPAS 模型)
- TimeSeriesTransformerConfig 配置类: TimeSeriesTransformerModel (Time Series Transformer 模型)
- TimesformerConfig 配置类: TimesformerModel (TimeSformer 模型)
-
TimmBackboneConfig
配置类:TimmBackbone
(TimmBackbone 模型) - TrajectoryTransformerConfig 配置类: TrajectoryTransformerModel (轨迹 Transformer 模型)
- TransfoXLConfig 配置类: TransfoXLModel (Transformer-XL 模型)
- TvltConfig 配置类: TvltModel (TVLT 模型)
- TvpConfig 配置类: TvpModel (TVP 模型)
- UMT5Config 配置类: UMT5Model (UMT5 模型)
- UniSpeechConfig 配置类: UniSpeechModel (UniSpeech 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatModel (UniSpeechSat 模型)
- UnivNetConfig 配置类: UnivNetModel (UnivNet 模型)
- VanConfig 配置类: VanModel (VAN 模型)
- ViTConfig 配置类: ViTModel (ViT 模型)
- ViTHybridConfig 配置类: ViTHybridModel (ViT 混合模型)
- ViTMAEConfig 配置类: ViTMAEModel (ViTMAE 模型)
- ViTMSNConfig 配置类: ViTMSNModel (ViTMSN 模型)
- VideoMAEConfig 配置类: VideoMAEModel (VideoMAE 模型)
- ViltConfig 配置类: ViltModel (ViLT 模型)
- VisionTextDualEncoderConfig 配置类: VisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- VisualBertConfig 配置类: VisualBertModel (VisualBERT 模型)
- VitDetConfig 配置类: VitDetModel (VitDet 模型)
- VitsConfig 配置类: VitsModel (VITS 模型)
- VivitConfig 配置类: VivitModel (ViViT 模型)
- Wav2Vec2BertConfig 配置类: Wav2Vec2BertModel (Wav2Vec2-BERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2Model (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerModel (Wav2Vec2-Conformer 模型)
- WavLMConfig 配置类: WavLMModel (WavLM 模型)
- WhisperConfig 配置类: WhisperModel (Whisper 模型)
- XCLIPConfig 配置类: XCLIPModel (X-CLIP 模型)
- XGLMConfig 配置类: XGLMModel (XGLM 模型)
- XLMConfig 配置类: XLMModel (XLM 模型)
- XLMProphetNetConfig 配置类: XLMProphetNetModel (XLM-ProphetNet 模型)
- XLMRobertaConfig 配置类: XLMRobertaModel (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLModel (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetModel (XLNet 模型)
- XmodConfig 配置类: XmodModel (X-MOD 模型)
- YolosConfig 配置类: YolosModel (YOLOS 模型)
- YosoConfig 配置类: YosoModel (YOSO 模型)
从配置中实例化库的基础模型类。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModel
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = AutoModel.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是:- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 url(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应该将配置对象作为config
参数提供。使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型的加载路径比较慢。
- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当以下情况发生时,配置可以被自动加载:- 模型是库提供的模型(使用预训练模型的模型 id字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
state_dict
(Dict[str, torch.Tensor],可选)— 一个状态字典,用于替代从保存的权重文件加载的状态字典。 如果您想从预训练配置创建模型但加载自己的权重,则可以使用此选项。但在这种情况下,您应该检查是否使用 save_pretrained()和 from_pretrained()不是更简单的选项。 -
cache_dir
(str
或os.PathLike
,可选)— 下载预训练模型配置应该被缓存的目录路径,如果不使用标准缓存。 -
from_tf
(bool
,可选,默认为False
)— 从 TensorFlow 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
output_loading_info(bool,
可选,默认为False
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的代码。只有在您信任的存储库中并且已阅读代码的情况下,才应将此选项设置为True
,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供了config
或自动加载,行为不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型实例化库的基本模型类之一。
要实例化的模型类是根据配置对象的 model_type
属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— AlbertModel (ALBERT 模型) -
align
— AlignModel (ALIGN 模型) -
altclip
— AltCLIPModel (AltCLIP 模型) -
audio-spectrogram-transformer
— ASTModel (音频频谱变换器模型) -
autoformer
— AutoformerModel (Autoformer 模型) -
bark
— BarkModel (Bark 模型) -
bart
— BartModel (BART 模型) -
beit
— BeitModel (BEiT 模型) -
bert
— BertModel (BERT 模型) -
bert-generation
— BertGenerationEncoder (Bert Generation 模型) -
big_bird
— BigBirdModel (BigBird 模型) -
bigbird_pegasus
— BigBirdPegasusModel (BigBird-Pegasus 模型) -
biogpt
— BioGptModel (BioGpt 模型) -
bit
— BitModel (BiT 模型) -
blenderbot
— BlenderbotModel (Blenderbot 模型) -
blenderbot-small
— BlenderbotSmallModel (BlenderbotSmall 模型) -
blip
— BlipModel (BLIP 模型) -
blip-2
— Blip2Model (BLIP-2 模型) -
bloom
— BloomModel (BLOOM 模型) -
bridgetower
— BridgeTowerModel (BridgeTower 模型) -
bros
— BrosModel (BROS 模型) -
camembert
— CamembertModel (CamemBERT 模型) -
canine
— CanineModel (CANINE 模型) -
chinese_clip
— ChineseCLIPModel (Chinese-CLIP 模型) -
clap
— ClapModel (CLAP 模型) -
clip
— CLIPModel (CLIP 模型) -
clip_vision_model
— CLIPVisionModel (CLIPVisionModel 模型) -
clipseg
— CLIPSegModel (CLIPSeg 模型) -
clvp
— ClvpModelForConditionalGeneration (CLVP 模型) -
code_llama
— LlamaModel (CodeLlama 模型) -
codegen
— CodeGenModel (CodeGen 模型) -
conditional_detr
— ConditionalDetrModel (Conditional DETR 模型) -
convbert
— ConvBertModel (ConvBERT 模型) -
convnext
— ConvNextModel (ConvNeXT 模型) -
convnextv2
— ConvNextV2Model (ConvNeXTV2 模型) -
cpmant
— CpmAntModel (CPM-Ant 模型) -
ctrl
— CTRLModel (CTRL 模型) -
cvt
— CvtModel (CvT 模型) -
data2vec-audio
— Data2VecAudioModel (Data2VecAudio 模型) -
data2vec-text
— Data2VecTextModel (Data2VecText 模型) -
data2vec-vision
— Data2VecVisionModel (Data2VecVision 模型) -
deberta
— DebertaModel (DeBERTa 模型) -
deberta-v2
— DebertaV2Model (DeBERTa-v2 模型) -
decision_transformer
— DecisionTransformerModel (Decision Transformer 模型) -
deformable_detr
— DeformableDetrModel (Deformable DETR 模型) -
deit
— DeiTModel (DeiT 模型) -
deta
— DetaModel (DETA 模型) -
detr
— DetrModel (DETR 模型) -
dinat
— DinatModel (DiNAT 模型) -
dinov2
— Dinov2Model (DINOv2 模型) -
distilbert
— DistilBertModel (DistilBERT 模型) -
donut-swin
— DonutSwinModel (DonutSwin 模型) -
dpr
— DPRQuestionEncoder (DPR 模型) -
dpt
— DPTModel (DPT 模型) -
efficientformer
— EfficientFormerModel (EfficientFormer 模型) -
efficientnet
— EfficientNetModel (EfficientNet 模型) -
electra
— ElectraModel (ELECTRA 模型) -
encodec
— EncodecModel (EnCodec 模型) -
ernie
— ErnieModel (ERNIE 模型) -
ernie_m
— ErnieMModel (ErnieM 模型) -
esm
— EsmModel (ESM 模型) -
falcon
— FalconModel (Falcon 模型) -
fastspeech2_conformer
— FastSpeech2ConformerModel (FastSpeech2Conformer 模型) -
flaubert
— FlaubertModel (FlauBERT 模型) -
flava
— FlavaModel (FLAVA 模型) -
fnet
— FNetModel (FNet 模型) -
focalnet
— FocalNetModel (FocalNet 模型) -
fsmt
— FSMTModel (FairSeq 机器翻译模型) -
funnel
— FunnelModel 或 FunnelBaseModel (Funnel Transformer 模型) -
git
— GitModel (GIT 模型) -
glpn
— GLPNModel (GLPN 模型) -
gpt-sw3
— GPT2Model (GPT-Sw3 模型) -
gpt2
— GPT2Model (OpenAI GPT-2 模型) -
gpt_bigcode
— GPTBigCodeModel (GPTBigCode 模型) -
gpt_neo
— GPTNeoModel (GPT Neo 模型) -
gpt_neox
— GPTNeoXModel (GPT NeoX 模型) -
gpt_neox_japanese
— GPTNeoXJapaneseModel (GPT NeoX Japanese 模型) -
gptj
— GPTJModel (GPT-J 模型) -
gptsan-japanese
— GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型) -
graphormer
— GraphormerModel (Graphormer 模型) -
groupvit
— GroupViTModel (GroupViT 模型) -
hubert
— HubertModel (Hubert 模型) -
ibert
— IBertModel (I-BERT 模型) -
idefics
— IdeficsModel (IDEFICS 模型) -
imagegpt
— ImageGPTModel (ImageGPT 模型) -
informer
— InformerModel (Informer 模型) -
jukebox
— JukeboxModel (Jukebox 模型) -
kosmos-2
— Kosmos2Model (KOSMOS-2 模型) -
layoutlm
— LayoutLMModel (LayoutLM 模型) -
layoutlmv2
— LayoutLMv2Model (LayoutLMv2 模型) -
layoutlmv3
— LayoutLMv3Model (LayoutLMv3 模型) -
led
— LEDModel (LED 模型) -
levit
— LevitModel (LeViT 模型) -
lilt
— LiltModel (LiLT 模型) -
llama
— LlamaModel (LLaMA 模型) -
longformer
— LongformerModel (Longformer 模型) -
longt5
— LongT5Model (LongT5 模型) -
luke
— LukeModel (LUKE 模型) -
lxmert
— LxmertModel (LXMERT 模型) -
m2m_100
— M2M100Model (M2M100 模型) -
marian
— MarianModel (Marian 模型) -
markuplm
— MarkupLMModel (MarkupLM 模型) -
mask2former
— Mask2FormerModel (Mask2Former 模型) -
maskformer
— MaskFormerModel (MaskFormer 模型) -
maskformer-swin
—MaskFormerSwinModel
(MaskFormerSwin 模型) -
mbart
— MBartModel (mBART 模型) -
mctct
— MCTCTModel (M-CTC-T 模型) -
mega
— MegaModel (MEGA 模型) -
megatron-bert
— MegatronBertModel (Megatron-BERT 模型) -
mgp-str
— MgpstrForSceneTextRecognition (MGP-STR 模型) -
mistral
— MistralModel (Mistral 模型) -
mixtral
— MixtralModel (Mixtral 模型) -
mobilebert
— MobileBertModel (MobileBERT 模型) -
mobilenet_v1
— MobileNetV1Model (MobileNetV1 模型) -
mobilenet_v2
— MobileNetV2Model (MobileNetV2 模型) -
mobilevit
— MobileViTModel (MobileViT 模型) -
mobilevitv2
— MobileViTV2Model (MobileViTV2 模型) -
mpnet
— MPNetModel (MPNet 模型) -
mpt
— MptModel (MPT 模型) -
mra
— MraModel (MRA 模型) -
mt5
— MT5Model (MT5 模型) -
mvp
— MvpModel (MVP 模型) -
nat
— NatModel (NAT 模型) -
nezha
— NezhaModel (Nezha 模型) -
nllb-moe
— NllbMoeModel (NLLB-MOE 模型) -
nystromformer
— NystromformerModel (Nyströmformer 模型) -
oneformer
— OneFormerModel (OneFormer 模型) -
open-llama
— OpenLlamaModel (OpenLlama 模型) -
openai-gpt
— OpenAIGPTModel (OpenAI GPT 模型) -
opt
— OPTModel (OPT 模型) -
owlv2
— Owlv2Model (OWLv2 模型) -
owlvit
— OwlViTModel (OWL-ViT 模型) -
patchtsmixer
— PatchTSMixerModel (PatchTSMixer 模型) -
patchtst
— PatchTSTModel (PatchTST 模型) -
pegasus
— PegasusModel (Pegasus 模型) -
pegasus_x
— PegasusXModel (PEGASUS-X 模型) -
perceiver
— PerceiverModel (感知器模型) -
persimmon
— PersimmonModel (Persimmon 模型) -
phi
— PhiModel (Phi 模型) -
plbart
— PLBartModel (PLBart 模型) -
poolformer
— PoolFormerModel (PoolFormer 模型) -
prophetnet
— ProphetNetModel (ProphetNet 模型) -
pvt
— PvtModel (PVT 模型) -
qdqbert
— QDQBertModel (QDQBert 模型) -
qwen2
— Qwen2Model (Qwen2 模型) -
reformer
— ReformerModel (Reformer 模型) -
regnet
— RegNetModel (RegNet 模型) -
rembert
— RemBertModel (RemBERT 模型) -
resnet
— ResNetModel (ResNet 模型) -
retribert
— RetriBertModel (RetriBERT 模型) -
roberta
— RobertaModel (RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertModel (RoCBert 模型) -
roformer
— RoFormerModel (RoFormer 模型) -
rwkv
— RwkvModel (RWKV 模型) -
sam
— SamModel (SAM 模型) -
seamless_m4t
— SeamlessM4TModel (SeamlessM4T 模型) -
seamless_m4t_v2
— SeamlessM4Tv2Model (SeamlessM4Tv2 模型) -
segformer
— SegformerModel (SegFormer 模型) -
sew
— SEWModel (SEW 模型) -
sew-d
— SEWDModel (SEW-D 模型) -
siglip
— SiglipModel (SigLIP 模型) -
siglip_vision_model
— SiglipVisionModel (SiglipVisionModel 模型) -
speech_to_text
— Speech2TextModel (Speech2Text 模型) -
speecht5
— SpeechT5Model (SpeechT5 模型) -
splinter
— SplinterModel (Splinter 模型) -
squeezebert
— SqueezeBertModel (SqueezeBERT 模型) -
swiftformer
— SwiftFormerModel (SwiftFormer 模型) -
swin
— SwinModel (Swin Transformer 模型) -
swin2sr
— Swin2SRModel (Swin2SR 模型) -
swinv2
— Swinv2Model (Swin Transformer V2 模型) -
switch_transformers
— SwitchTransformersModel (SwitchTransformers 模型) -
t5
— T5Model (T5 模型) -
table-transformer
— TableTransformerModel (Table Transformer 模型) -
tapas
— TapasModel (TAPAS 模型) -
time_series_transformer
— TimeSeriesTransformerModel (Time Series Transformer 模型) -
timesformer
— TimesformerModel (TimeSformer 模型) -
timm_backbone
—TimmBackbone
(TimmBackbone 模型) -
trajectory_transformer
— TrajectoryTransformerModel (Trajectory Transformer 模型) -
transfo-xl
— TransfoXLModel (Transformer-XL 模型) -
tvlt
— TvltModel (TVLT 模型) -
tvp
— TvpModel (TVP 模型) -
umt5
— UMT5Model (UMT5 模型) -
unispeech
— UniSpeechModel (UniSpeech 模型) -
unispeech-sat
— UniSpeechSatModel (UniSpeechSat 模型) -
univnet
— UnivNetModel (UnivNet 模型) -
van
— VanModel (VAN 模型) -
videomae
— VideoMAEModel (VideoMAE 模型) -
vilt
— ViltModel (ViLT 模型) -
vision-text-dual-encoder
— VisionTextDualEncoderModel (VisionTextDualEncoder 模型) -
visual_bert
— VisualBertModel (VisualBERT 模型) -
vit
— ViTModel (ViT 模型) -
vit_hybrid
— ViTHybridModel (ViT Hybrid 模型) -
vit_mae
— ViTMAEModel (ViTMAE 模型) -
vit_msn
— ViTMSNModel (ViTMSN 模型) -
vitdet
— VitDetModel (VitDet 模型) -
vits
— VitsModel (VITS 模型) -
vivit
— VivitModel (ViViT 模型) -
wav2vec2
— Wav2Vec2Model (Wav2Vec2 模型) -
wav2vec2-bert
— Wav2Vec2BertModel (Wav2Vec2-BERT 模型) -
wav2vec2-conformer
— Wav2Vec2ConformerModel (Wav2Vec2-Conformer 模型) -
wavlm
— WavLMModel (WavLM 模型) -
whisper
— WhisperModel (Whisper 模型) -
xclip
— XCLIPModel (X-CLIP 模型) -
xglm
— XGLMModel (XGLM 模型) -
xlm
— XLMModel (XLM 模型) -
xlm-prophetnet
— XLMProphetNetModel (XLM-ProphetNet 模型) -
xlm-roberta
— XLMRobertaModel (XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLModel (XLM-RoBERTa-XL 模型) -
xlnet
— XLNetModel (XLNet 模型) -
xmod
— XmodModel (X-MOD 模型) -
yolos
— YolosModel (YOLOS 模型) -
yoso
— YosoModel (YOSO 模型)
默认情况下,模型处于评估模式,使用 model.eval()
(例如,dropout 模块被停用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModel.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModel
class transformers.TFAutoModel
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,在使用 from_pretrained() 类方法或 from_config() 类方法创建时,会被实例化为库中的基础模型类之一。
这个类不能直接使用 __init__()
实例化(会报错)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类: TFAlbertModel (ALBERT 模型)
- BartConfig 配置类: TFBartModel (BART 模型)
- BertConfig 配置类: TFBertModel (BERT 模型)
- BlenderbotConfig 配置类: TFBlenderbotModel (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: TFBlenderbotSmallModel (BlenderbotSmall 模型)
- BlipConfig 配置类: TFBlipModel (BLIP 模型)
- CLIPConfig 配置类: TFCLIPModel (CLIP 模型)
- CTRLConfig 配置类: TFCTRLModel (CTRL 模型)
- CamembertConfig 配置类: TFCamembertModel (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertModel (ConvBERT 模型)
- ConvNextConfig 配置类: TFConvNextModel (ConvNeXT 模型)
- ConvNextV2Config 配置类: TFConvNextV2Model (ConvNeXTV2 模型)
- CvtConfig 配置类: TFCvtModel (CvT 模型)
- DPRConfig 配置类: TFDPRQuestionEncoder (DPR 模型)
- Data2VecVisionConfig 配置类: TFData2VecVisionModel (Data2VecVision 模型)
- DebertaConfig 配置类: TFDebertaModel (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2Model (DeBERTa-v2 模型)
- DeiTConfig 配置类: TFDeiTModel (DeiT 模型)
- DistilBertConfig 配置类: TFDistilBertModel (DistilBERT 模型)
- EfficientFormerConfig 配置类: TFEfficientFormerModel (EfficientFormer 模型)
- ElectraConfig 配置类: TFElectraModel (ELECTRA 模型)
- EsmConfig 配置类: TFEsmModel (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertModel (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelModel 或 TFFunnelBaseModel (Funnel Transformer 模型)
- GPT2Config 配置类: TFGPT2Model (OpenAI GPT-2 模型)
- GPTJConfig 配置类: TFGPTJModel (GPT-J 模型)
- GroupViTConfig 配置类: TFGroupViTModel (GroupViT 模型)
- HubertConfig 配置类: TFHubertModel (Hubert 模型)
- LEDConfig 配置类: TFLEDModel (LED 模型)
- LayoutLMConfig 配置类: TFLayoutLMModel (LayoutLM 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3Model (LayoutLMv3 模型)
- LongformerConfig 配置类: TFLongformerModel (Longformer 模型)
- LxmertConfig 配置类: TFLxmertModel (LXMERT 模型)
- MBartConfig 配置类: TFMBartModel (mBART 模型)
- MPNetConfig 配置类: TFMPNetModel (MPNet 模型)
- MT5Config 配置类: TFMT5Model (MT5 模型)
- MarianConfig 配置类: TFMarianModel (Marian 模型)
- MobileBertConfig 配置类: TFMobileBertModel (MobileBERT 模型)
- MobileViTConfig 配置类: TFMobileViTModel (MobileViT 模型)
- OPTConfig 配置类: TFOPTModel (OPT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTModel (OpenAI GPT 模型)
- PegasusConfig 配置类: TFPegasusModel (Pegasus 模型)
- RegNetConfig 配置类: TFRegNetModel (RegNet 模型)
- RemBertConfig 配置类: TFRemBertModel (RemBERT 模型)
- ResNetConfig 配置类: TFResNetModel (ResNet 模型)
- RoFormerConfig 配置类: TFRoFormerModel (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaModel (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- SamConfig 配置类: TFSamModel (SAM 模型)
- SegformerConfig 配置类: TFSegformerModel (SegFormer 模型)
- Speech2TextConfig 配置类: TFSpeech2TextModel (Speech2Text 模型)
- SwinConfig 配置类: TFSwinModel (Swin Transformer 模型)
- T5Config 配置类: TFT5Model (T5 模型)
- TapasConfig 配置类: TFTapasModel (TAPAS 模型)
- TransfoXLConfig 配置类: TFTransfoXLModel (Transformer-XL 模型)
- ViTConfig 配置类: TFViTModel (ViT 模型)
- ViTMAEConfig 配置类: TFViTMAEModel (ViTMAE 模型)
- VisionTextDualEncoderConfig 配置类:TFVisionTextDualEncoderModel(VisionTextDualEncoder 模型)
- Wav2Vec2Config 配置类:TFWav2Vec2Model(Wav2Vec2 模型)
- WhisperConfig 配置类:TFWhisperModel(Whisper 模型)
- XGLMConfig 配置类:TFXGLMModel(XGLM 模型)
- XLMConfig 配置类:TFXLMModel(XLM 模型)
- XLMRobertaConfig 配置类:TFXLMRobertaModel(XLM-RoBERTa 模型)
- XLNetConfig 配置类:TFXLNetModel(XLNet 模型)
从配置实例化库中的一个基础模型类。
注意:从配置文件加载模型不会加载模型权重。它只会影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModel
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = TFAutoModel.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)- 可以是:- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如
./my_model_directory/
。 - 一个PyTorch 状态字典保存文件的路径或 url(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并加载 TensorFlow 模型的加载路径比较慢。
- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)- 将传递给底层模型的__init__()
方法。 -
config
(PretrainedConfig,可选)- 模型使用的配置,而不是自动加载的配置。当以下情况发生时,配置可以自动加载:- 该模型是库中提供的一个模型(使用预训练模型的模型 id字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并且在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
oros.PathLike
, optional) — 下载的预训练模型配置应缓存在其中的目录路径,如果不使用标准缓存。 -
from_pt
(bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的模型文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 用于 Hub 上的代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(附加关键字参数,optional) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型实例化库的基本模型类之一。
根据配置对象的 model_type
属性选择要实例化的模型类(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺失时,通过在 pretrained_model_name_or_path
上进行模式匹配来回退:
-
albert
— TFAlbertModel (ALBERT 模型) -
bart
— TFBartModel (BART 模型) -
bert
— TFBertModel (BERT 模型) -
blenderbot
— TFBlenderbotModel (Blenderbot 模型) -
blenderbot-small
— TFBlenderbotSmallModel (BlenderbotSmall 模型) -
blip
— TFBlipModel (BLIP 模型) -
camembert
— TFCamembertModel (CamemBERT 模型) -
clip
— TFCLIPModel (CLIP 模型) -
convbert
— TFConvBertModel (ConvBERT 模型) -
convnext
— TFConvNextModel (ConvNeXT 模型) -
convnextv2
— TFConvNextV2Model (ConvNeXTV2 模型) -
ctrl
— TFCTRLModel (CTRL 模型) -
cvt
— TFCvtModel (CvT 模型) -
data2vec-vision
— TFData2VecVisionModel (Data2VecVision 模型) -
deberta
— TFDebertaModel (DeBERTa 模型) -
deberta-v2
— TFDebertaV2Model (DeBERTa-v2 模型) -
deit
— TFDeiTModel (DeiT 模型) -
distilbert
— TFDistilBertModel (DistilBERT 模型) -
dpr
— TFDPRQuestionEncoder (DPR 模型) -
efficientformer
— TFEfficientFormerModel (EfficientFormer 模型) -
electra
— TFElectraModel (ELECTRA 模型) -
esm
— TFEsmModel (ESM 模型) -
flaubert
— TFFlaubertModel (FlauBERT 模型) -
funnel
— TFFunnelModel 或 TFFunnelBaseModel (Funnel Transformer 模型) -
gpt-sw3
— TFGPT2Model (GPT-Sw3 模型) -
gpt2
— TFGPT2Model (OpenAI GPT-2 模型) -
gptj
— TFGPTJModel (GPT-J 模型) -
groupvit
— TFGroupViTModel (GroupViT 模型) -
hubert
— TFHubertModel (Hubert 模型) -
layoutlm
— TFLayoutLMModel (LayoutLM 模型) -
layoutlmv3
— TFLayoutLMv3Model (LayoutLMv3 模型) -
led
— TFLEDModel (LED 模型) -
longformer
— TFLongformerModel (Longformer 模型) -
lxmert
— TFLxmertModel (LXMERT 模型) -
marian
— TFMarianModel (Marian 模型) -
mbart
— TFMBartModel (mBART 模型) -
mobilebert
— TFMobileBertModel (MobileBERT 模型) -
mobilevit
— TFMobileViTModel (MobileViT 模型) -
mpnet
— TFMPNetModel (MPNet 模型) -
mt5
— TFMT5Model (MT5 模型) -
openai-gpt
— TFOpenAIGPTModel (OpenAI GPT 模型) -
opt
— TFOPTModel (OPT 模型) -
pegasus
— TFPegasusModel (Pegasus 模型) -
regnet
— TFRegNetModel (RegNet 模型) -
rembert
— TFRemBertModel (RemBERT 模型) -
resnet
— TFResNetModel (ResNet 模型) -
roberta
— TFRobertaModel (RoBERTa 模型) -
roberta-prelayernorm
— TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型) -
roformer
— TFRoFormerModel (RoFormer 模型) -
sam
— TFSamModel (SAM 模型) -
segformer
— TFSegformerModel (SegFormer 模型) -
speech_to_text
— TFSpeech2TextModel (Speech2Text 模型) -
swin
— TFSwinModel (Swin Transformer 模型) -
t5
— TFT5Model (T5 模型) -
tapas
— TFTapasModel (TAPAS 模型) -
transfo-xl
— TFTransfoXLModel (Transformer-XL 模型) -
vision-text-dual-encoder
— TFVisionTextDualEncoderModel (VisionTextDualEncoder 模型) -
vit
— TFViTModel (ViT 模型) -
vit_mae
— TFViTMAEModel (ViTMAE 模型) -
wav2vec2
— TFWav2Vec2Model (Wav2Vec2 模型) -
whisper
— TFWhisperModel (Whisper 模型) -
xglm
— TFXGLMModel (XGLM 模型) -
xlm
— TFXLMModel (XLM 模型) -
xlm-roberta
— TFXLMRobertaModel (XLM-RoBERTa 模型) -
xlnet
— TFXLNetModel (XLNet 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModel
class transformers.FlaxAutoModel
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将被实例化为库的基础模型类之一。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类:FlaxAlbertModel(ALBERT 模型)
- BartConfig 配置类:FlaxBartModel(BART 模型)
- BeitConfig 配置类:FlaxBeitModel(BEiT 模型)
- BertConfig 配置类:FlaxBertModel(BERT 模型)
- BigBirdConfig 配置类:FlaxBigBirdModel(BigBird 模型)
- BlenderbotConfig 配置类:FlaxBlenderbotModel(Blenderbot 模型)
- BlenderbotSmallConfig 配置类:FlaxBlenderbotSmallModel(BlenderbotSmall 模型)
- BloomConfig 配置类:FlaxBloomModel(BLOOM 模型)
- CLIPConfig 配置类:FlaxCLIPModel(CLIP 模型)
- DistilBertConfig 配置类:FlaxDistilBertModel(DistilBERT 模型)
- ElectraConfig 配置类:FlaxElectraModel(ELECTRA 模型)
- GPT2Config 配置类:FlaxGPT2Model(OpenAI GPT-2 模型)
- GPTJConfig 配置类:FlaxGPTJModel(GPT-J 模型)
- GPTNeoConfig 配置类: FlaxGPTNeoModel (GPT Neo 模型)
- LlamaConfig 配置类: FlaxLlamaModel (LLaMA 模型)
- LongT5Config 配置类: FlaxLongT5Model (LongT5 模型)
- MBartConfig 配置类: FlaxMBartModel (mBART 模型)
- MT5Config 配置类: FlaxMT5Model (MT5 模型)
- MarianConfig 配置类: FlaxMarianModel (Marian 模型)
- OPTConfig 配置类: FlaxOPTModel (OPT 模型)
- PegasusConfig 配置类: FlaxPegasusModel (Pegasus 模型)
- RegNetConfig 配置类: FlaxRegNetModel (RegNet 模型)
- ResNetConfig 配置类: FlaxResNetModel (ResNet 模型)
- RoFormerConfig 配置类: FlaxRoFormerModel (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaModel (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- T5Config 配置类: FlaxT5Model (T5 模型)
- ViTConfig 配置类: FlaxViTModel (ViT 模型)
- VisionTextDualEncoderConfig 配置类: FlaxVisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- Wav2Vec2Config 配置类: FlaxWav2Vec2Model (Wav2Vec2 模型)
- WhisperConfig 配置类: FlaxWhisperModel (Whisper 模型)
- XGLMConfig 配置类: FlaxXGLMModel (XGLM 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaModel (XLM-RoBERTa 模型)
从配置中实例化库中的基础模型类之一。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = FlaxAutoModel.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 一个字符串,预训练模型的 模型 id,托管在 huggingface.co 上的模型存储库内。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 保存的模型权重的 目录 的路径,例如,
./my_model_directory/
。 - 一个 PyTorch state_dict save file 的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,预训练模型的 模型 id,托管在 huggingface.co 上的模型存储库内。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数, 可选) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig, 可选) — 模型使用的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的 模型 id 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
并在目录中找到名为 config.json 的配置 JSON 文件来加载模型。
-
cache_dir
(str
或os.PathLike
, 可选) — 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_pt
(bool
, 可选, 默认为False
) — 从 PyTorch checkpoint save 文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, 可选, 默认为False
) — 是否强制下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, 可选, 默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, 可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
output_loading_info(bool,
可选, 默认为False
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, 可选, 默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供了config
或自动加载,行为不同:- 如果提供了带有
config
的配置,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
的每个键对应于配置属性,将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了带有
从预训练模型中实例化库的基本模型类之一。
根据配置对象的 model_type
属性选择要实例化的模型类(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能的话),或者当缺失时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— FlaxAlbertModel (ALBERT 模型) -
bart
— FlaxBartModel (BART 模型) -
beit
— FlaxBeitModel (BEiT 模型) -
bert
— FlaxBertModel (BERT 模型) -
big_bird
— FlaxBigBirdModel (BigBird 模型) -
blenderbot
— FlaxBlenderbotModel (Blenderbot 模型) -
blenderbot-small
— FlaxBlenderbotSmallModel (BlenderbotSmall 模型) -
bloom
— FlaxBloomModel (BLOOM 模型) -
clip
— FlaxCLIPModel (CLIP 模型) -
distilbert
— FlaxDistilBertModel (DistilBERT 模型) -
electra
— FlaxElectraModel(ELECTRA 模型) -
gpt-sw3
— FlaxGPT2Model(GPT-Sw3 模型) -
gpt2
— FlaxGPT2Model(OpenAI GPT-2 模型) -
gpt_neo
— FlaxGPTNeoModel(GPT Neo 模型) -
gptj
— FlaxGPTJModel(GPT-J 模型) -
llama
— FlaxLlamaModel(LLaMA 模型) -
longt5
— FlaxLongT5Model(LongT5 模型) -
marian
— FlaxMarianModel(Marian 模型) -
mbart
— FlaxMBartModel(mBART 模型) -
mt5
— FlaxMT5Model(MT5 模型) -
opt
— FlaxOPTModel(OPT 模型) -
pegasus
— FlaxPegasusModel(Pegasus 模型) -
regnet
— FlaxRegNetModel(RegNet 模型) -
resnet
— FlaxResNetModel(ResNet 模型) -
roberta
— FlaxRobertaModel(RoBERTa 模型) -
roberta-prelayernorm
— FlaxRobertaPreLayerNormModel(RoBERTa-PreLayerNorm 模型) -
roformer
— FlaxRoFormerModel(RoFormer 模型) -
t5
— FlaxT5Model(T5 模型) -
vision-text-dual-encoder
— FlaxVisionTextDualEncoderModel(VisionTextDualEncoder 模型) -
vit
— FlaxViTModel(ViT 模型) -
wav2vec2
— FlaxWav2Vec2Model(Wav2Vec2 模型) -
whisper
— FlaxWhisperModel(Whisper 模型) -
xglm
— FlaxXGLMModel(XGLM 模型) -
xlm-roberta
— FlaxXLMRobertaModel(XLM-RoBERTa 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
通用预训练类
以下自动类可用于实例化带有预训练头部的模型。
AutoModelForPreTraining
class transformers.AutoModelForPreTraining
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将作为库的模型类之一实例化(带有预训练头部)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类: AlbertForPreTraining (ALBERT 模型)
- BartConfig 配置类: BartForConditionalGeneration (BART 模型)
- BertConfig 配置类: BertForPreTraining (BERT 模型)
- BigBirdConfig 配置类: BigBirdForPreTraining (BigBird 模型)
- BloomConfig 配置类: BloomForCausalLM (BLOOM 模型)
- CTRLConfig 配置类: CTRLLMHeadModel (CTRL 模型)
- CamembertConfig 配置类: CamembertForMaskedLM (CamemBERT 模型)
- Data2VecTextConfig 配置类: Data2VecTextForMaskedLM (Data2VecText 模型)
- DebertaConfig 配置类: DebertaForMaskedLM (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- DistilBertConfig 配置类: DistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForPreTraining (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForPreTraining (ERNIE 模型)
- FNetConfig 配置类: FNetForPreTraining (FNet 模型)
- FSMTConfig 配置类: FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- FlaubertConfig 配置类: FlaubertWithLMHeadModel (FlauBERT 模型)
- FlavaConfig 配置类: FlavaForPreTraining (FLAVA 模型)
- FunnelConfig 配置类: FunnelForPreTraining (Funnel Transformer 模型)
- GPT2Config 配置类: GPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeForCausalLM (GPTBigCode 模型)
- GPTSanJapaneseConfig 配置类: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- IBertConfig 配置类: IBertForMaskedLM (I-BERT 模型)
- IdeficsConfig 配置类: IdeficsForVisionText2Text (IDEFICS 模型)
- LayoutLMConfig 配置类: LayoutLMForMaskedLM (LayoutLM 模型)
- LlavaConfig 配置类: LlavaForConditionalGeneration (LLaVa 模型)
- LongformerConfig 配置类: LongformerForMaskedLM (Longformer 模型)
- LukeConfig 配置类: LukeForMaskedLM (LUKE 模型)
- LxmertConfig 配置类: LxmertForPreTraining (LXMERT 模型)
- MPNetConfig 配置类: MPNetForMaskedLM (MPNet 模型)
- MegaConfig 配置类: MegaForMaskedLM (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForPreTraining (Megatron-BERT 模型)
- MobileBertConfig 配置类: MobileBertForPreTraining (MobileBERT 模型)
- MptConfig 配置类: MptForCausalLM (MPT 模型)
- MraConfig 配置类: MraForMaskedLM (MRA 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NezhaConfig 配置类: NezhaForPreTraining (Nezha 模型)
- NllbMoeConfig 配置类: NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- RetriBertConfig 配置类: RetriBertModel (RetriBERT 模型)
- RoCBertConfig 配置类: RoCBertForPreTraining (RoCBert 模型)
- RobertaConfig 配置类: RobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- RwkvConfig 配置类: RwkvForCausalLM (RWKV 模型)
- SplinterConfig 配置类: SplinterForPreTraining (Splinter 模型)
- SqueezeBertConfig 配置类: SqueezeBertForMaskedLM (SqueezeBERT 模型)
- SwitchTransformersConfig 配置类: SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- T5Config 配置类: T5ForConditionalGeneration (T5 模型)
- TapasConfig 配置类: TapasForMaskedLM (TAPAS 模型)
- TransfoXLConfig 配置类: TransfoXLLMHeadModel (Transformer-XL 模型)
- TvltConfig 配置类: TvltForPreTraining (TVLT 模型)
- UniSpeechConfig 配置类: UniSpeechForPreTraining (UniSpeech 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatForPreTraining (UniSpeechSat 模型)
- ViTMAEConfig 配置类: ViTMAEForPreTraining (ViTMAE 模型)
- VideoMAEConfig 配置类: VideoMAEForPreTraining (VideoMAE 模型)
- VipLlavaConfig 配置类: VipLlavaForConditionalGeneration (VipLlava 模型)
- VisualBertConfig 配置类: VisualBertForPreTraining (VisualBERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2ForPreTraining (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer 模型)
- XLMConfig 配置类: XLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetLMHeadModel (XLNet 模型)
- XmodConfig 配置类: XmodForMaskedLM (X-MOD 模型)
从配置实例化库中的一个模型类(带有预训练头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForPreTraining
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = AutoModelForPreTraining.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是:- 一个字符串,托管在 huggingface.co 模型存储库内的预训练模型的模型 id。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如,
./my_model_directory/
。 - 一个TensorFlow 索引检查点文件的路径或 url(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象提供为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型要慢。
- 一个字符串,托管在 huggingface.co 模型存储库内的预训练模型的模型 id。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 用于替代自动加载的配置的模型配置。当以下情况时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的模型 id字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
state_dict
(Dict[str, torch.Tensor],可选)— 一个状态字典,用于替代从保存的权重文件加载的状态字典。 如果您想从预训练配置创建模型但加载自己的权重,则可以使用此选项。在这种情况下,您应该检查是否使用 save_pretrained()和 from_pretrained()不是一个更简单的选项。 -
cache_dir
(str
或os.PathLike
,可选)— 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_tf
(bool
,可选,默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除未完全接收的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 要使用的代理服务器的字典,按协议或端点,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理将在每个请求上使用。 -
output_loading_info(bool,
可选,默认为False
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, 可选, 默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供了config
或自动加载了config
,行为会有所不同:- 如果提供了带有
config
的配置,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
的每个与配置属性对应的键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了带有
从预训练模型实例化库中的一个模型类(带有预训练头)。
要实例化的模型类是根据配置对象的 model_type
属性选择的(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能的话),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— AlbertForPreTraining (ALBERT 模型) -
bart
— BartForConditionalGeneration (BART 模型) -
bert
— BertForPreTraining (BERT 模型) -
big_bird
— BigBirdForPreTraining (BigBird 模型) -
bloom
— BloomForCausalLM (BLOOM 模型) -
camembert
— CamembertForMaskedLM (CamemBERT 模型) -
ctrl
— CTRLLMHeadModel (CTRL 模型) -
data2vec-text
— Data2VecTextForMaskedLM (Data2VecText 模型) -
deberta
— DebertaForMaskedLM (DeBERTa 模型) -
deberta-v2
— DebertaV2ForMaskedLM (DeBERTa-v2 模型) -
distilbert
— DistilBertForMaskedLM (DistilBERT 模型) -
electra
— ElectraForPreTraining (ELECTRA 模型) -
ernie
— ErnieForPreTraining (ERNIE 模型) -
flaubert
— FlaubertWithLMHeadModel (FlauBERT 模型) -
flava
— FlavaForPreTraining (FLAVA 模型) -
fnet
— FNetForPreTraining (FNet 模型) -
fsmt
— FSMTForConditionalGeneration (FairSeq 机器翻译模型) -
funnel
— FunnelForPreTraining (Funnel Transformer 模型) -
gpt-sw3
— GPT2LMHeadModel (GPT-Sw3 模型) -
gpt2
— GPT2LMHeadModel (OpenAI GPT-2 模型) -
gpt_bigcode
— GPTBigCodeForCausalLM (GPTBigCode 模型) -
gptsan-japanese
— GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型) -
ibert
— IBertForMaskedLM (I-BERT 模型) -
idefics
— IdeficsForVisionText2Text (IDEFICS 模型) -
layoutlm
— LayoutLMForMaskedLM (LayoutLM 模型) -
llava
— LlavaForConditionalGeneration (LLaVa 模型) -
longformer
— LongformerForMaskedLM (Longformer 模型) -
luke
— LukeForMaskedLM (LUKE 模型) -
lxmert
— LxmertForPreTraining (LXMERT 模型) -
mega
— MegaForMaskedLM (MEGA 模型) -
megatron-bert
— MegatronBertForPreTraining (Megatron-BERT 模型) -
mobilebert
— MobileBertForPreTraining (MobileBERT 模型) -
mpnet
— MPNetForMaskedLM (MPNet 模型) -
mpt
— MptForCausalLM (MPT 模型) -
mra
— MraForMaskedLM (MRA 模型) -
mvp
— MvpForConditionalGeneration (MVP 模型) -
nezha
— NezhaForPreTraining (Nezha 模型) -
nllb-moe
— NllbMoeForConditionalGeneration (NLLB-MOE 模型) -
openai-gpt
— OpenAIGPTLMHeadModel (OpenAI GPT 模型) -
retribert
— RetriBertModel (RetriBERT 模型) -
roberta
— RobertaForMaskedLM (RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertForPreTraining (RoCBert 模型) -
rwkv
— RwkvForCausalLM (RWKV 模型) -
splinter
— SplinterForPreTraining (Splinter 模型) -
squeezebert
— SqueezeBertForMaskedLM (SqueezeBERT 模型) -
switch_transformers
— SwitchTransformersForConditionalGeneration (SwitchTransformers 模型) -
t5
— T5ForConditionalGeneration (T5 模型) -
tapas
— TapasForMaskedLM (TAPAS 模型) -
transfo-xl
— TransfoXLLMHeadModel (Transformer-XL 模型) -
tvlt
— TvltForPreTraining (TVLT 模型) -
unispeech
— UniSpeechForPreTraining (UniSpeech 模型) -
unispeech-sat
— UniSpeechSatForPreTraining (UniSpeechSat 模型) -
videomae
— VideoMAEForPreTraining (VideoMAE 模型) -
vipllava
— VipLlavaForConditionalGeneration (VipLlava 模型) -
visual_bert
— VisualBertForPreTraining (VisualBERT 模型) -
vit_mae
— ViTMAEForPreTraining (ViTMAE 模型) -
wav2vec2
— Wav2Vec2ForPreTraining (Wav2Vec2 模型) -
wav2vec2-conformer
— Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer 模型) -
xlm
— XLMWithLMHeadModel (XLM 模型) -
xlm-roberta
— XLMRobertaForMaskedLM (XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型) -
xlnet
— XLNetLMHeadModel (XLNet 模型) -
xmod
— XmodForMaskedLM (X-MOD 模型)
默认情况下,模型处于评估模式,使用 model.eval()
(例如,dropout 模块被停用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForPreTraining.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForPreTraining.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForPreTraining.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForPreTraining
class transformers.TFAutoModelForPreTraining
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将作为库的模型类之一实例化(带有预训练头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
<来源>
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig)— 根据配置类选择要实例化的模型类:- AlbertConfig 配置类:TFAlbertForPreTraining(ALBERT 模型)
- BartConfig 配置类:TFBartForConditionalGeneration(BART 模型)
- BertConfig 配置类:TFBertForPreTraining(BERT 模型)
- CTRLConfig 配置类:TFCTRLLMHeadModel(CTRL 模型)
- CamembertConfig 配置类:TFCamembertForMaskedLM(CamemBERT 模型)
- DistilBertConfig 配置类:TFDistilBertForMaskedLM(DistilBERT 模型)
- ElectraConfig 配置类:TFElectraForPreTraining(ELECTRA 模型)
- FlaubertConfig 配置类:TFFlaubertWithLMHeadModel(FlauBERT 模型)
- FunnelConfig 配置类:TFFunnelForPreTraining(漏斗 Transformer 模型)
- GPT2Config 配置类:TFGPT2LMHeadModel(OpenAI GPT-2 模型)
- LayoutLMConfig 配置类:TFLayoutLMForMaskedLM(LayoutLM 模型)
- LxmertConfig 配置类:TFLxmertForPreTraining(LXMERT 模型)
- MPNetConfig 配置类:TFMPNetForMaskedLM(MPNet 模型)
- MobileBertConfig 配置类:TFMobileBertForPreTraining(MobileBERT 模型)
- OpenAIGPTConfig 配置类:TFOpenAIGPTLMHeadModel(OpenAI GPT 模型)
- RobertaConfig 配置类:TFRobertaForMaskedLM(RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类:TFRobertaPreLayerNormForMaskedLM(RoBERTa-PreLayerNorm 模型)
- T5Config 配置类:TFT5ForConditionalGeneration(T5 模型)
- TapasConfig 配置类:TFTapasForMaskedLM(TAPAS 模型)
- TransfoXLConfig 配置类:TFTransfoXLLMHeadModel(Transformer-XL 模型)
- ViTMAEConfig 配置类:TFViTMAEForPreTraining(ViTMAE 模型)
- XLMConfig 配置类:TFXLMWithLMHeadModel(XLM 模型)
- XLMRobertaConfig 配置类:TFXLMRobertaForMaskedLM(XLM-RoBERTa 模型)
- XLNetConfig 配置类:TFXLNetLMHeadModel(XLNet 模型)
从配置实例化库中的一个模型类(带有预训练头)。
注意:从配置文件加载模型 不会 加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = TFAutoModelForPreTraining.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是以下之一:- 一个字符串,预训练模型的 模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 保存的模型权重的 目录 路径,例如,
./my_model_directory/
。 - 路径或 URL 指向 PyTorch state_dict 保存文件(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,预训练模型的 模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,optional) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,optional) — 用于模型的配置,而不是自动加载的配置。当以下情况发生时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的 模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
并在目录中找到名为 config.json 的配置 JSON 文件加载模型。
-
cache_dir
(str
oros.PathLike
, optional) — 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_pt
(bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制下载(重新下载)模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求中使用。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否还返回包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的代码文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,optional) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为会有所不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
的每个对应配置属性的键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给基础模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有预训练头)。
要实例化的模型类是根据配置对象的 model_type
属性选择的(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能的话),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— TFAlbertForPreTraining (ALBERT 模型) -
bart
— TFBartForConditionalGeneration (BART 模型) -
bert
— TFBertForPreTraining (BERT 模型) -
camembert
— TFCamembertForMaskedLM (CamemBERT 模型) -
ctrl
— TFCTRLLMHeadModel (CTRL 模型) -
distilbert
— TFDistilBertForMaskedLM (DistilBERT 模型) -
electra
— TFElectraForPreTraining (ELECTRA 模型) -
flaubert
— TFFlaubertWithLMHeadModel (FlauBERT 模型) -
funnel
— TFFunnelForPreTraining (Funnel Transformer 模型) -
gpt-sw3
— TFGPT2LMHeadModel (GPT-Sw3 模型) -
gpt2
— TFGPT2LMHeadModel (OpenAI GPT-2 模型) -
layoutlm
— TFLayoutLMForMaskedLM (LayoutLM 模型) -
lxmert
— TFLxmertForPreTraining (LXMERT 模型) -
mobilebert
— TFMobileBertForPreTraining (MobileBERT 模型) -
mpnet
— TFMPNetForMaskedLM (MPNet 模型) -
openai-gpt
— TFOpenAIGPTLMHeadModel (OpenAI GPT 模型) -
roberta
— TFRobertaForMaskedLM (RoBERTa 模型) -
roberta-prelayernorm
— TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型) -
t5
— TFT5ForConditionalGeneration (T5 模型) -
tapas
— TFTapasForMaskedLM (TAPAS 模型) -
transfo-xl
— TFTransfoXLLMHeadModel (Transformer-XL 模型) -
vit_mae
— TFViTMAEForPreTraining (ViTMAE 模型) -
xlm
— TFXLMWithLMHeadModel(XLM 模型) -
xlm-roberta
— TFXLMRobertaForMaskedLM(XLM-RoBERTa 模型) -
xlnet
— TFXLNetLMHeadModel(XLNet 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForPreTraining.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForPreTraining.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForPreTraining
class transformers.FlaxAutoModelForPreTraining
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将作为库的模型类之一实例化(带有预训练头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 选择要实例化的模型类基于配置类:- AlbertConfig 配置类:FlaxAlbertForPreTraining(ALBERT 模型)
- BartConfig 配置类:FlaxBartForConditionalGeneration(BART 模型)
- BertConfig 配置类:FlaxBertForPreTraining(BERT 模型)
- BigBirdConfig 配置类:FlaxBigBirdForPreTraining(BigBird 模型)
- ElectraConfig 配置类:FlaxElectraForPreTraining(ELECTRA 模型)
- LongT5Config 配置类:FlaxLongT5ForConditionalGeneration(LongT5 模型)
- MBartConfig 配置类:FlaxMBartForConditionalGeneration(mBART 模型)
- MT5Config 配置类:FlaxMT5ForConditionalGeneration(MT5 模型)
- RoFormerConfig 配置类:FlaxRoFormerForMaskedLM(RoFormer 模型)
- RobertaConfig 配置类:FlaxRobertaForMaskedLM(RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- T5Config 配置类: FlaxT5ForConditionalGeneration (T5 模型)
- Wav2Vec2Config 配置类: FlaxWav2Vec2ForPreTraining (Wav2Vec2 模型)
- WhisperConfig 配置类: FlaxWhisperForConditionalGeneration (Whisper 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
从配置实例化库中的一个模型类(带有预训练头)时,可以自动加载配置。
注意: 从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForPreTraining
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = FlaxAutoModelForPreTraining.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 一个字符串,预训练模型的 model id,托管在 huggingface.co 上的模型存储库内。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 保存的模型权重的 目录 的路径,例如,
./my_model_directory/
。 - PyTorch state_dict save file 的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。使用此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,预训练模型的 model id,托管在 huggingface.co 上的模型存储库内。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选) — 用于替代自动加载的配置的模型配置。当:- 是库提供的模型(使用预训练模型的 model id 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
并在目录中找到名为 config.json 的配置 JSON 文件来加载模型。
-
cache_dir
(str
或os.PathLike
,可选) — 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_pt
(bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 在 Hub 上使用的特定代码修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,optional) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,其行为有所不同:- 如果提供了带有
config
的配置,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了带有
从预训练模型实例化库中的一个模型类(带有预训练头)。
根据配置对象的 model_type
属性选择要实例化的模型类(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能),或者当缺失时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— FlaxAlbertForPreTraining(ALBERT 模型) -
bart
— FlaxBartForConditionalGeneration(BART 模型) -
bert
— FlaxBertForPreTraining(BERT 模型) -
big_bird
— FlaxBigBirdForPreTraining(BigBird 模型) -
electra
— FlaxElectraForPreTraining(ELECTRA 模型) -
longt5
— FlaxLongT5ForConditionalGeneration (LongT5 模型) -
mbart
— FlaxMBartForConditionalGeneration (mBART 模型) -
mt5
— FlaxMT5ForConditionalGeneration (MT5 模型) -
roberta
— FlaxRobertaForMaskedLM (RoBERTa 模型) -
roberta-prelayernorm
— FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型) -
roformer
— FlaxRoFormerForMaskedLM (RoFormer 模型) -
t5
— FlaxT5ForConditionalGeneration (T5 模型) -
wav2vec2
— FlaxWav2Vec2ForPreTraining (Wav2Vec2 模型) -
whisper
— FlaxWhisperForConditionalGeneration (Whisper 模型) -
xlm-roberta
— FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForPreTraining.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForPreTraining.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
自然语言处理
以下自动类适用于以下自然语言处理任务。
AutoModelForCausalLM
class transformers.AutoModelForCausalLM
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将实例化为库中的一个模型类(带有因果语言建模头)。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- BartConfig 配置类: BartForCausalLM (BART 模型)
- BertConfig 配置类: BertLMHeadModel (BERT 模型)
- BertGenerationConfig 配置类: BertGenerationDecoder (Bert Generation 模型)
- BigBirdConfig 配置类: BigBirdForCausalLM (BigBird 模型)
- BigBirdPegasusConfig 配置类: BigBirdPegasusForCausalLM (BigBird-Pegasus 模型)
- BioGptConfig 配置类: BioGptForCausalLM (BioGpt 模型)
- BlenderbotConfig 配置类: BlenderbotForCausalLM (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: BlenderbotSmallForCausalLM (BlenderbotSmall 模型)
- BloomConfig 配置类: BloomForCausalLM (BLOOM 模型)
- CTRLConfig 配置类: CTRLLMHeadModel (CTRL 模型)
- CamembertConfig 配置类: CamembertForCausalLM (CamemBERT 模型)
- CodeGenConfig 配置类: CodeGenForCausalLM (CodeGen 模型)
- CpmAntConfig 配置类: CpmAntForCausalLM (CPM-Ant 模型)
- Data2VecTextConfig 配置类: Data2VecTextForCausalLM (Data2VecText 模型)
- ElectraConfig 配置类: ElectraForCausalLM (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForCausalLM (ERNIE 模型)
- FalconConfig 配置类: FalconForCausalLM (Falcon 模型)
- FuyuConfig 配置类: FuyuForCausalLM (Fuyu 模型)
- GPT2Config 配置类: GPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeForCausalLM (GPTBigCode 模型)
- GPTJConfig 配置类: GPTJForCausalLM (GPT-J 模型)
- GPTNeoConfig 配置类: GPTNeoForCausalLM (GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForCausalLM (GPT NeoX 模型)
- GPTNeoXJapaneseConfig 配置类: GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese 模型)
- GitConfig 配置类: GitForCausalLM (GIT 模型)
- LlamaConfig 配置类: LlamaForCausalLM (LLaMA 模型)
- MBartConfig 配置类: MBartForCausalLM (mBART 模型)
- MarianConfig 配置类: MarianForCausalLM (Marian 模型)
- MegaConfig 配置类: MegaForCausalLM (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForCausalLM (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForCausalLM (Mistral 模型)
- MixtralConfig 配置类: MixtralForCausalLM (Mixtral 模型)
- MptConfig 配置类: MptForCausalLM (MPT 模型)
- MusicgenConfig 配置类: MusicgenForCausalLM (MusicGen 模型)
- MvpConfig 配置类: MvpForCausalLM (MVP 模型)
- OPTConfig 配置类: OPTForCausalLM (OPT 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- OpenLlamaConfig 配置类: OpenLlamaForCausalLM (OpenLlama 模型)
- PLBartConfig 配置类: PLBartForCausalLM (PLBart 模型)
- PegasusConfig 配置类: PegasusForCausalLM (Pegasus 模型)
- PersimmonConfig 配置类: PersimmonForCausalLM (Persimmon 模型)
- PhiConfig 配置类: PhiForCausalLM (Phi 模型)
- ProphetNetConfig 配置类: ProphetNetForCausalLM (ProphetNet 模型)
- QDQBertConfig 配置类: QDQBertLMHeadModel (QDQBert 模型)
- Qwen2Config 配置类: Qwen2ForCausalLM (Qwen2 模型)
- ReformerConfig 配置类: ReformerModelWithLMHead (Reformer 模型)
- RemBertConfig 配置类: RemBertForCausalLM (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForCausalLM (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForCausalLM (RoFormer 模型)
- RobertaConfig 配置类: RobertaForCausalLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- RwkvConfig 配置类: RwkvForCausalLM (RWKV 模型)
- Speech2Text2Config 配置类: Speech2Text2ForCausalLM (Speech2Text2 模型)
- TrOCRConfig 配置类: TrOCRForCausalLM (TrOCR 模型)
- TransfoXLConfig 配置类: TransfoXLLMHeadModel (Transformer-XL 模型)
- WhisperConfig 配置类:WhisperForCausalLM(Whisper 模型)
- XGLMConfig 配置类:XGLMForCausalLM(XGLM 模型)
- XLMConfig 配置类:XLMWithLMHeadModel(XLM 模型)
- XLMProphetNetConfig 配置类:XLMProphetNetForCausalLM(XLM-ProphetNet 模型)
- XLMRobertaConfig 配置类:XLMRobertaForCausalLM(XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类:XLMRobertaXLForCausalLM(XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类:XLNetLMHeadModel(XLNet 模型)
- XmodConfig 配置类:XmodForCausalLM(X-MOD 模型)
从配置实例化库中的一个模型类(带有因果语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = AutoModelForCausalLM.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)- 可以是:- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型仓库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者命名空间下的用户或组织名称,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 保存的模型权重的目录路径,例如,
./my_model_directory/
。 - 一个TensorFlow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应提供配置对象作为config
参数。使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型后,此加载路径比较慢。
- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型仓库中。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)- 将传递给底层模型的__init__()
方法。 -
config
(PretrainedConfig,可选)- 用于替代自动加载的配置的模型配置。当:- 模型是库提供的模型(使用预训练模型的模型 id 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为 config.json 的配置 JSON 文件。
-
state_dict
(Dict[str, torch.Tensor], optional) — 用于替代从保存的权重文件加载的状态字典。 如果要从预训练配置创建模型但加载自己的权重,则可以使用此选项。但在这种情况下,您应该检查是否使用 save_pretrained() 和 from_pretrained() 不是更简单的选项。 -
cache_dir
(str
或os.PathLike
, optional) — 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_tf
(bool
, optional, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, 默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 -
output_loading_info(bool,
optional, 默认为False
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, 默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数, optional) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,其行为会有所不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,则首先将
kwargs
传递给配置类的初始化函数(from_pretrained())。kwargs
的每个键对应于一个配置属性,将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型中实例化库中的一个模型类(带有因果语言建模头)。
根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能的话),选择要实例化的模型类,或者当缺失时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
bart
— BartForCausalLM (BART 模型) -
bert
— BertLMHeadModel (BERT 模型) -
bert-generation
— BertGenerationDecoder (Bert Generation 模型) -
big_bird
— BigBirdForCausalLM (BigBird 模型) -
bigbird_pegasus
— BigBirdPegasusForCausalLM (BigBird-Pegasus 模型) -
biogpt
— BioGptForCausalLM (BioGpt 模型) -
blenderbot
— BlenderbotForCausalLM (Blenderbot 模型) -
blenderbot-small
— BlenderbotSmallForCausalLM (BlenderbotSmall 模型) -
bloom
— BloomForCausalLM (BLOOM 模型) -
camembert
— CamembertForCausalLM (CamemBERT 模型) -
code_llama
— LlamaForCausalLM (CodeLlama 模型) -
codegen
— CodeGenForCausalLM (CodeGen 模型) -
cpmant
— CpmAntForCausalLM (CPM-Ant 模型) -
ctrl
— CTRLLMHeadModel (CTRL 模型) -
data2vec-text
— Data2VecTextForCausalLM (Data2VecText 模型) -
electra
— ElectraForCausalLM (ELECTRA 模型) -
ernie
— ErnieForCausalLM (ERNIE 模型) -
falcon
— FalconForCausalLM (Falcon 模型) -
fuyu
— FuyuForCausalLM (Fuyu 模型) -
git
— GitForCausalLM (GIT 模型) -
gpt-sw3
— GPT2LMHeadModel (GPT-Sw3 模型) -
gpt2
— GPT2LMHeadModel (OpenAI GPT-2 模型) -
gpt_bigcode
— GPTBigCodeForCausalLM (GPTBigCode 模型) -
gpt_neo
— GPTNeoForCausalLM (GPT Neo 模型) -
gpt_neox
— GPTNeoXForCausalLM (GPT NeoX 模型) -
gpt_neox_japanese
— GPTNeoXJapaneseForCausalLM (GPT NeoX 日语模型) -
gptj
— GPTJForCausalLM (GPT-J 模型) -
llama
— LlamaForCausalLM (LLaMA 模型) -
marian
— MarianForCausalLM (Marian 模型) -
mbart
— MBartForCausalLM (mBART 模型) -
mega
— MegaForCausalLM (MEGA 模型) -
megatron-bert
— MegatronBertForCausalLM (Megatron-BERT 模型) -
mistral
— MistralForCausalLM (Mistral 模型) -
mixtral
— MixtralForCausalLM (Mixtral 模型) -
mpt
— MptForCausalLM (MPT 模型) -
musicgen
— MusicgenForCausalLM (MusicGen 模型) -
mvp
— MvpForCausalLM (MVP 模型) -
open-llama
— OpenLlamaForCausalLM (OpenLlama 模型) -
openai-gpt
— OpenAIGPTLMHeadModel (OpenAI GPT 模型) -
opt
— OPTForCausalLM (OPT 模型) -
pegasus
— PegasusForCausalLM (Pegasus 模型) -
persimmon
— PersimmonForCausalLM (Persimmon 模型) -
phi
— PhiForCausalLM (Phi 模型) -
plbart
— PLBartForCausalLM (PLBart 模型) -
prophetnet
— ProphetNetForCausalLM (ProphetNet 模型) -
qdqbert
— QDQBertLMHeadModel (QDQBert 模型) -
qwen2
— Qwen2ForCausalLM (Qwen2 模型) -
reformer
— ReformerModelWithLMHead (Reformer 模型) -
rembert
— RemBertForCausalLM (RemBERT 模型) -
roberta
— RobertaForCausalLM (RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertForCausalLM (RoCBert 模型) -
roformer
— RoFormerForCausalLM (RoFormer 模型) -
rwkv
— RwkvForCausalLM(RWKV 模型) -
speech_to_text_2
— Speech2Text2ForCausalLM(Speech2Text2 模型) -
transfo-xl
— TransfoXLLMHeadModel(Transformer-XL 模型) -
trocr
— TrOCRForCausalLM(TrOCR 模型) -
whisper
— WhisperForCausalLM(Whisper 模型) -
xglm
— XGLMForCausalLM(XGLM 模型) -
xlm
— XLMWithLMHeadModel(XLM 模型) -
xlm-prophetnet
— XLMProphetNetForCausalLM(XLM-ProphetNet 模型) -
xlm-roberta
— XLMRobertaForCausalLM(XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLForCausalLM(XLM-RoBERTa-XL 模型) -
xlnet
— XLNetLMHeadModel(XLNet 模型) -
xmod
— XmodForCausalLM(X-MOD 模型)
默认情况下,该模型处于评估模式,使用model.eval()
(例如,关闭了 dropout 模块)。要训练模型,您应该首先使用model.train()
将其设置回训练模式
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCausalLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCausalLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCausalLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForCausalLM
class transformers.TFAutoModelForCausalLM
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将作为库中的模型类之一实例化(带有因果语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig)—将要实例化的模型类是基于配置类选择的:- BertConfig 配置类:TFBertLMHeadModel(BERT 模型)
- CTRLConfig 配置类:TFCTRLLMHeadModel(CTRL 模型)
- CamembertConfig 配置类:TFCamembertForCausalLM(CamemBERT 模型)
- GPT2Config 配置类:TFGPT2LMHeadModel(OpenAI GPT-2 模型)
- GPTJConfig 配置类:TFGPTJForCausalLM(GPT-J 模型)
- OPTConfig 配置类: TFOPTForCausalLM (OPT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- RemBertConfig 配置类: TFRemBertForCausalLM (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForCausalLM (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForCausalLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- TransfoXLConfig 配置类: TFTransfoXLLMHeadModel (Transformer-XL 模型)
- XGLMConfig 配置类: TFXGLMForCausalLM (XGLM 模型)
- XLMConfig 配置类: TFXLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForCausalLM (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetLMHeadModel (XLNet 模型)
从配置实例化库中的一个模型类(带有因果语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = TFAutoModelForCausalLM.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是以下之一:- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained() 保存的模型权重的目录路径,例如,
./my_model_directory/
。 - 路径或 url 到PyTorch 状态字典保存文件(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当:- 模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
或os.PathLike
,可选)— 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_pt
(bool
,可选,默认为False
)— 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
output_loading_info(bool,
可选,默认为False
)— 是否返回包含丢失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
)— 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
,可选,默认为"main"
)— 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
)— 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
)— 用于 Hub 上的代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选)— 可以用于更新配置对象(在加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供了config
,行为会有所不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成)。 - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有因果语言建模头)。
根据配置对象的 model_type
属性选择要实例化的模型类(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺失时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
bert
— TFBertLMHeadModel (BERT 模型) -
camembert
— TFCamembertForCausalLM (CamemBERT 模型) -
ctrl
— TFCTRLLMHeadModel (CTRL 模型) -
gpt-sw3
— TFGPT2LMHeadModel (GPT-Sw3 模型) -
gpt2
— TFGPT2LMHeadModel (OpenAI GPT-2 模型) -
gptj
— TFGPTJForCausalLM (GPT-J 模型) -
openai-gpt
— TFOpenAIGPTLMHeadModel (OpenAI GPT 模型) -
opt
— TFOPTForCausalLM (OPT 模型) -
rembert
— TFRemBertForCausalLM (RemBERT 模型) -
roberta
— TFRobertaForCausalLM (RoBERTa 模型) -
roberta-prelayernorm
— TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型) -
roformer
— TFRoFormerForCausalLM (RoFormer 模型) -
transfo-xl
— TFTransfoXLLMHeadModel (Transformer-XL 模型) -
xglm
— TFXGLMForCausalLM (XGLM 模型) -
xlm
— TFXLMWithLMHeadModel (XLM 模型) -
xlm-roberta
— TFXLMRobertaForCausalLM (XLM-RoBERTa 模型) -
xlnet
— TFXLNetLMHeadModel (XLNet 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForCausalLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForCausalLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForCausalLM
class transformers.FlaxAutoModelForCausalLM
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将实例化为库中的一个模型类(带有因果语言建模头)。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- BartConfig 配置类: FlaxBartForCausalLM (BART 模型)
- BertConfig 配置类: FlaxBertForCausalLM (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForCausalLM (BigBird 模型)
- BloomConfig 配置类: FlaxBloomForCausalLM (BLOOM 模型)
- ElectraConfig 配置类: FlaxElectraForCausalLM (ELECTRA 模型)
- GPT2Config 配置类: FlaxGPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTJConfig 配置类: FlaxGPTJForCausalLM (GPT-J 模型)
- GPTNeoConfig 配置类: FlaxGPTNeoForCausalLM (GPT Neo 模型)
- LlamaConfig 配置类: FlaxLlamaForCausalLM (LLaMA 模型)
- OPTConfig 配置类: FlaxOPTForCausalLM (OPT 模型)
- RobertaConfig 配置类: FlaxRobertaForCausalLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- XGLMConfig 配置类: FlaxXGLMForCausalLM (XGLM 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForCausalLM (XLM-RoBERTa 模型)
从配置中实例化库中的一个模型类(带有因果语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForCausalLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = FlaxAutoModelForCausalLM.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
oros.PathLike
) — 可以是:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict save file的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应该将配置对象作为config
参数提供。使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并加载 TensorFlow 模型后,此加载路径比较慢。
- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
-
model_args
(额外的位置参数,optional) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig, optional) — 用于模型的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
并在目录中找到名为 config.json 的配置 JSON 文件来加载模型。
-
cache_dir
(str
oros.PathLike
, optional) — 下载的预训练模型配置应该缓存在其中的目录路径,如果不应使用标准缓存。 -
from_pt
(bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理服务器在每个请求上使用。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否只查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(附加关键字参数,可选) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供或自动加载config
,行为会有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给基础模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
中与配置属性对应的每个键将用于使用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给基础模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有因果语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(作为参数传递或从pretrained_model_name_or_path
加载,如果可能的话),或者当缺少时,通过在pretrained_model_name_or_path
上使用模式匹配来回退:
-
bart
— FlaxBartForCausalLM(BART 模型) -
bert
— FlaxBertForCausalLM(BERT 模型) -
big_bird
— FlaxBigBirdForCausalLM(BigBird 模型) -
bloom
— FlaxBloomForCausalLM(BLOOM 模型) -
electra
— FlaxElectraForCausalLM(ELECTRA 模型) -
gpt-sw3
— FlaxGPT2LMHeadModel(GPT-Sw3 模型) -
gpt2
— FlaxGPT2LMHeadModel(OpenAI GPT-2 模型) -
gpt_neo
— FlaxGPTNeoForCausalLM(GPT Neo 模型) -
gptj
— FlaxGPTJForCausalLM(GPT-J 模型) -
llama
— FlaxLlamaForCausalLM(LLaMA 模型) -
opt
— FlaxOPTForCausalLM(OPT 模型) -
roberta
— FlaxRobertaForCausalLM(RoBERTa 模型) -
roberta-prelayernorm
— FlaxRobertaPreLayerNormForCausalLM(RoBERTa-PreLayerNorm 模型) -
xglm
— FlaxXGLMForCausalLM(XGLM 模型) -
xlm-roberta
— FlaxXLMRobertaForCausalLM(XLM-RoBERTa 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForCausalLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForCausalLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskedLM
class transformers.AutoModelForMaskedLM
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将被实例化为库中的一个模型类(带有一个掩码语言建模头)。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类:AlbertForMaskedLM(ALBERT 模型)
- BartConfig 配置类:BartForConditionalGeneration(BART 模型)
- BertConfig 配置类:BertForMaskedLM(BERT 模型)
- BigBirdConfig 配置类:BigBirdForMaskedLM(BigBird 模型)
- CamembertConfig 配置类:CamembertForMaskedLM(CamemBERT 模型)
- ConvBertConfig 配置类:ConvBertForMaskedLM(ConvBERT 模型)
- Data2VecTextConfig 配置类:Data2VecTextForMaskedLM(Data2VecText 模型)
- DebertaConfig 配置类:DebertaForMaskedLM(DeBERTa 模型)
- DebertaV2Config 配置类:DebertaV2ForMaskedLM(DeBERTa-v2 模型)
- DistilBertConfig 配置类:DistilBertForMaskedLM(DistilBERT 模型)
- ElectraConfig 配置类:ElectraForMaskedLM(ELECTRA 模型)
- ErnieConfig 配置类: ErnieForMaskedLM (ERNIE 模型)
- EsmConfig 配置类: EsmForMaskedLM (ESM 模型)
- FNetConfig 配置类: FNetForMaskedLM (FNet 模型)
- FlaubertConfig 配置类: FlaubertWithLMHeadModel (FlauBERT 模型)
- FunnelConfig 配置类: FunnelForMaskedLM (Funnel Transformer 模型)
- IBertConfig 配置类: IBertForMaskedLM (I-BERT 模型)
- LayoutLMConfig 配置类: LayoutLMForMaskedLM (LayoutLM 模型)
- LongformerConfig 配置类: LongformerForMaskedLM (Longformer 模型)
- LukeConfig 配置类: LukeForMaskedLM (LUKE 模型)
- MBartConfig 配置类: MBartForConditionalGeneration (mBART 模型)
- MPNetConfig 配置类: MPNetForMaskedLM (MPNet 模型)
- MegaConfig 配置类: MegaForMaskedLM (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForMaskedLM (Megatron-BERT 模型)
- MobileBertConfig 配置类: MobileBertForMaskedLM (MobileBERT 模型)
- MraConfig 配置类: MraForMaskedLM (MRA 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NezhaConfig 配置类: NezhaForMaskedLM (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForMaskedLM (Nyströmformer 模型)
- PerceiverConfig 配置类: PerceiverForMaskedLM (Perceiver 模型)
- QDQBertConfig 配置类: QDQBertForMaskedLM (QDQBert 模型)
- ReformerConfig 配置类: ReformerForMaskedLM (Reformer 模型)
- RemBertConfig 配置类: RemBertForMaskedLM (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForMaskedLM (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForMaskedLM (RoFormer 模型)
- RobertaConfig 配置类: RobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类: SqueezeBertForMaskedLM (SqueezeBERT 模型)
- TapasConfig 配置类: TapasForMaskedLM (TAPAS 模型)
- Wav2Vec2Config 配置类:
Wav2Vec2ForMaskedLM
(Wav2Vec2 模型) - XLMConfig 配置类: XLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- XmodConfig 配置类: XmodForMaskedLM (X-MOD 模型)
- YosoConfig 配置类:YosoForMaskedLM(YOSO 模型)
从配置实例化库中的一个模型类(带有掩码语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForMaskedLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = AutoModelForMaskedLM.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如,
./my_model_directory/
。 - 一个TensorFlow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象提供为config
参数。使用此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库中。有效的模型 ID 可以位于根级别,如
-
model_args
(额外的位置参数,可选) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:- 该模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并且在目录中找到名为config.json的配置 JSON 文件。
-
state_dict
(Dict[str, torch.Tensor],可选) — 一个状态字典,用于替代从保存的权重文件加载的状态字典。 如果要从预训练配置创建模型但加载自己的权重,则可以使用此选项。但在这种情况下,您应该检查是否使用 save_pretrained()和 from_pretrained()不是更简单的选项。 -
cache_dir
(str
或os.PathLike
, 可选) — 下载预训练模型配置应该缓存在其中的目录路径,如果不应使用标准缓存。 -
from_tf
(bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, 可选, 默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求中使用。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否还返回包含缺少键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读了代码,因为它将在本地计算机上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 代码在 Hub 上使用的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(附加关键字参数,optional) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供了config
或自动加载:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型中实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(如果可能作为参数传递或从pretrained_model_name_or_path
加载),或者当缺少时,通过在pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— AlbertForMaskedLM(ALBERT 模型) -
bart
— BartForConditionalGeneration(BART 模型) -
bert
— BertForMaskedLM(BERT 模型) -
big_bird
— BigBirdForMaskedLM(BigBird 模型) -
camembert
— CamembertForMaskedLM(CamemBERT 模型) -
convbert
— ConvBertForMaskedLM(ConvBERT 模型) -
data2vec-text
— Data2VecTextForMaskedLM(Data2VecText 模型) -
deberta
— DebertaForMaskedLM(DeBERTa 模型) -
deberta-v2
— DebertaV2ForMaskedLM (DeBERTa-v2 模型) -
distilbert
— DistilBertForMaskedLM (DistilBERT 模型) -
electra
— ElectraForMaskedLM (ELECTRA 模型) -
ernie
— ErnieForMaskedLM (ERNIE 模型) -
esm
— EsmForMaskedLM (ESM 模型) -
flaubert
— FlaubertWithLMHeadModel (FlauBERT 模型) -
fnet
— FNetForMaskedLM (FNet 模型) -
funnel
— FunnelForMaskedLM (Funnel Transformer 模型) -
ibert
— IBertForMaskedLM (I-BERT 模型) -
layoutlm
— LayoutLMForMaskedLM (LayoutLM 模型) -
longformer
— LongformerForMaskedLM (Longformer 模型) -
luke
— LukeForMaskedLM (LUKE 模型) -
mbart
— MBartForConditionalGeneration (mBART 模型) -
mega
— MegaForMaskedLM (MEGA 模型) -
megatron-bert
— MegatronBertForMaskedLM (Megatron-BERT 模型) -
mobilebert
— MobileBertForMaskedLM (MobileBERT 模型) -
mpnet
— MPNetForMaskedLM (MPNet 模型) -
mra
— MraForMaskedLM (MRA 模型) -
mvp
— MvpForConditionalGeneration (MVP 模型) -
nezha
— NezhaForMaskedLM (Nezha 模型) -
nystromformer
— NystromformerForMaskedLM (Nyströmformer 模型) -
perceiver
— PerceiverForMaskedLM (Perceiver 模型) -
qdqbert
— QDQBertForMaskedLM (QDQBert 模型) -
reformer
— ReformerForMaskedLM (Reformer 模型) -
rembert
— RemBertForMaskedLM (RemBERT 模型) -
roberta
— RobertaForMaskedLM (RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertForMaskedLM (RoCBert 模型) -
roformer
— RoFormerForMaskedLM (RoFormer 模型) -
squeezebert
— SqueezeBertForMaskedLM(SqueezeBERT 模型) -
tapas
— TapasForMaskedLM(TAPAS 模型) -
wav2vec2
—Wav2Vec2ForMaskedLM
(Wav2Vec2 模型) -
xlm
— XLMWithLMHeadModel(XLM 模型) -
xlm-roberta
— XLMRobertaForMaskedLM(XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLForMaskedLM(XLM-RoBERTa-XL 模型) -
xmod
— XmodForMaskedLM(X-MOD 模型) -
yoso
— YosoForMaskedLM(YOSO 模型)
默认情况下,使用 model.eval()
将模型设置为评估模式(例如,关闭了 dropout 模块)。要训练模型,应该首先使用 model.train()
将其设置回训练模式。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedLM
class transformers.TFAutoModelForMaskedLM
<来源>
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,在使用 from_pretrained() 类方法或 from_config() 类方法创建时,将作为库中的模型类之一实例化(带有掩码语言建模头)。
这个类不能直接使用 __init__()
实例化(会报错)。
from_config
<来源>
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类:TFAlbertForMaskedLM(ALBERT 模型)
- BertConfig 配置类:TFBertForMaskedLM(BERT 模型)
- CamembertConfig 配置类:TFCamembertForMaskedLM(CamemBERT 模型)
- ConvBertConfig 配置类:TFConvBertForMaskedLM(ConvBERT 模型)
- DebertaConfig 配置类:TFDebertaForMaskedLM(DeBERTa 模型)
- DebertaV2Config 配置类:TFDebertaV2ForMaskedLM(DeBERTa-v2 模型)
- DistilBertConfig 配置类:TFDistilBertForMaskedLM(DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForMaskedLM (ELECTRA 模型)
- EsmConfig 配置类: TFEsmForMaskedLM (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertWithLMHeadModel (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForMaskedLM (Funnel Transformer 模型)
- LayoutLMConfig 配置类: TFLayoutLMForMaskedLM (LayoutLM 模型)
- LongformerConfig 配置类: TFLongformerForMaskedLM (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForMaskedLM (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForMaskedLM (MobileBERT 模型)
- RemBertConfig 配置类: TFRemBertForMaskedLM (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForMaskedLM (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- TapasConfig 配置类: TFTapasForMaskedLM (TAPAS 模型)
- XLMConfig 配置类: TFXLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
从配置实例化库中的一个模型类(带有掩码语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = TFAutoModelForMaskedLM.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 预训练模型的 model id 字符串,托管在 huggingface.co 上的模型存储库内。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下进行命名空间,如dbmdz/bert-base-german-cased
。 - 包含使用 save_pretrained() 保存的模型权重的 目录 路径,例如
./my_model_directory/
。 - 指向 PyTorch state_dict 保存文件 的路径或 url(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象提供为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 预训练模型的 model id 字符串,托管在 huggingface.co 上的模型存储库内。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选) — 用于模型的配置,而不是自动加载的配置。当:- 该模型是库提供的模型(使用预训练模型的 model id 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
并在该目录中找到名为 config.json 的配置 JSON 文件来加载模型。
-
cache_dir
(str
或os.PathLike
,可选) — 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_pt
(bool
,可选,默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已存在的缓存版本。 -
resume_download
(bool
,可选,默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理将在每个请求上使用。 -
output_loading_info(bool,
可选,默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
,可选,默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 id,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(附加关键字参数,可选) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为会有所不同:- 如果提供了带有
config
的配置,**kwargs
将直接传递给基础模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
的每个对应于配置属性的键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给基础模型的__init__
函数。
- 如果提供了带有
从预训练模型实例化库中的一个模型类(带有遮蔽语言建模头)。
要实例化的模型类基于配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能的话),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来选择:
-
albert
— TFAlbertForMaskedLM (ALBERT 模型) -
bert
— TFBertForMaskedLM (BERT 模型) -
camembert
— TFCamembertForMaskedLM (CamemBERT 模型) -
convbert
— TFConvBertForMaskedLM (ConvBERT 模型) -
deberta
— TFDebertaForMaskedLM (DeBERTa 模型) -
deberta-v2
— TFDebertaV2ForMaskedLM (DeBERTa-v2 模型) -
distilbert
— TFDistilBertForMaskedLM (DistilBERT 模型) -
electra
— TFElectraForMaskedLM (ELECTRA 模型) -
esm
— TFEsmForMaskedLM (ESM 模型) -
flaubert
— TFFlaubertWithLMHeadModel (FlauBERT 模型) -
funnel
— TFFunnelForMaskedLM (Funnel Transformer 模型) -
layoutlm
— TFLayoutLMForMaskedLM (LayoutLM 模型) -
longformer
— TFLongformerForMaskedLM (Longformer 模型) -
mobilebert
— TFMobileBertForMaskedLM (MobileBERT 模型) -
mpnet
— TFMPNetForMaskedLM (MPNet 模型) -
rembert
— TFRemBertForMaskedLM (RemBERT 模型) -
roberta
— TFRobertaForMaskedLM (RoBERTa 模型) -
roberta-prelayernorm
— TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型) -
roformer
— TFRoFormerForMaskedLM (RoFormer 模型) -
tapas
— TFTapasForMaskedLM (TAPAS 模型) -
xlm
— TFXLMWithLMHeadModel (XLM 模型) -
xlm-roberta
— TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMaskedLM
class transformers.FlaxAutoModelForMaskedLM
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将被实例化为库中的模型类之一(带有遮蔽语言建模头)。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类: FlaxAlbertForMaskedLM (ALBERT 模型)
- BartConfig 配置类: FlaxBartForConditionalGeneration (BART 模型)
- BertConfig 配置类: FlaxBertForMaskedLM (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForMaskedLM (BigBird 模型)
- DistilBertConfig 配置类: FlaxDistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraForMaskedLM (ELECTRA 模型)
- MBartConfig 配置类:FlaxMBartForConditionalGeneration(mBART 模型)
- RoFormerConfig 配置类:FlaxRoFormerForMaskedLM(RoFormer 模型)
- RobertaConfig 配置类:FlaxRobertaForMaskedLM(RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类:FlaxRobertaPreLayerNormForMaskedLM(RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类:FlaxXLMRobertaForMaskedLM(XLM-RoBERTa 模型)
从配置实例化库中的一个模型类(带有掩码语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForMaskedLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = FlaxAutoModelForMaskedLM.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是:- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 id。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch 状态字典保存文件的路径或 url(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象提供为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 id。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当以下情况自动加载配置时:- 该模型是库提供的模型(使用预训练模型的模型 id字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
或os.PathLike
,可选)— 预下载的模型配置应缓存在其中的目录路径,如果不使用标准缓存。 -
from_pt
(bool
,可选,默认为False
)— 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
output_loading_info(bool,
可选,默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
,可选,默认为"main"
)— 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
) — 是否允许在 Hub 上定义自定义模型的代码文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
)— 用于 Hub 上的代码的特定修订版本,如果代码与模型的其余部分不在同一存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选)— 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为会有所不同:- 如果提供了
config
配置,**kwargs
将直接传递给基础模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给基础模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
实例化的模型类是根据配置对象的model_type
属性选择的(作为参数传递或从pretrained_model_name_or_path
加载,如果可能的话),或者当缺失时,通过在pretrained_model_name_or_path
上进行模式匹配来回退:
-
albert
— FlaxAlbertForMaskedLM(ALBERT 模型) -
bart
— FlaxBartForConditionalGeneration(BART 模型) -
bert
— FlaxBertForMaskedLM(BERT 模型) -
big_bird
— FlaxBigBirdForMaskedLM(BigBird 模型) -
distilbert
— FlaxDistilBertForMaskedLM(DistilBERT 模型) -
electra
- FlaxElectraForMaskedLM(ELECTRA 模型) -
mbart
- FlaxMBartForConditionalGeneration(mBART 模型) -
roberta
- FlaxRobertaForMaskedLM(RoBERTa 模型) -
roberta-prelayernorm
- FlaxRobertaPreLayerNormForMaskedLM(RoBERTa-PreLayerNorm 模型) -
roformer
- FlaxRoFormerForMaskedLM(RoFormer 模型) -
xlm-roberta
- FlaxXLMRobertaForMaskedLM(XLM-RoBERTa 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
自动模型用于生成口罩
class transformers.AutoModelForMaskGeneration
<来源>
代码语言:javascript复制( *args **kwargs )
TFAutoModelForMaskGeneration
class transformers.TFAutoModelForMaskGeneration
<来源>
代码语言:javascript复制( *args **kwargs )
AutoModelForSeq2SeqLM
class transformers.AutoModelForSeq2SeqLM
<来源>
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained()类方法或 from_config()类方法创建时,将实例化为库的模型类之一(带有序列到序列语言建模头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
<来源>
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig)-选择要实例化的模型类基于配置类:- BartConfig 配置类:BartForConditionalGeneration(BART 模型)
- BigBirdPegasusConfig 配置类:BigBirdPegasusForConditionalGeneration(BigBird-Pegasus 模型)
- BlenderbotConfig 配置类:BlenderbotForConditionalGeneration(Blenderbot 模型)
- BlenderbotSmallConfig 配置类:BlenderbotSmallForConditionalGeneration(BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类:EncoderDecoderModel(编码器解码器模型)
- FSMTConfig 配置类: FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- GPTSanJapaneseConfig 配置类: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- LEDConfig 配置类: LEDForConditionalGeneration (LED 模型)
- LongT5Config 配置类: LongT5ForConditionalGeneration (LongT5 模型)
- M2M100Config 配置类: M2M100ForConditionalGeneration (M2M100 模型)
- MBartConfig 配置类: MBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: MT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: MarianMTModel (Marian 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NllbMoeConfig 配置类: NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- PLBartConfig 配置类: PLBartForConditionalGeneration (PLBart 模型)
- PegasusConfig 配置类: PegasusForConditionalGeneration (Pegasus 模型)
- PegasusXConfig 配置类: PegasusXForConditionalGeneration (PEGASUS-X 模型)
- ProphetNetConfig 配置类: ProphetNetForConditionalGeneration (ProphetNet 模型)
- SeamlessM4TConfig 配置类: SeamlessM4TForTextToText (SeamlessM4T 模型)
- SeamlessM4Tv2Config 配置类:SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型)
- SwitchTransformersConfig 配置类:SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- T5Config 配置类:T5ForConditionalGeneration (T5 模型)
- UMT5Config 配置类:UMT5ForConditionalGeneration (UMT5 模型)
- XLMProphetNetConfig 配置类:XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
从配置实例化库中的模型类(带有序列到序列语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即在 huggingface.co 上托管的预训练模型的 模型 id。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下进行命名空间化,如dbmdz/bert-base-german-cased
。 - 一个指向使用 save_pretrained() 保存的模型权重的 目录 的路径,例如
./my_model_directory/
。 - 一个指向 tensorflow 索引检查点文件 的路径或 url(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象提供为config
参数。使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型后,此加载路径比较慢。
- 一个字符串,即在 huggingface.co 上托管的预训练模型的 模型 id。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:- 模型是库提供的模型(使用预训练模型的 模型 id 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
并且在目录中找到名为 config.json 的配置 JSON 文件来加载模型。
-
state_dict
(Dict[str, torch.Tensor], 可选) — 用于替代从保存的权重文件加载的状态字典的状态字典。 如果您想从预训练配置创建模型,但加载自己的权重,可以使用此选项。不过,在这种情况下,您应该检查使用 save_pretrained()和 from_pretrained()是否不是更简单的选项。 -
cache_dir
(str
或os.PathLike
,可选) — 下载的预训练模型配置应该缓存在其中的目录路径,如果不使用标准缓存。 -
from_tf
(bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, 可选, 默认为False
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, 可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 -
output_loading_info(bool,
可选, 默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, 可选, 默认为False
) — 是否允许在 Hub 上定义自定义模型的代码。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为会有所不同:- 如果提供了配置
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了配置
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类基于配置对象的 model_type
属性进行选择(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺失时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
bart
— BartForConditionalGeneration (BART 模型) -
bigbird_pegasus
— BigBirdPegasusForConditionalGeneration (BigBird-Pegasus 模型) -
blenderbot
— BlenderbotForConditionalGeneration (Blenderbot 模型) -
blenderbot-small
— BlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型) -
encoder-decoder
— EncoderDecoderModel (编码器解码器模型) -
fsmt
— FSMTForConditionalGeneration (FairSeq 机器翻译模型) -
gptsan-japanese
— GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型) -
led
— LEDForConditionalGeneration (LED 模型) -
longt5
— LongT5ForConditionalGeneration (LongT5 模型) -
m2m_100
— M2M100ForConditionalGeneration (M2M100 模型) -
marian
— MarianMTModel (Marian 模型) -
mbart
— MBartForConditionalGeneration (mBART 模型) -
mt5
— MT5ForConditionalGeneration (MT5 模型) -
mvp
— MvpForConditionalGeneration (MVP 模型) -
nllb-moe
— NllbMoeForConditionalGeneration (NLLB-MOE 模型) -
pegasus
— PegasusForConditionalGeneration (Pegasus 模型) -
pegasus_x
— PegasusXForConditionalGeneration (PEGASUS-X 模型) -
plbart
— PLBartForConditionalGeneration (PLBart 模型) -
prophetnet
— ProphetNetForConditionalGeneration (ProphetNet 模型) -
seamless_m4t
— SeamlessM4TForTextToText (SeamlessM4T 模型) -
seamless_m4t_v2
— SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型) -
switch_transformers
— SwitchTransformersForConditionalGeneration (SwitchTransformers 模型) -
t5
— T5ForConditionalGeneration (T5 模型) -
umt5
— UMT5ForConditionalGeneration (UMT5 模型) -
xlm-prophetnet
— XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
默认情况下,模型处于评估模式,使用 model.eval()
(例如,dropout 模块被停用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> # Update configuration during loading
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/t5_tf_model_config.json")
>>> model = AutoModelForSeq2SeqLM.from_pretrained(
... "./tf_model/t5_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSeq2SeqLM
class transformers.TFAutoModelForSeq2SeqLM
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,在使用 from_pretrained() 类方法或 from_config() 类方法创建时,将作为库中的模型类之一实例化(带有序列到序列语言建模头)。
这个类不能直接使用 __init__()
实例化(会报错)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- BartConfig 配置类: TFBartForConditionalGeneration (BART 模型)
- BlenderbotConfig 配置类: TFBlenderbotForConditionalGeneration (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类: TFEncoderDecoderModel (编码器解码器模型)
- LEDConfig 配置类: TFLEDForConditionalGeneration (LED 模型)
- MBartConfig 配置类: TFMBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: TFMT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: TFMarianMTModel (Marian 模型)
- PegasusConfig 配置类: TFPegasusForConditionalGeneration (Pegasus 模型)
- T5Config 配置类:TFT5ForConditionalGeneration(T5 模型)
从配置中实例化库中的一个模型类(带有序列到序列语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("t5-base")
>>> model = TFAutoModelForSeq2SeqLM.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是:- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 ID。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或命名空间下的用户或组织名称,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/
。 - 一个PyTorch 状态字典保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象提供为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 ID。有效的模型 ID 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:- 该模型是由库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
或os.PathLike
,可选)— 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。 -
from_pt
(bool
,可选,默认为False
)— 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除未完全接收的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 -
output_loading_info(bool,
可选,默认为False
) — 是否还返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 id,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, 可选, 默认为False
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, 可选, 默认为"main"
) — 用于 Hub 上的代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 id,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为会有所不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类的初始化函数(from_pretrained())。kwargs
的每个键对应一个配置属性,将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
bart
— TFBartForConditionalGeneration (BART 模型) -
blenderbot
— TFBlenderbotForConditionalGeneration (Blenderbot 模型) -
blenderbot-small
— TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型) -
encoder-decoder
— TFEncoderDecoderModel (编码器解码器模型) -
led
— TFLEDForConditionalGeneration (LED 模型) -
marian
— TFMarianMTModel (Marian 模型) -
mbart
— TFMBartForConditionalGeneration (mBART 模型) -
mt5
— TFMT5ForConditionalGeneration (MT5 模型) -
pegasus
— TFPegasusForConditionalGeneration (Pegasus 模型) -
t5
— TFT5ForConditionalGeneration (T5 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> # Update configuration during loading
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSeq2SeqLM
class transformers.FlaxAutoModelForSeq2SeqLM
<来源>
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将作为库的模型类之一实例化(带有序列到序列语言建模头)。
这个类不能直接使用__init__()
进行实例化(会抛出错误)。
from_config
<来源>
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig)— 选择要实例化的模型类基于配置类:- BartConfig 配置类:FlaxBartForConditionalGeneration(BART 模型)
- BlenderbotConfig 配置类:FlaxBlenderbotForConditionalGeneration(Blenderbot 模型)
- BlenderbotSmallConfig 配置类:FlaxBlenderbotSmallForConditionalGeneration(BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类:FlaxEncoderDecoderModel(编码器解码器模型)
- LongT5Config 配置类:FlaxLongT5ForConditionalGeneration(LongT5 模型)
- MBartConfig 配置类:FlaxMBartForConditionalGeneration(mBART 模型)
- MT5Config 配置类:FlaxMT5ForConditionalGeneration(MT5 模型)
- MarianConfig 配置类:FlaxMarianMTModel(Marian 模型)
- PegasusConfig 配置类:FlaxPegasusForConditionalGeneration(Pegasus 模型)
- T5Config 配置类:FlaxT5ForConditionalGeneration(T5 模型)
从配置实例化库的模型类之一(带有序列到序列语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("t5-base")
>>> model = FlaxAutoModelForSeq2SeqLM.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是:- 一个字符串,预训练模型的模型标识符,托管在 huggingface.co 上的模型存储库中。有效的模型标识符可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间,如dbmdz/bert-base-german-cased
。 - 指向使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/
。 - 指向PyTorch 状态字典保存文件的路径或 url(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,应将from_pt
设置为True
,并将配置对象提供为config
参数。使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型的加载路径比较慢。
- 一个字符串,预训练模型的模型标识符,托管在 huggingface.co 上的模型存储库中。有效的模型标识符可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:- 该模型是库提供的模型(使用预训练模型的模型标识符字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
或os.PathLike
,可选)— 下载预训练模型配置应缓存的目录路径,如果不使用标准缓存。 -
from_pt
(bool
,可选,默认为False
)— 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。每个请求都会使用代理。 -
output_loading_info(bool,
可选,默认为False
)— 是否还返回包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
)— 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
,可选,默认为"main"
)— 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
) — 是否允许在 Hub 上定义自定义模型并在其自己的建模文件中执行。此选项应仅在您信任的存储库中设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,output_attentions=True
)。根据是否提供或自动加载了config
,行为会有所不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(作为参数传递或从pretrained_model_name_or_path
加载,如果可能的话),或者当缺少时,通过在pretrained_model_name_or_path
上使用模式匹配来回退:
-
bart
— FlaxBartForConditionalGeneration (BART 模型) -
blenderbot
— FlaxBlenderbotForConditionalGeneration (Blenderbot 模型) -
blenderbot-small
— FlaxBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型) -
encoder-decoder
— FlaxEncoderDecoderModel (编码器解码器模型) -
longt5
— FlaxLongT5ForConditionalGeneration (LongT5 模型) -
marian
— FlaxMarianMTModel (Marian 模型) -
mbart
— FlaxMBartForConditionalGeneration (mBART 模型) -
mt5
— FlaxMT5ForConditionalGeneration (MT5 模型) -
pegasus
— FlaxPegasusForConditionalGeneration (Pegasus 模型) -
t5
— FlaxT5ForConditionalGeneration (T5 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForSequenceClassification
class transformers.AutoModelForSequenceClassification
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将实例化为库的模型类之一(带有序列分类头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig)— 选择要实例化的模型类基于配置类:- AlbertConfig 配置类:AlbertForSequenceClassification(ALBERT 模型)
- BartConfig 配置类:BartForSequenceClassification(BART 模型)
- BertConfig 配置类:BertForSequenceClassification(BERT 模型)
- BigBirdConfig 配置类:BigBirdForSequenceClassification(BigBird 模型)
- BigBirdPegasusConfig 配置类:BigBirdPegasusForSequenceClassification(BigBird-Pegasus 模型)
- BioGptConfig 配置类:BioGptForSequenceClassification(BioGpt 模型)
- BloomConfig 配置类:BloomForSequenceClassification(BLOOM 模型)
- CTRLConfig 配置类:CTRLForSequenceClassification(CTRL 模型)
- CamembertConfig 配置类:CamembertForSequenceClassification(CamemBERT 模型)
- CanineConfig 配置类:CanineForSequenceClassification(CANINE 模型)
- ConvBertConfig 配置类:ConvBertForSequenceClassification(ConvBERT 模型)
- Data2VecTextConfig 配置类:Data2VecTextForSequenceClassification(Data2VecText 模型)
- DebertaConfig 配置类:DebertaForSequenceClassification(DeBERTa 模型)
- DebertaV2Config 配置类:DebertaV2ForSequenceClassification(DeBERTa-v2 模型)
- DistilBertConfig 配置类:DistilBertForSequenceClassification(DistilBERT 模型)
- ElectraConfig 配置类:ElectraForSequenceClassification(ELECTRA 模型)
- ErnieConfig 配置类:ErnieForSequenceClassification(ERNIE 模型)
- ErnieMConfig 配置类:ErnieMForSequenceClassification(ErnieM 模型)
- EsmConfig 配置类:EsmForSequenceClassification(ESM 模型)
- FNetConfig 配置类:FNetForSequenceClassification(FNet 模型)
- FalconConfig 配置类:FalconForSequenceClassification(Falcon 模型)
- FlaubertConfig 配置类:FlaubertForSequenceClassification(FlauBERT 模型)
- FunnelConfig 配置类:FunnelForSequenceClassification(Funnel Transformer 模型)
- GPT2Config 配置类:GPT2ForSequenceClassification(OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类:GPTBigCodeForSequenceClassification(GPTBigCode 模型)
- GPTJConfig 配置类:GPTJForSequenceClassification(GPT-J 模型)
- GPTNeoConfig 配置类:GPTNeoForSequenceClassification(GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForSequenceClassification (GPT NeoX 模型)
- IBertConfig 配置类: IBertForSequenceClassification (I-BERT 模型)
- LEDConfig 配置类: LEDForSequenceClassification (LED 模型)
- LayoutLMConfig 配置类: LayoutLMForSequenceClassification (LayoutLM 模型)
- LayoutLMv2Config 配置类: LayoutLMv2ForSequenceClassification (LayoutLMv2 模型)
- LayoutLMv3Config 配置类: LayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- LiltConfig 配置类: LiltForSequenceClassification (LiLT 模型)
- LlamaConfig 配置类: LlamaForSequenceClassification (LLaMA 模型)
- LongformerConfig 配置类: LongformerForSequenceClassification (Longformer 模型)
- LukeConfig 配置类: LukeForSequenceClassification (LUKE 模型)
- MBartConfig 配置类: MBartForSequenceClassification (mBART 模型)
- MPNetConfig 配置类: MPNetForSequenceClassification (MPNet 模型)
- MT5Config 配置类: MT5ForSequenceClassification (MT5 模型)
- MarkupLMConfig 配置类: MarkupLMForSequenceClassification (MarkupLM 模型)
- MegaConfig 配置类: MegaForSequenceClassification (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForSequenceClassification (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForSequenceClassification (Mistral 模型)
- MixtralConfig 配置类: MixtralForSequenceClassification (Mixtral 模型)
- MobileBertConfig 配置类: MobileBertForSequenceClassification (MobileBERT 模型)
- MptConfig 配置类: MptForSequenceClassification (MPT 模型)
- MraConfig 配置类: MraForSequenceClassification (MRA 模型)
- MvpConfig 配置类: MvpForSequenceClassification (MVP 模型)
- NezhaConfig 配置类: NezhaForSequenceClassification (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForSequenceClassification (Nyströmformer 模型)
- OPTConfig 配置类: OPTForSequenceClassification (OPT 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- OpenLlamaConfig 配置类: OpenLlamaForSequenceClassification (OpenLlama 模型)
- PLBartConfig 配置类: PLBartForSequenceClassification (PLBart 模型)
- PerceiverConfig 配置类: PerceiverForSequenceClassification (Perceiver 模型)
- PersimmonConfig 配置类: PersimmonForSequenceClassification (Persimmon 模型)
- PhiConfig 配置类:PhiForSequenceClassification(Phi 模型)
- QDQBertConfig 配置类:QDQBertForSequenceClassification(QDQBert 模型)
- Qwen2Config 配置类:Qwen2ForSequenceClassification(Qwen2 模型)
- ReformerConfig 配置类:ReformerForSequenceClassification(Reformer 模型)
- RemBertConfig 配置类:RemBertForSequenceClassification(RemBERT 模型)
- RoCBertConfig 配置类:RoCBertForSequenceClassification(RoCBert 模型)
- RoFormerConfig 配置类:RoFormerForSequenceClassification(RoFormer 模型)
- RobertaConfig 配置类:RobertaForSequenceClassification(RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类:RobertaPreLayerNormForSequenceClassification(RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类:SqueezeBertForSequenceClassification(SqueezeBERT 模型)
- T5Config 配置类:T5ForSequenceClassification(T5 模型)
- TapasConfig 配置类:TapasForSequenceClassification(TAPAS 模型)
- TransfoXLConfig 配置类:TransfoXLForSequenceClassification(Transformer-XL 模型)
- UMT5Config 配置类:UMT5ForSequenceClassification(UMT5 模型)
- XLMConfig 配置类:XLMForSequenceClassification(XLM 模型)
- XLMRobertaConfig 配置类:XLMRobertaForSequenceClassification(XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类:XLMRobertaXLForSequenceClassification(XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类:XLNetForSequenceClassification(XLNet 模型)
- XmodConfig 配置类:XmodForSequenceClassification(X-MOD 模型)
- YosoConfig 配置类:YosoForSequenceClassification(YOSO 模型)
从配置实例化库中的一个模型类(带有序列分类头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForSequenceClassification
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = AutoModelForSequenceClassification.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级别,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/
。 - 一个TensorFlow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象提供为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型要慢。
- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选)— 用于替代自动加载的配置的模型配置。当以下情况时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录来重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
state_dict
(Dict[str, torch.Tensor],可选)— 用于替代从保存的权重文件加载的状态字典的状态字典。 如果要从预训练配置创建模型但加载自己的权重,则可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained()和 from_pretrained()是否不是更简单的选项。 -
cache_dir
(str
oros.PathLike
, optional) — 下载的预训练模型配置应缓存在其中的目录路径,如果不使用标准缓存。 -
from_tf
(bool
, optional, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 用于每个请求的代理服务器的协议或端点的字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理将用于每个请求。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否还返回包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅在您信任的存储库中设置为True
,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供了config
,行为会有所不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有序列分类头)。
要实例化的模型类是根据配置对象的 model_type
属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— AlbertForSequenceClassification (ALBERT 模型) -
bart
— BartForSequenceClassification (BART 模型) -
bert
— BertForSequenceClassification (BERT 模型) -
big_bird
— BigBirdForSequenceClassification (BigBird 模型) -
bigbird_pegasus
— BigBirdPegasusForSequenceClassification (BigBird-Pegasus 模型) -
biogpt
— BioGptForSequenceClassification (BioGpt 模型) -
bloom
— BloomForSequenceClassification (BLOOM 模型) -
camembert
— CamembertForSequenceClassification (CamemBERT 模型) -
canine
— CanineForSequenceClassification (CANINE 模型) -
code_llama
— LlamaForSequenceClassification (CodeLlama 模型) -
convbert
— ConvBertForSequenceClassification (ConvBERT 模型) -
ctrl
— CTRLForSequenceClassification (CTRL 模型) -
data2vec-text
— Data2VecTextForSequenceClassification (Data2VecText 模型) -
deberta
— DebertaForSequenceClassification (DeBERTa 模型) -
deberta-v2
— DebertaV2ForSequenceClassification (DeBERTa-v2 模型) -
distilbert
— DistilBertForSequenceClassification (DistilBERT 模型) -
electra
— ElectraForSequenceClassification (ELECTRA 模型) -
ernie
— ErnieForSequenceClassification (ERNIE 模型) -
ernie_m
— ErnieMForSequenceClassification (ErnieM 模型) -
esm
— EsmForSequenceClassification (ESM 模型) -
falcon
— FalconForSequenceClassification (Falcon 模型) -
flaubert
— FlaubertForSequenceClassification (FlauBERT 模型) -
fnet
— FNetForSequenceClassification (FNet 模型) -
funnel
— FunnelForSequenceClassification (Funnel Transformer model) -
gpt-sw3
— GPT2ForSequenceClassification (GPT-Sw3 model) -
gpt2
— GPT2ForSequenceClassification (OpenAI GPT-2 model) -
gpt_bigcode
— GPTBigCodeForSequenceClassification (GPTBigCode model) -
gpt_neo
— GPTNeoForSequenceClassification (GPT Neo model) -
gpt_neox
— GPTNeoXForSequenceClassification (GPT NeoX model) -
gptj
— GPTJForSequenceClassification (GPT-J model) -
ibert
— IBertForSequenceClassification (I-BERT model) -
layoutlm
— LayoutLMForSequenceClassification (LayoutLM model) -
layoutlmv2
— LayoutLMv2ForSequenceClassification (LayoutLMv2 model) -
layoutlmv3
— LayoutLMv3ForSequenceClassification (LayoutLMv3 model) -
led
— LEDForSequenceClassification (LED model) -
lilt
— LiltForSequenceClassification (LiLT model) -
llama
— LlamaForSequenceClassification (LLaMA model) -
longformer
— LongformerForSequenceClassification (Longformer model) -
luke
— LukeForSequenceClassification (LUKE model) -
markuplm
— MarkupLMForSequenceClassification (MarkupLM model) -
mbart
— MBartForSequenceClassification (mBART model) -
mega
— MegaForSequenceClassification (MEGA model) -
megatron-bert
— MegatronBertForSequenceClassification (Megatron-BERT model) -
mistral
— MistralForSequenceClassification (Mistral model) -
mixtral
— MixtralForSequenceClassification (Mixtral model) -
mobilebert
— MobileBertForSequenceClassification (MobileBERT model) -
mpnet
— MPNetForSequenceClassification (MPNet model) -
mpt
— MptForSequenceClassification (MPT model) -
mra
— MraForSequenceClassification (MRA 模型) -
mt5
— MT5ForSequenceClassification (MT5 模型) -
mvp
— MvpForSequenceClassification (MVP 模型) -
nezha
— NezhaForSequenceClassification (Nezha 模型) -
nystromformer
— NystromformerForSequenceClassification (Nyströmformer 模型) -
open-llama
— OpenLlamaForSequenceClassification (OpenLlama 模型) -
openai-gpt
— OpenAIGPTForSequenceClassification (OpenAI GPT 模型) -
opt
— OPTForSequenceClassification (OPT 模型) -
perceiver
— PerceiverForSequenceClassification (Perceiver 模型) -
persimmon
— PersimmonForSequenceClassification (Persimmon 模型) -
phi
— PhiForSequenceClassification (Phi 模型) -
plbart
— PLBartForSequenceClassification (PLBart 模型) -
qdqbert
— QDQBertForSequenceClassification (QDQBert 模型) -
qwen2
— Qwen2ForSequenceClassification (Qwen2 模型) -
reformer
— ReformerForSequenceClassification (Reformer 模型) -
rembert
— RemBertForSequenceClassification (RemBERT 模型) -
roberta
— RobertaForSequenceClassification (RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertForSequenceClassification (RoCBert 模型) -
roformer
— RoFormerForSequenceClassification (RoFormer 模型) -
squeezebert
— SqueezeBertForSequenceClassification (SqueezeBERT 模型) -
t5
— T5ForSequenceClassification (T5 模型) -
tapas
— TapasForSequenceClassification (TAPAS 模型) -
transfo-xl
— TransfoXLForSequenceClassification (Transformer-XL 模型) -
umt5
— UMT5ForSequenceClassification (UMT5 模型) -
xlm
— XLMForSequenceClassification (XLM 模型) -
xlm-roberta
— XLMRobertaForSequenceClassification (XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLForSequenceClassification (XLM-RoBERTa-XL 模型) -
xlnet
— XLNetForSequenceClassification (XLNet 模型) -
xmod
— XmodForSequenceClassification (X-MOD 模型) -
yoso
— YosoForSequenceClassification (YOSO 模型)
默认情况下,该模型处于评估模式,使用 model.eval()
(例如,dropout 模块被停用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSequenceClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSequenceClassification
class transformers.TFAutoModelForSequenceClassification
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将被实例化为库中的一个模型类(带有序列分类头)
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 实例化的模型类基于配置类进行选择:- AlbertConfig 配置类: TFAlbertForSequenceClassification (ALBERT 模型)
- BartConfig 配置类: TFBartForSequenceClassification (BART 模型)
- BertConfig 配置类: TFBertForSequenceClassification (BERT 模型)
- CTRLConfig 配置类: TFCTRLForSequenceClassification (CTRL 模型)
- CamembertConfig 配置类: TFCamembertForSequenceClassification (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForSequenceClassification (ConvBERT 模型)
- DebertaConfig 配置类: TFDebertaForSequenceClassification (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForSequenceClassification (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForSequenceClassification (ELECTRA 模型)
- EsmConfig 配置类: TFEsmForSequenceClassification (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertForSequenceClassification (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForSequenceClassification (Funnel Transformer 模型)
- GPT2Config 配置类: TFGPT2ForSequenceClassification (OpenAI GPT-2 模型)
- GPTJConfig 配置类: TFGPTJForSequenceClassification (GPT-J 模型)
- LayoutLMConfig 配置类: TFLayoutLMForSequenceClassification (LayoutLM 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- LongformerConfig 配置类: TFLongformerForSequenceClassification (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForSequenceClassification (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForSequenceClassification (MobileBERT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- RemBertConfig 配置类: TFRemBertForSequenceClassification (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForSequenceClassification (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForSequenceClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- TapasConfig 配置类: TFTapasForSequenceClassification (TAPAS 模型)
- TransfoXLConfig 配置类: TFTransfoXLForSequenceClassification (Transformer-XL 模型)
- XLMConfig 配置类: TFXLMForSequenceClassification (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetForSequenceClassification (XLNet 模型)
从配置实例化库中的一个模型类(带有序列分类头)。
注意: 从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = TFAutoModelForSequenceClassification.from_config(config)
from_pretrained
<来源>
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如,
./my_model_directory/
。 - 路径或 URL 到PyTorch 状态字典保存文件(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,预训练模型的模型 id,托管在 huggingface.co 上的模型存储库中。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选) — 将传递给底层模型__init__()
方法。 -
config
(PretrainedConfig,可选) — 用于模型的配置,而不是自动加载的配置。当以下情况自动加载配置时:- 模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
或os.PathLike
,可选) — 下载预训练模型配置应该缓存在其中的目录路径,如果不使用标准缓存。 -
from_pt
(bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
, optional, defaults toFalse
) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
, optional) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 -
output_loading_info(bool,
optional, defaults toFalse
) — 是否返回包含缺少键、意外键和错误消息的字典。 -
local_files_only(bool,
optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
, optional, defaults toFalse
) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供或自动加载config
,行为会有所不同:- 如果提供了
config
配置,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成)。 - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
的每个键对应一个配置属性,将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给基础模型的__init__
函数。
- 如果提供了
从预训练模型实例化库中的一个模型类(带有序列分类头)。
要实例化的模型类是根据配置对象的 model_type
属性选择的(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能的话),或者当缺少时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— TFAlbertForSequenceClassification (ALBERT 模型) -
bart
— TFBartForSequenceClassification (BART 模型) -
bert
— TFBertForSequenceClassification (BERT 模型) -
camembert
— TFCamembertForSequenceClassification (CamemBERT 模型) -
convbert
— TFConvBertForSequenceClassification (ConvBERT 模型) -
ctrl
— TFCTRLForSequenceClassification (CTRL 模型) -
deberta
— TFDebertaForSequenceClassification (DeBERTa 模型) -
deberta-v2
— TFDebertaV2ForSequenceClassification (DeBERTa-v2 模型) -
distilbert
— TFDistilBertForSequenceClassification (DistilBERT 模型) -
electra
— TFElectraForSequenceClassification (ELECTRA 模型) -
esm
— TFEsmForSequenceClassification (ESM 模型) -
flaubert
— TFFlaubertForSequenceClassification (FlauBERT 模型) -
funnel
— TFFunnelForSequenceClassification (漏斗变换器模型) -
gpt-sw3
— TFGPT2ForSequenceClassification (GPT-Sw3 模型) -
gpt2
— TFGPT2ForSequenceClassification (OpenAI GPT-2 模型) -
gptj
— TFGPTJForSequenceClassification (GPT-J 模型) -
layoutlm
— TFLayoutLMForSequenceClassification (LayoutLM 模型) -
layoutlmv3
— TFLayoutLMv3ForSequenceClassification (LayoutLMv3 模型) -
longformer
— TFLongformerForSequenceClassification (Longformer 模型) -
mobilebert
— TFMobileBertForSequenceClassification (MobileBERT 模型) -
mpnet
— TFMPNetForSequenceClassification (MPNet 模型) -
openai-gpt
— TFOpenAIGPTForSequenceClassification (OpenAI GPT 模型) -
rembert
— TFRemBertForSequenceClassification (RemBERT 模型) -
roberta
— TFRobertaForSequenceClassification (RoBERTa 模型) -
roberta-prelayernorm
— TFRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型) -
roformer
— TFRoFormerForSequenceClassification (RoFormer 模型) -
tapas
— TFTapasForSequenceClassification (TAPAS 模型) -
transfo-xl
— TFTransfoXLForSequenceClassification (Transformer-XL 模型) -
xlm
— TFXLMForSequenceClassification (XLM 模型) -
xlm-roberta
— TFXLMRobertaForSequenceClassification (XLM-RoBERTa 模型) -
xlnet
— TFXLNetForSequenceClassification (XLNet 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSequenceClassification
class transformers.FlaxAutoModelForSequenceClassification
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将被实例化为库中的一个模型类(带有序列分类头)。
这个类不能直接使用 __init__()
实例化 (会报错)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类: FlaxAlbertForSequenceClassification (ALBERT 模型)
- BartConfig 配置类: FlaxBartForSequenceClassification (BART 模型)
- BertConfig 配置类: FlaxBertForSequenceClassification (BERT 模型)
- BigBirdConfig 配置类:FlaxBigBirdForSequenceClassification(BigBird 模型)
- DistilBertConfig 配置类:FlaxDistilBertForSequenceClassification(DistilBERT 模型)
- ElectraConfig 配置类:FlaxElectraForSequenceClassification(ELECTRA 模型)
- MBartConfig 配置类:FlaxMBartForSequenceClassification(mBART 模型)
- RoFormerConfig 配置类:FlaxRoFormerForSequenceClassification(RoFormer 模型)
- RobertaConfig 配置类:FlaxRobertaForSequenceClassification(RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类:FlaxRobertaPreLayerNormForSequenceClassification(RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类:FlaxXLMRobertaForSequenceClassification(XLM-RoBERTa 模型)
从配置实例化库中的一个模型类(带有序列分类头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForSequenceClassification
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = FlaxAutoModelForSequenceClassification.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
) — 可以是:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级,如
bert-base-uncased
,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased
。 - 一个包含使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级,如
-
model_args
(额外的位置参数,可选) — 将传递给底层模型的__init__()
方法。 -
config
(PretrainedConfig,可选)- 模型使用的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并在目录中找到名为config.json的配置 JSON 文件。
-
cache_dir
(str
或os.PathLike
,可选)- 下载的预训练模型配置应该缓存在其中的目录路径,如果不应使用标准缓存。 -
from_pt
(bool
,可选,默认为False
)- 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)- 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)- 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)- 要使用的代理服务器的字典,按协议或端点划分,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理将在每个请求上使用。 -
output_loading_info(bool,
可选,默认为False
)- 是否返回一个包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
)- 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
,可选,默认为"main"
)- 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
)- 是否允许在 Hub 上定义自定义模型的自定义建模文件。此选项应仅对您信任的存储库设置为True
,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
)- 在 Hub 上使用的代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(额外的关键字参数,可选)- 可用于更新配置对象(在加载后)并启动模型(例如,output_attentions=True
)。根据是否提供config
或自动加载的情况而表现不同:- 如果提供了
config
,**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs
的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了
从预训练模型中实例化库的模型类之一(带有序列分类头)。
要实例化的模型类是根据配置对象的model_type
属性选择的(如果可能,作为参数传递或从pretrained_model_name_or_path
加载),或者当缺少时,通过在pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— FlaxAlbertForSequenceClassification(ALBERT 模型) -
bart
— FlaxBartForSequenceClassification(BART 模型) -
bert
— FlaxBertForSequenceClassification(BERT 模型) -
big_bird
— FlaxBigBirdForSequenceClassification(BigBird 模型) -
distilbert
— FlaxDistilBertForSequenceClassification(DistilBERT 模型) -
electra
— FlaxElectraForSequenceClassification(ELECTRA 模型) -
mbart
— FlaxMBartForSequenceClassification(mBART 模型) -
roberta
— FlaxRobertaForSequenceClassification(RoBERTa 模型) -
roberta-prelayernorm
— FlaxRobertaPreLayerNormForSequenceClassification(RoBERTa-PreLayerNorm 模型) -
roformer
— FlaxRoFormerForSequenceClassification(RoFormer 模型) -
xlm-roberta
— FlaxXLMRobertaForSequenceClassification(XLM-RoBERTa 模型)
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, FlaxAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMultipleChoice
class transformers.AutoModelForMultipleChoice
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained()类方法或 from_config()类方法创建时,将被实例化为库的模型类之一(带有多选头)。
这个类不能直接使用__init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 要实例化的模型类是根据配置类选择的:- AlbertConfig 配置类:AlbertForMultipleChoice(ALBERT 模型)
- BertConfig 配置类:BertForMultipleChoice(BERT 模型)
- BigBirdConfig 配置类:BigBirdForMultipleChoice(BigBird 模型)
- CamembertConfig 配置类:CamembertForMultipleChoice(CamemBERT 模型)
- CanineConfig 配置类:CanineForMultipleChoice(CANINE 模型)
- ConvBertConfig 配置类:ConvBertForMultipleChoice(ConvBERT 模型)
- Data2VecTextConfig 配置类:Data2VecTextForMultipleChoice(Data2VecText 模型)
- DebertaV2Config 配置类:DebertaV2ForMultipleChoice(DeBERTa-v2 模型)
- DistilBertConfig 配置类:DistilBertForMultipleChoice(DistilBERT 模型)
- ElectraConfig 配置类:ElectraForMultipleChoice(ELECTRA 模型)
- ErnieConfig 配置类:ErnieForMultipleChoice(ERNIE 模型)
- ErnieMConfig 配置类:ErnieMForMultipleChoice(ErnieM 模型)
- FNetConfig 配置类:FNetForMultipleChoice(FNet 模型)
- FlaubertConfig 配置类:FlaubertForMultipleChoice(FlauBERT 模型)
- FunnelConfig 配置类:FunnelForMultipleChoice(Funnel Transformer 模型)
- IBertConfig 配置类:IBertForMultipleChoice(I-BERT 模型)
- LongformerConfig 配置类:LongformerForMultipleChoice(Longformer 模型)
- LukeConfig 配置类:LukeForMultipleChoice(LUKE 模型)
- MPNetConfig 配置类: MPNetForMultipleChoice (MPNet 模型)
- MegaConfig 配置类: MegaForMultipleChoice (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForMultipleChoice (Megatron-BERT 模型)
- MobileBertConfig 配置类: MobileBertForMultipleChoice (MobileBERT 模型)
- MraConfig 配置类: MraForMultipleChoice (MRA 模型)
- NezhaConfig 配置类: NezhaForMultipleChoice (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForMultipleChoice (Nyströmformer 模型)
- QDQBertConfig 配置类: QDQBertForMultipleChoice (QDQBert 模型)
- RemBertConfig 配置类: RemBertForMultipleChoice (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForMultipleChoice (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForMultipleChoice (RoFormer 模型)
- RobertaConfig 配置类: RobertaForMultipleChoice (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类: SqueezeBertForMultipleChoice (SqueezeBERT 模型)
- XLMConfig 配置类: XLMForMultipleChoice (XLM 模型)
- XLMRobertaConfig 配置类:XLMRobertaForMultipleChoice(XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类:XLMRobertaXLForMultipleChoice(XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类:XLNetForMultipleChoice(XLNet 模型)
- XmodConfig 配置类:XmodForMultipleChoice(X-MOD 模型)
- YosoConfig 配置类:YosoForMultipleChoice(YOSO 模型)
从配置中实例化库中的一个模型类(带有多选头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForMultipleChoice
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = AutoModelForMultipleChoice.from_config(config)
from_pretrained
< source >
代码语言:javascript复制( *model_args **kwargs )
参数
-
pretrained_model_name_or_path
(str
或os.PathLike
)— 可以是以下之一:- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 id。有效的模型 id 可以位于根级别,如
bert-base-uncased
,或者在用户或组织名称下命名空间化,如dbmdz/bert-base-german-cased
。 - 一个目录的路径,其中包含使用 save_pretrained()保存的模型权重,例如,
./my_model_directory/
。 - 一个路径或 url 到一个tensorflow 索引检查点文件(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象提供为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型要慢。
- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 id。有效的模型 id 可以位于根级别,如
-
model_args
(额外的位置参数,可选)— 将传递给底层模型的__init__()
方法。 -
config
(PretrainedConfig,可选)— 用于替代自动加载的配置的模型配置。当以下情况时,配置可以自动加载:- 该模型是库提供的模型(使用预训练模型的模型 id字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并且在目录中找到名为config.json的配置 JSON 文件。
-
state_dict
(Dict[str, torch.Tensor],可选)— 一个状态字典,用于替代从保存的权重文件加载的状态字典。 如果您想从预训练配置创建模型但加载自己的权重,则可以使用此选项。不过,在这种情况下,您应该检查使用 save_pretrained()和 from_pretrained()是否不是更简单的选项。 -
cache_dir
(str
或os.PathLike
,可选) — 下载预训练模型配置应缓存的目录路径,如果不使用标准缓存。 -
from_tf
(bool
,可选,默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 -
force_download
(bool
,可选,默认为False
)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 -
resume_download
(bool
,可选,默认为False
)— 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。 -
proxies
(Dict[str, str]
,可选)— 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。这些代理在每个请求中使用。 -
output_loading_info(bool,
可选,默认为False
) — 是否还返回包含缺失键、意外键和错误消息的字典。 -
local_files_only(bool,
可选,默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 -
revision
(str
,可选,默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
trust_remote_code
(bool
,可选,默认为False
)— 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True
,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。 -
code_revision
(str
,可选,默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 -
kwargs
(附加关键字参数,可选) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True
)。根据是否提供或自动加载config
,其行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数(from_pretrained())。kwargs
中与配置属性对应的每个键将用提供的kwargs
值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有多选头)。
要实例化的模型类是根据配置对象的 model_type
属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path
加载),或者当缺失时,通过在 pretrained_model_name_or_path
上使用模式匹配来回退:
-
albert
— AlbertForMultipleChoice (ALBERT 模型) -
bert
— BertForMultipleChoice (BERT 模型) -
big_bird
— BigBirdForMultipleChoice (BigBird 模型) -
camembert
— CamembertForMultipleChoice (CamemBERT 模型) -
canine
— CanineForMultipleChoice (CANINE 模型) -
convbert
— ConvBertForMultipleChoice (ConvBERT 模型) -
data2vec-text
— Data2VecTextForMultipleChoice (Data2VecText 模型) -
deberta-v2
— DebertaV2ForMultipleChoice (DeBERTa-v2 模型) -
distilbert
— DistilBertForMultipleChoice (DistilBERT 模型) -
electra
— ElectraForMultipleChoice (ELECTRA 模型) -
ernie
— ErnieForMultipleChoice (ERNIE 模型) -
ernie_m
— ErnieMForMultipleChoice (ErnieM 模型) -
flaubert
— FlaubertForMultipleChoice (FlauBERT 模型) -
fnet
— FNetForMultipleChoice (FNet 模型) -
funnel
— FunnelForMultipleChoice (Funnel Transformer 模型) -
ibert
— IBertForMultipleChoice (I-BERT 模型) -
longformer
— LongformerForMultipleChoice (Longformer 模型) -
luke
— LukeForMultipleChoice (LUKE 模型) -
mega
— MegaForMultipleChoice (MEGA 模型) -
megatron-bert
— MegatronBertForMultipleChoice (Megatron-BERT 模型) -
mobilebert
— MobileBertForMultipleChoice (MobileBERT 模型) -
mpnet
— MPNetForMultipleChoice (MPNet 模型) -
mra
— MraForMultipleChoice (MRA 模型) -
nezha
— NezhaForMultipleChoice (Nezha 模型) -
nystromformer
— NystromformerForMultipleChoice (Nyströmformer 模型) -
qdqbert
— QDQBertForMultipleChoice(QDQBert 模型) -
rembert
— RemBertForMultipleChoice(RemBERT 模型) -
roberta
— RobertaForMultipleChoice(RoBERTa 模型) -
roberta-prelayernorm
— RobertaPreLayerNormForMultipleChoice(RoBERTa-PreLayerNorm 模型) -
roc_bert
— RoCBertForMultipleChoice(RoCBert 模型) -
roformer
— RoFormerForMultipleChoice(RoFormer 模型) -
squeezebert
— SqueezeBertForMultipleChoice(SqueezeBERT 模型) -
xlm
— XLMForMultipleChoice(XLM 模型) -
xlm-roberta
— XLMRobertaForMultipleChoice(XLM-RoBERTa 模型) -
xlm-roberta-xl
— XLMRobertaXLForMultipleChoice(XLM-RoBERTa-XL 模型) -
xlnet
— XLNetForMultipleChoice(XLNet 模型) -
xmod
— XmodForMultipleChoice(X-MOD 模型) -
yoso
— YosoForMultipleChoice(YOSO 模型)
默认情况下,模型处于评估模式,使用 model.eval()
(例如,关闭了 dropout 模块)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式
示例:
代码语言:javascript复制>>> from transformers import AutoConfig, AutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMultipleChoice.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMultipleChoice.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMultipleChoice.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMultipleChoice
class transformers.TFAutoModelForMultipleChoice
< source >
代码语言:javascript复制( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,将实例化为库中的模型类之一(带有多选头)。
这个类不能直接使用 __init__()
实例化(会抛出错误)。
from_config
< source >
代码语言:javascript复制( **kwargs )
参数
-
config
(PretrainedConfig) — 根据配置类选择要实例化的模型类:- AlbertConfig 配置类:TFAlbertForMultipleChoice(ALBERT 模型)
- BertConfig 配置类:TFBertForMultipleChoice(BERT 模型)
- CamembertConfig 配置类:TFCamembertForMultipleChoice(CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForMultipleChoice (ConvBERT 模型)
- DebertaV2Config 配置类: TFDebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForMultipleChoice (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForMultipleChoice (ELECTRA 模型)
- FlaubertConfig 配置类: TFFlaubertForMultipleChoice (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForMultipleChoice (Funnel Transformer 模型)
- LongformerConfig 配置类: TFLongformerForMultipleChoice (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForMultipleChoice (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForMultipleChoice (MobileBERT 模型)
- RemBertConfig 配置类: TFRemBertForMultipleChoice (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForMultipleChoice (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForMultipleChoice (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- XLMConfig 配置类: TFXLMForMultipleChoice (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- XLNetConfig 配置类:TFXLNetForMultipleChoice(XLNet 模型)
从配置实例化库中的一个模型类(带有多选头)。
注意:从配