原文:
jax.readthedocs.io/en/latest/
JAX 教程
原文:
jax.readthedocs.io/en/latest/tutorials.html
关键概念
原文:
jax.readthedocs.io/en/latest/key-concepts.html
本节简要介绍了 JAX 包的一些关键概念。
JAX 数组 (jax.Array
)
JAX 中的默认数组实现是 jax.Array
。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray
类型相似,但它也有一些重要的区别。
数组创建
我们通常不直接调用 jax.Array
构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy
提供了类似 NumPy 风格的数组构造功能,如 jax.numpy.zeros()
、jax.numpy.linspace()
、jax.numpy.arange()
等。
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
代码语言:javascript复制True
如果您在代码中使用 Python 类型注解,jax.Array
是 jax 数组对象的适当注释(参见 jax.typing
以获取更多讨论)。
数组设备和分片
JAX 数组对象具有一个 devices
方法,允许您查看数组内容存储在哪里。在最简单的情况下,这将是单个 CPU 设备:
x.devices()
代码语言:javascript复制{CpuDevice(id=0)}
一般来说,数组可能会在多个设备上 分片,您可以通过 sharding
属性进行检查:
x.sharding
代码语言:javascript复制SingleDeviceSharding(device=CpuDevice(id=0))
在这里,数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备或者多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅分片计算介绍## 变换
除了用于操作数组的函数外,JAX 还包括许多用于操作 JAX 函数的变换。这些变换包括
-
jax.jit()
: 即时(JIT)编译;参见即时编译 -
jax.vmap()
: 向量化变换;参见自动向量化 -
jax.grad()
: 梯度变换;参见自动微分
以及其他几个。变换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这是您可能如何对一个简单的 SELU 函数进行 JIT 编译:
代码语言:javascript复制def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
代码语言:javascript复制1.05
通常情况下,您会看到使用 Python 的装饰器语法来应用变换以方便操作:
代码语言:javascript复制@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
jit()
、vmap()
、grad()
等变换对于有效使用 JAX 至关重要,我们将在后续章节中详细介绍它们。## 跟踪
变换背后的魔法是跟踪器的概念。跟踪器是数组对象的抽象替身,传递给 JAX 函数,以提取函数编码的操作序列。
您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如:
代码语言:javascript复制@jax.jit
def f(x):
print(x)
return x 1
x = jnp.arange(5)
result = f(x)
代码语言:javascript复制Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>
打印的值不是数组 x
,而是代表 x
的关键属性的 Tracer
实例,比如它的 shape
和 dtype
。通过使用追踪值执行函数,JAX 可以确定函数编码的操作序列,然后在实际执行这些操作之前执行转换:例如 jit()
、vmap()
和 grad()
可以将输入操作序列映射到变换后的操作序列。 ## Jaxprs
JAX 对操作序列有自己的中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是一个函数程序的简单表示,包含一系列原始操作。
例如,考虑我们上面定义的 selu
函数:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
我们可以使用 jax.make_jaxpr()
实用程序来将该函数转换为一个 jaxpr,给定特定的输入:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
代码语言:javascript复制{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558 c
e:f32[5] = sub d 1.6699999570846558
f:f32[5] = pjit[
name=_where
jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
j:f32[5] = select_n g i h
in (j,) }
] b a e
k:f32[5] = mul 1.0499999523162842 f
in (k,) }
与 Python 函数定义相比,可以看出它编码了函数表示的精确操作序列。我们稍后将深入探讨 JAX 内部的 jaxprs:jaxpr 语言。 ## Pytrees
JAX 函数和转换基本上操作数组,但实际上编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在具有有意义键的数组字典中。与其逐案处理这类结构,JAX 依赖于 pytree 抽象来统一处理这些集合。
以下是一些可以作为 pytrees 处理的对象的示例:
代码语言:javascript复制# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
代码语言:javascript复制PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
代码语言:javascript复制# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
代码语言:javascript复制PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
代码语言:javascript复制# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
代码语言:javascript复制PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX 提供了许多用于处理 PyTrees 的通用实用程序;例如函数 jax.tree.map()
可以用于将函数映射到树中的每个叶子,而 jax.tree.reduce()
可以用于在树中的叶子上应用约简操作。
你可以在《使用 pytrees 教程》中了解更多信息。
即时编译
原文:
jax.readthedocs.io/en/latest/jit-compilation.html
在这一部分,我们将进一步探讨 JAX 的工作原理,以及如何使其性能卓越。我们将讨论 jax.jit()
变换,它将 JAX Python 函数进行即时编译,以便在 XLA 中高效执行。
如何工作 JAX 变换
在前一节中,我们讨论了 JAX 允许我们转换 Python 函数的能力。JAX 通过将每个函数减少为一系列原始操作来实现这一点,每个原始操作代表一种基本的计算单位。
查看函数背后原始操作序列的一种方法是使用 jax.make_jaxpr()
:
import jax
import jax.numpy as jnp
global_list = []
def log2(x):
global_list.append(x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0))
代码语言:javascript复制{ lambda ; a:f32[]. let
b:f32[] = log a
c:f32[] = log 2.0
d:f32[] = div b c
in (d,) }
文档的理解 Jaxprs 部分提供了有关上述输出含义的更多信息。
重要的是要注意,jaxpr 不捕获函数中存在的副作用:其中没有对 global_list.append(x)
的任何内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(也称为函数纯粹)的代码。如果 纯函数 和 副作用 是陌生的术语,这在