阅读(2117) (0)

PyTorch 大规模部署的功能

2020-09-10 17:00:46 更新
原文: https://pytorch.org/docs/stable/notes/large_scale_deployments.html

本说明讨论了一些扩展点和技巧,这些扩展点和技巧在较大的系统中运行 PyTorch 或在较大的组织中使用 PyTorch 操作多个系统时可能有用。

它不涉及将模型部署到生产中的主题。 检查 torch.jit 或相应的教程之一。

该注释假定您是从组织中的源代码构建 PyTorch 的,或者能够静态链接使用 PyTorch 时要加载的其他代码。 因此,许多钩子都公开为 C ++ API,可以在集中式位置(例如, 在静态初始化代码中。

整个机队的运营商配置文件

PyTorch 带有torch.autograd.profiler,能够测量各个操作员根据需要花费的时间。 可以使用相同的机制对运行 PyTorch 的任何进程进行“始终在线”测量。 收集有关在给定进程中或在整个机器组中运行的 PyTorch 工作负载的信息可能很有用。

可以使用torch::autograd::profiler::pushCallback添加用于任何操作员调用的新回调。 挂钩将使用描述调用上下文的torch::autograd::profiler::RecordFunction结构调用(例如<cite>名称</cite>)。 如果启用,RecordFunction::inputs()包含表示为torch::IValue变量类型的函数的参数。 注意,输入日志记录相对昂贵,因此必须显式启用。

操作员回调还可以访问at::getThreadLocalDebugInfo()接口,该接口返回指向包含调试信息的结构的指针。 该调试信息应该与相应的at::setThreadLocalDebugInfo(debug_info)调用一起设置。 调试信息通过前向传播(包括异步fork任务)和后向传播进行传播,对于将有关执行环境的一些额外信息(例如,模型 ID)从应用程序的高层传递到操作员回调非常有用。

调用回调会增加一些开销,因此通常随机抽样操作员调用很有用。 可以在每个回调的基础上启用 <cite>torch :: autograd :: profiler :: setSamplingProbability</cite> 指定的全局采样率。

请注意,pushCallbacksetSamplingProbability不是线程安全的,只有在没有运行 PyTorch 运算符时才能调用。 通常,在初始化过程中一次调用它们是一个好主意。

这是一个例子:

// Called somewhere in the program beginning
void init() {
    // Sample one in a hundred operator runs randomly
    torch::autograd::setSamplingProbability(0.01);
    pushCallback(
        &onFunctionEnter,
        &onFunctionExit,
        /* needs_inputs */ true,
        /* sampled */ true
    );
}


void onFunctionEnter(const RecordFunction& fn) {
    std::cerr << "Before function " << fn.name()
              << " with " << fn.inputs().size() << " inputs" << std::endl;
}


void onFunctionExit(const RecordFunction& fn) {
    std::cerr << "After function " << fn.name();
}

API 使用记录

在更广泛的生态系统中运行时(例如在托管的工作计划程序中),跟踪哪些二进制文件调用特定的 PyTorch API 通常很有用。 在几个重要的 API 点注入了简单的工具,这些工具会触发给定的回调。 由于通常在一次性 python 脚本中调用 PyTorch,因此对于每个 API 的给定进程,回调仅触发一次。

c10::SetAPIUsageHandler可用于注册 API 使用情况检测处理程序。 传递的参数将是“ api key”,用于标识使用的点,例如,用于 PyTorch 扩展名导入的python.import或触发 TorchScript 编译的torch.script.compile

SetAPIUsageLogger([](const std::string& event_name) {
    std::cerr << "API was used: " << event_name << std::endl;
});

开发人员注意:可以在代码中使用 C ++中的C10_LOG_API_USAGE_ONCE("my_api")或 Python 中的torch._C._log_api_usage_once("my.api")添加新的 API 触发点。

将元数据附加到已保存的 TorchScript 模型中

TorchScript 模块可以保存为存档文件,该文件将序列化的参数和模块代码捆绑为 TorchScript(请参见 torch.jit.save())。 将附加信息与模型捆绑在一起通常很方便,例如,模型生产者或辅助工件的描述。

可以通过将_extra_files<>参数传递给 torch.jit.save() 和torch::jit::load<>在存储过程中存储和检索任意二进制 Blob 来实现。 由于 TorchScript 文件是常规的 ZIP 存档,因此额外的信息将作为常规文件存储在存档的extra/<>目录中。

还有一个全局挂钩,可将其他文件附加到当前流程中生成的任何 TorchScript 存档中。 用生产者元数据标记模型可能很有用,类似于由数码相机产生的 JPEG 元数据。 用法示例如下所示:

SetExportModuleExtraFilesHook([](const script::Module&) {
    script::ExtraFilesMap files;
    files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}";
    return files;
});

构建环境注意事项

TorchScript 的编译需要使用 python 的inspect.getsource<>调用,因此必须有权访问原始 python 文件。 在某些生产环境中,可能需要显式部署.py文件以及预编译的.pyc文件。

通用扩展点

PyTorch API 通常是松散耦合的,很容易用专用版本替换组件。 常见的扩展点包括:

  • 使用 C ++实现的自定义运算符-有关更多详细信息,请参见教程
  • 自定义数据读取通常可以通过调用相应的 python 库直接集成。 torch.utils.data 的现有功能可以通过扩展 Dataset 或  IterableDataset 加以利用。