JAX 中文文档(二)

2024-06-22 08:39:39 浏览数 (1)

原文:jax.readthedocs.io/en/latest/

JAX 教程

原文:jax.readthedocs.io/en/latest/tutorials.html

关键概念

原文:jax.readthedocs.io/en/latest/key-concepts.html

本节简要介绍了 JAX 包的一些关键概念。

JAX 数组 (jax.Array)

JAX 中的默认数组实现是 jax.Array。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray 类型相似,但它也有一些重要的区别。

数组创建

我们通常不直接调用 jax.Array 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了类似 NumPy 风格的数组构造功能,如 jax.numpy.zeros()jax.numpy.linspace()jax.numpy.arange() 等。

代码语言:javascript复制
import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array) 
代码语言:javascript复制
True 

如果您在代码中使用 Python 类型注解,jax.Array 是 jax 数组对象的适当注释(参见 jax.typing 以获取更多讨论)。

数组设备和分片

JAX 数组对象具有一个 devices 方法,允许您查看数组内容存储在哪里。在最简单的情况下,这将是单个 CPU 设备:

代码语言:javascript复制
x.devices() 
代码语言:javascript复制
{CpuDevice(id=0)} 

一般来说,数组可能会在多个设备上 分片,您可以通过 sharding 属性进行检查:

代码语言:javascript复制
x.sharding 
代码语言:javascript复制
SingleDeviceSharding(device=CpuDevice(id=0)) 

在这里,数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备或者多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅分片计算介绍## 变换

除了用于操作数组的函数外,JAX 还包括许多用于操作 JAX 函数的变换。这些变换包括

  • jax.jit(): 即时(JIT)编译;参见即时编译
  • jax.vmap(): 向量化变换;参见自动向量化
  • jax.grad(): 梯度变换;参见自动微分

以及其他几个。变换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这是您可能如何对一个简单的 SELU 函数进行 JIT 编译:

代码语言:javascript复制
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0)) 
代码语言:javascript复制
1.05 

通常情况下,您会看到使用 Python 的装饰器语法来应用变换以方便操作:

代码语言:javascript复制
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

jit()vmap()grad() 等变换对于有效使用 JAX 至关重要,我们将在后续章节中详细介绍它们。## 跟踪

变换背后的魔法是跟踪器的概念。跟踪器是数组对象的抽象替身,传递给 JAX 函数,以提取函数编码的操作序列。

您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如:

代码语言:javascript复制
@jax.jit
def f(x):
  print(x)
  return x   1

x = jnp.arange(5)
result = f(x) 
代码语言:javascript复制
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)> 

打印的值不是数组 x,而是代表 x 的关键属性的 Tracer 实例,比如它的 shapedtype。通过使用追踪值执行函数,JAX 可以确定函数编码的操作序列,然后在实际执行这些操作之前执行转换:例如 jit()vmap()grad() 可以将输入操作序列映射到变换后的操作序列。 ## Jaxprs

JAX 对操作序列有自己的中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是一个函数程序的简单表示,包含一系列原始操作。

例如,考虑我们上面定义的 selu 函数:

代码语言:javascript复制
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

我们可以使用 jax.make_jaxpr() 实用程序来将该函数转换为一个 jaxpr,给定特定的输入:

代码语言:javascript复制
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x) 
代码语言:javascript复制
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) } 

与 Python 函数定义相比,可以看出它编码了函数表示的精确操作序列。我们稍后将深入探讨 JAX 内部的 jaxprs:jaxpr 语言。 ## Pytrees

JAX 函数和转换基本上操作数组,但实际上编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在具有有意义键的数组字典中。与其逐案处理这类结构,JAX 依赖于 pytree 抽象来统一处理这些集合。

以下是一些可以作为 pytrees 处理的对象的示例:

代码语言:javascript复制
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
代码语言:javascript复制
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)] 
代码语言:javascript复制
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
代码语言:javascript复制
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5] 
代码语言:javascript复制
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
代码语言:javascript复制
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0] 

JAX 提供了许多用于处理 PyTrees 的通用实用程序;例如函数 jax.tree.map() 可以用于将函数映射到树中的每个叶子,而 jax.tree.reduce() 可以用于在树中的叶子上应用约简操作。

你可以在《使用 pytrees 教程》中了解更多信息。

即时编译

原文:jax.readthedocs.io/en/latest/jit-compilation.html

在这一部分,我们将进一步探讨 JAX 的工作原理,以及如何使其性能卓越。我们将讨论 jax.jit() 变换,它将 JAX Python 函数进行即时编译,以便在 XLA 中高效执行。

如何工作 JAX 变换

在前一节中,我们讨论了 JAX 允许我们转换 Python 函数的能力。JAX 通过将每个函数减少为一系列原始操作来实现这一点,每个原始操作代表一种基本的计算单位。

查看函数背后原始操作序列的一种方法是使用 jax.make_jaxpr()

代码语言:javascript复制
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0)) 
代码语言:javascript复制
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) } 

文档的理解 Jaxprs 部分提供了有关上述输出含义的更多信息。

重要的是要注意,jaxpr 不捕获函数中存在的副作用:其中没有对 global_list.append(x) 的任何内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(也称为函数纯粹)的代码。如果 纯函数副作用 是陌生的术语,这在

0 人点赞