原文:
jax.readthedocs.io/en/latest/
在 JAX 之上构建
原文:
jax.readthedocs.io/en/latest/building_on_jax.html
学习高级 JAX 使用的一种很好的方法是看看其他库如何使用 JAX,它们如何将库集成到其 API 中,它在数学上添加了什么功能,并且如何在其他库中用于计算加速。
以下是 JAX 功能如何用于跨多个领域和软件包定义加速计算的示例。
梯度计算
简单的梯度计算是 JAX 的一个关键特性。在JaxOpt 库中值和 grad 直接用于用户在其源代码中的多个优化算法中。
同样,上面提到的 Dynamax Optax 配对,是过去具有挑战性的梯度使估计方法的一个例子,Optax 的最大似然期望。
在多个设备上单核计算速度加快
在 JAX 中定义的模型然后可以被编译以通过 JIT 编译进行单次计算速度加快。相同的编译码然后可以被发送到 CPU 设备,GPU 或 TPU 设备以获得额外的速度加快,通常不需要额外的更改。 这允许平稳地从开发流程转入生产流程。在 Dynamax 中,线性状态空间模型求解器的计算密集型部分已jitted。 PyTensor 的一个更复杂的例子源于动态地编译 JAX 函数,然后jit 构造的函数。
使用并行化的单台和多台计算机加速
JAX 的另一个好处是使用pmap
和vmap
函数调用或装饰器轻松并行化计算。在 Dynamax 中,状态空间模型使用VMAP 装饰器进行并行化,其实际用例是多对象跟踪。
将 JAX 代码合并到您的工作流程中或您的用户工作流程中
JAX 非常可组合,并且可以以多种方式使用。 JAX 可以作为独立模式使用,用户自己定义所有计算。 但是其他模式,例如使用构建在 jax 上提供特定功能的库。 这些可以是定义特定类型的模型的库,例如神经网络或状态空间模型或其他,或者提供特定功能,例如优化。以下是每种模式的更具体的示例。
直接使用
Jax 可以直接导入和利用,以便在本网站上“从零开始”构建模型,例如在JAX 教程或使用 JAX 进行神经网络中展示的方法。如果您无法找到特定挑战的预建代码,或者希望减少代码库中的依赖项数量,这可能是最佳选择。
使用 JAX 暴露的可组合领域特定库
另一种常见方法是提供预建功能的包,无论是模型定义还是某种类型的计算。这些包的组合可以混合使用,以实现全面的端到端工作流程,定义模型并估计其参数。
一个例子是Flax,它简化了神经网络的构建。通常将 Flax 与Optax配对使用,其中 Flax 定义了神经网络架构,而 Optax 提供了优化和模型拟合能力。
另一个是Dynamax,它允许轻松定义状态空间模型。使用 Dynamax 可以使用Optax 进行最大似然估计,或者使用Blackjax 进行 MCMC 全贝叶斯后验估计。
用户完全隐藏 JAX
其他库选择完全包装 JAX 以适应其特定 API。例如,PyMC 和Pytensor就是一个例子,用户可能从未直接“看到”JAX,而是使用 PyMC 特定的 API 包装JAX 函数。
注:
原文:
jax.readthedocs.io/en/latest/notes.html
本节包含有关使用 JAX 相关主题的简短注释;另请参阅 JAX Enhancement Proposals (JEPs) 中更详细的设计讨论。
依赖和版本兼容性:
- API 兼容性概述了 JAX 在不同版本之间 API 兼容性的政策。
- Python 和 NumPy 版本支持政策概述了 JAX 与 Python 和 NumPy 的兼容性政策。
迁移和弃用事项:
- jax.Array 迁移总结了 jax v 0.4.1 中默认数组类型的更改。
内存和计算使用:
- 异步调度描述了 JAX 的异步调度模型。
- 并发性描述了 JAX 与其他 Python 并发性的交互方式。
- GPU 内存分配描述了 JAX 在 GPU 内存分配中的交互方式。
程序员保护栏:
- 等级提升警告描述了如何配置
jax.numpy
以避免隐式等级提升。
API 兼容性
原文:
jax.readthedocs.io/en/latest/api_compatibility.html
JAX 不断发展,我们希望能改进其 API。尽管如此,我们希望最大程度减少 JAX 用户社区的混乱,并尽量少做破坏性更改。
JAX 遵循三个月的废弃政策。当对 API 进行不兼容的更改时,我们将尽力遵守以下流程:
- 更改将在
CHANGELOG.md
中和被废弃 API 的文档字符串中公布,并且旧 API 将发出DeprecationWarning
。 - 在
jax
发布了废弃 API 后的三个月内,我们可能随时移除已废弃的 API。请注意,三个月是一个较短的时间界限,故意选择快于许多更成熟项目的时间界限。实际上,废弃可能需要更长时间,特别是如果某个功能有很多用户。如果三个月的废弃期变得问题重重,请与我们联系。
我们保留随时更改此政策的权利。
覆盖了什么内容?
仅涵盖公共的 JAX API,包括以下模块:
-
jax
-
jax.dlpack
-
jax.image
-
jax.lax
-
jax.nn
-
jax.numpy
-
jax.ops
-
jax.profiler
-
jax.random
(参见下文详细说明) -
jax.scipy
-
jax.tree_util
-
jax.test_util
这些模块中并非所有内容都是公开的。随着时间的推移,我们正在努力区分公共 API 和私有 API。公共 API 在 JAX 文档中有详细记录。此外,我们的目标是所有非公共 API 应以下划线作为前缀命名,尽管我们目前还未完全遵守这一规定。
未覆盖的内容是什么?
- 任何以下划线开头的内容。
-
jax._src
-
jax.core
-
jax.linear_util
-
jax.lib
-
jax.prng
-
jax.interpreters
-
jax.experimental
-
jax.example_libraries
-
jax.extend
(参见详情)
此列表并非详尽无遗。
数值和随机性
数值运算的确切值在 JAX 的不同版本中并不保证稳定。事实上,在给定的 JAX 版本、加速器平台上,在或不在 jax.jit
内部,等等,确切的数值计算不一定是稳定的。
对于固定 PRNG 密钥输入,jax.random
中伪随机函数的输出可能会在 JAX 不同版本间变化。兼容性政策仅适用于输出的分布。例如,表达式 jax.random.gumbel(jax.random.key(72))
在 JAX 的不同版本中可能返回不同的值,但 jax.random.gumbel
仍然是 Gumbel 分布的伪随机生成器。
我们尽量不频繁地更改伪随机值。当更改发生时,会在变更日志中公布,但不遵循废弃周期。在某些情况下,JAX 可能会暴露一个临时配置标志,用于回滚新行为,以帮助用户诊断和更新受影响的代码。此类标志将持续一段废弃时间。
Python 和 NumPy 版本支持政策
原文:
jax.readthedocs.io/en/latest/deprecation.html
对于 NumPy 和 SciPy 版本支持,JAX 遵循 Python 科学社区的 SPEC 0。
对于 Python 版本支持,我们听取了用户的意见,36 个月的支持窗口可能太短,例如由于新 CPython 版本到 Linux 供应商版本的延迟传播。因此,JAX 支持 Python 版本至少比 SPEC-0 推荐的长九个月。
这意味着我们至少支持:
- 在每个 JAX 发布前 45 个月内的所有较小的 Python 版本。例如:
- Python 3.9于 2020 年 10 月发布,并将至少在2024 年 7 月之前支持新的 JAX 发布。
- Python 3.10于 2021 年 10 月发布,并将至少在2025 年 7 月之前支持新的 JAX 发布。
- Python 3.11于 2022 年 10 月发布,并将至少在2026 年 7 月之前支持新的 JAX 发布。
- 在每个 JAX 发布前 24 个月内的所有较小的 NumPy 版本。例如:
- NumPy 1.22于 2021 年 12 月发布,并将至少在2023 年 12 月之前支持新的 JAX 发布。
- NumPy 1.23于 2022 年 6 月发布,并将至少在2024 年 6 月之前支持新的 JAX 发布。
- NumPy 1.24于 2022 年 12 月发布,并将至少在2024 年 12 月之前支持新的 JAX 发布。
- 在每个 JAX 发布前 24 个月内的所有较小的 SciPy 版本,从 SciPy 版本 1.9 开始。例如:
- Scipy 1.9于 2022 年 7 月发布,并将至少在2024 年 7 月之前支持新的 JAX 发布。
- Scipy 1.10于 2023 年 1 月发布,并将至少在2025 年 1 月之前支持新的 JAX 发布。
- Scipy 1.11于 2023 年 6 月发布,并将至少在2025 年 6 月之前支持新的 JAX 发布。
JAX 发布可以支持比本政策严格要求的更旧的 Python、NumPy 和 SciPy 版本,但对更旧版本的支持可能随时在列出的日期之后终止。
jax.Array
迁移
原文:
jax.readthedocs.io/en/latest/jax_array_migration.html
yashkatariya@
TL;DR
JAX 将其默认数组实现切换为新的 jax.Array
自版本 0.4.1 起。本指南解释了这一决定的背景,它可能对您的代码产生的影响,以及如何(临时)切换回旧行为。
发生了什么?
jax.Array
是 JAX 中统一的数组类型,包括 DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
类型。jax.Array
类型有助于使并行成为 JAX 的核心特性,简化和统一了 JAX 的内部结构,并允许我们统一 jit
和 pjit
。如果你的代码没有涉及到 DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
的区别,那就不需要进行任何更改。但是依赖于这些单独类细节的代码可能需要进行调整以适配统一的 jax.Array
。
迁移完成后,jax.Array
将成为 JAX 中唯一的数组类型。
本文介绍了如何将现有代码库迁移到 jax.Array
。有关如何使用 jax.Array
和 JAX 并行 API 的更多信息,请参阅 Distributed arrays and automatic parallelization 教程。
如何启用 jax.Array
?
你可以通过以下方式启用 jax.Array
:
设置 shell 环境变量 JAX_ARRAY
为真值(例如 1
);
如果你的代码使用 absl 解析标志,可以将布尔标志 jax_array
设置为真值;
在你的主文件顶部加入以下声明:
代码语言:javascript复制import jax
jax.config.update('jax_array', True)
如何判断 jax.Array
是否破坏了我的代码?
最简单的方法是禁用 jax.Array
,看看问题是否解决。
我如何暂时禁用 jax.Array
?
通过 2023 年 3 月 15 日,可以通过以下方式禁用 jax.Array
:
设置 shell 环境变量 JAX_ARRAY
为假值(例如 0
);
如果你的代码使用 absl 解析标志,可以将布尔标志 jax_array
设置为假值;
在你的主文件顶部加入以下声明:
代码语言:javascript复制import jax
jax.config.update('jax_array', False)
为什么创建 jax.Array
?
当前 JAX 有三种类型:DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
。jax.Array
合并了这三种类型,并清理了 JAX 的内部结构,同时增加了新的并行特性。
我们还引入了一个新的 Sharding
抽象,描述了逻辑数组如何在一个或多个设备(如 TPU 或 GPU)上物理分片。这一变更还升级、简化并将 pjit
的并行性特性合并到 jit
中。使用 jit
装饰的函数将能够在分片数组上操作,而无需将数据复制到单个设备上。
使用 jax.Array
可以获得的功能:
- C
pjit
分派路径 - 逐操作并行性(即使数组分布在多台设备上,跨多个主机)
- 使用
pjit
/jit
更简单的批数据并行性。 - 可以完全利用 OpSharding 的灵活性,或者任何您想要的其他分片方式来创建不一定包含网格和分区规范的
Sharding
。 - 等等
示例:
代码语言:javascript复制import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
切换到 jax.Array
后可能会出现哪些问题?
新公共类型命名为 jax.Array
。
所有 isinstance(..., jnp.DeviceArray)
或 isinstance(.., jax.xla.DeviceArray)
以及其他 DeviceArray
的变体应该切换到使用 isinstance(..., jax.Array)
。
由于 jax.Array
可以表示 DA、SDA 和 GDA,您可以通过以下方式在 jax.Array
中区分这三种类型:
-
x.is_fully_addressable and len(x.sharding.device_set) == 1
– 这意味着jax.Array
类似于 DA。 -
x.is_fully_addressable and (len(x.sharding.device_set) > 1
– 这意味着jax.Array
类似于 SDA。 -
not x.is_fully_addressable
– 这意味着jax.Array
类似于 GDA,并跨多个进程。
对于 ShardedDeviceArray
,可以将 isinstance(..., pxla.ShardedDeviceArray)
转移到 isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1
。
通常无法区分单设备数组上的 ShardedDeviceArray
与任何其他类型的单设备数组。
GDA 的 API 名称变更
GDA 的 local_shards
和 local_data
已经被弃用。
请使用与 jax.Array
和 GDA
兼容的 addressable_shards
和 addressable_data
。
创建 jax.Array
。
当 jax_array
标志为真时,所有 JAX 函数将输出 jax.Array
。如果您曾使用 GlobalDeviceArray.from_callback
、make_sharded_device_array
或 make_device_array
函数显式创建相应的 JAX 数据类型,则需要切换为使用 jax.make_array_from_callback()
或 jax.make_array_from_single_device_arrays()
。
对于 GDA:
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)
可以一对一地切换为 jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)
。
如果您曾使用原始的 GDA 构造函数来创建 GDAs,则执行以下操作:
GlobalDeviceArray(shape, mesh, pspec, buffers)
可以变成 jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)
。
对于 SDA:
make_sharded_device_array(aval, sharding_spec, device_buffers, indices)
可以变成 jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)
。
要决定分片应该是什么,取决于您创建 SDA 的原因:
如果它被创建为 pmap
的输入,则分片可以是:jax.sharding.PmapSharding(devices, sharding_spec)
。
如果它被创建为 pjit 的输入,则分片可以是 jax.sharding.NamedSharding(mesh, pspec)
。
切换到 jax.Array
后对于主机本地输入的 pjit 有破坏性变更。
如果您完全使用 GDA 参数作为 pjit 的输入,则可以跳过此部分!