JAX 中文文档(十五)

2024-06-22 08:46:44 浏览数 (2)

原文:jax.readthedocs.io/en/latest/

jax.tree 模块

原文:jax.readthedocs.io/en/latest/jax.tree.html

用于处理树形容器数据结构的实用工具。

jax.tree 命名空间包含了来自 jax.tree_util 的实用工具的别名。

功能列表

all(tree, *[, is_leaf])

对树的所有叶子进行 all()操作。

flatten(tree[, is_leaf])

将一个 pytree 扁平化。

leaves(tree[, is_leaf])

获取一个 pytree 的叶子。

map(f, tree, *rest[, is_leaf])

将一个多输入函数映射到 pytree 参数上,生成一个新的 pytree。

reduce()

对树的叶子进行 reduce()操作。

structure(tree[, is_leaf])

获取一个 pytree 的 treedef。

transpose(outer_treedef, inner_treedef, …)

将具有树结构 (outer, inner) 的树转换为具有结构 (inner, outer) 的树。

unflatten(treedef, leaves)

根据 treedef 和叶子重构一个 pytree。

jax.tree_util 模块

原文:jax.readthedocs.io/en/latest/jax.tree_util.html

用于处理树状容器数据结构的实用工具。

该模块提供了一小组用于处理树状数据结构(例如嵌套元组、列表和字典)的实用函数。我们称这些结构为 pytrees。它们是树形的,因为它们是递归定义的(任何非 pytree 都是 pytree,即叶子,任何 pytree 的 pytrees 都是 pytree),并且可以递归地操作(映射操作不保留对象身份等价性,并且这些结构不能包含引用循环)。

被视为 pytree 节点的 Python 类型集合(例如可以映射而不是视为叶子的类型)是可扩展的。存在一个单一的模块级别的类型注册表,并且类层次结构被忽略。通过注册一个新的 pytree 节点类型,该类型实际上变得对此文件中的实用函数透明。

该模块的主要目的是支持用户定义的数据结构与 JAX 转换(例如 jit)之间的互操作性。这不是一个通用的树状数据结构处理库。

查看 JAX pytrees 注释以获取示例。

函数列表

Partial(func, *args, **kw)

在 pytrees 中工作的 functools.partial 的版本。

all_leaves(iterable[, is_leaf])

测试给定可迭代对象中的所有元素是否都是叶子。

build_tree(treedef, xs)

从嵌套的可迭代结构构建一个 treedef。

register_dataclass(nodetype, data_fields, …)

扩展了在 pytrees 中被视为内部节点的类型集合。

register_pytree_node(nodetype, flatten_func, …)

扩展了在 pytrees 中被视为内部节点的类型集合。

register_pytree_node_class(cls)

扩展了在 pytrees 中被视为内部节点的类型集合。

register_pytree_with_keys(nodetype, …[, …])

扩展了在 pytrees 中被视为内部节点的类型集合。

register_pytree_with_keys_class(cls)

扩展了在 pytrees 中被视为内部节点的类型集合。

register_static(cls)

将 cls 注册为没有叶子的 pytree。

tree_flatten_with_path(tree[, is_leaf])

像tree_flatten一样展平 pytree,但还返回每个叶子的键路径。

tree_leaves_with_path(tree[, is_leaf])

获取类似tree_leaves的 pytree 的叶子,并返回每个叶子的键路径。

tree_map_with_path(f, tree, *rest[, is_leaf])

对 pytree 键路径和参数执行多输入函数映射,生成新的 pytree。

treedef_children(treedef)

返回直接子节点的 treedef 列表。

treedef_is_leaf(treedef)

如果 treedef 表示叶子,则返回 True。

treedef_tuple(treedefs)

从子 treedefs 的可迭代对象制作一个元组 treedef。

keystr(keys)

辅助函数,用于漂亮地打印键的元组。

传统 API

现在通过jax.tree访问这些 API。

tree_all(tree, *[, is_leaf])

jax.tree.all()的别名。

tree_flatten(tree[, is_leaf])

jax.tree.flatten()的别名。

tree_leaves(tree[, is_leaf])

jax.tree.leaves()的别名。

tree_map(f, tree, *rest[, is_leaf])

jax.tree.map()的别名。

tree_reduce(function, tree[, initializer, …])

jax.tree.reduce()的别名。

tree_structure(tree[, is_leaf])

jax.tree.structure()的别名。

tree_transpose(outer_treedef, inner_treedef, …)

jax.tree.transpose()的别名。

tree_unflatten(treedef, leaves)

jax.tree.unflatten()的别名。

jax.typing 模块

原文:jax.readthedocs.io/en/latest/jax.typing.html

JAX 类型注解模块是 JAX 特定静态类型注解的存放地。这个子模块仍在开发中;要查看这里导出的类型背后的提案,请参阅jax.readthedocs.io/en/latest/jep/12049-type-annotations.html

当前可用的类型包括:

  • jax.Array: 适用于任何 JAX 数组或跟踪器的注解(即 JAX 变换中的数组表示)。
  • jax.typing.ArrayLike: 适用于任何安全隐式转换为 JAX 数组的值;这包括 jax.Arraynumpy.ndarray,以及 Python 内置数值类型(例如intfloat 等)和 numpy 标量值(例如 numpy.int32numpy.float64 等)。
  • jax.typing.DTypeLike: 适用于可以转换为 JAX 兼容 dtype 的任何值;这包括字符串(例如 ‘float32’、‘int32’)、标量类型(例如 float、np.float32)、dtype(例如 np.dtype(‘float32’))、或具有 dtype 属性的对象(例如 jnp.float32、jnp.int32)。

我们可能在将来的版本中添加其他类型。

JAX 类型注解最佳实践

在公共 API 函数中注释 JAX 数组时,我们建议使用 ArrayLike 来标注数组输入,使用 Array 来标注数组输出。

例如,您的函数可能如下所示:

代码语言:javascript复制
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

def my_function(x: ArrayLike) -> Array:
  # Runtime type validation, Python 3.10 or newer:
  if not isinstance(x, ArrayLike):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Runtime type validation, any Python version:
  if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
    raise TypeError(f"Expected arraylike input; got {x}")

  # Convert input to jax.Array:
  x_arr = jnp.asarray(x)

  # ... do some computation; JAX functions will return Array types:
  result = x_arr.sum(0) / x_arr.shape[0]

  # return an Array
  return result 

JAX 的大多数公共 API 遵循这种模式。特别需要注意的是,我们建议 JAX 函数不要接受序列,如listtuple,而应该接受数组,因为这样可以避免在像 jit() 这样的 JAX 变换中产生额外的开销,并且在类似批处理变换 vmap()jax.pmap() 中可能会表现出意外行为。更多信息,请参阅NumPy vs JAX 中的非数组输入。

成员列表

ArrayLike

适用于 JAX 数组类似对象的类型注解。

DTypeLike

别名为str | type[Any] | dtype | SupportsDType

jax.export 模块

原文:jax.readthedocs.io/en/latest/jax.export.html

Exported(fun_name, in_tree, in_avals, …)

降低为 StableHLO 的 JAX 函数。

DisabledSafetyCheck(_impl)

应在(反)序列化时跳过的安全检查。

函数

export(fun_jit, *[, platforms, …])

导出一个用于持久化序列化的 JAX 函数。

deserialize(blob)

反序列化一个已导出的对象。

minimum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

maximum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

default_export_platform()

获取默认的导出平台。

与形状多态性相关的函数

symbolic_shape(shape_spec, *[, constraints, …])

从字符串表示中构建一个符号形状。

symbolic_args_specs(args, shapes_specs[, …])

为导出构建一个 jax.ShapeDtypeSpec 参数规范的 pytree。

is_symbolic_dim§

检查一个维度是否是符号维度。

SymbolicScope([constraints_str])

标识用于符号表达式的作用域。

常量

代码语言:javascript复制
jax.export.minimum_supported_serialization_version

最小支持的序列化版本;参见调用约定版本。

代码语言:javascript复制
jax.export.maximum_supported_serialization_version

最大支持的序列化版本;参见调用约定版本。

jax.extend 模块

原文:jax.readthedocs.io/en/latest/jax.extend.html

JAX 扩展模块。

jax.extend 包提供了访问 JAX 内部机制的模块。参见 JEP #15856。

API 政策

与 公共 API 不同,这个包在发布版本之间 不提供兼容性保证。突破性变更将通过 JAX 项目变更日志 进行公告。

模块

  • jax.extend.ffi 模块
  • jax.extend.linear_util 模块
  • jax.extend.mlir 模块
  • jax.extend.random 模块

jax.extend.ffi 模块

原文:jax.readthedocs.io/en/latest/jax.extend.ffi.html

ffi_lowering(call_target_name, *[, …])

构建一个外部函数接口(FFI)目标的降低规则。

pycapsule(funcptr)

将一个 ctypes 函数指针包装在 PyCapsule 中。

jax.extend.linear_util 模块

原文:jax.readthedocs.io/en/latest/jax.extend.linear_util.html

StoreException

WrappedFun(f, transforms, stores, params, …)

表示要应用转换的函数 f。

cache(call, *[, explain])

用于将 WrappedFun 作为第一个参数的函数的记忆化装饰器。

merge_linear_aux(aux1, aux2)

transformation

向 WrappedFun 添加一个转换。

transformation_with_aux

向 WrappedFun 添加一个带有辅助输出的转换。

wrap_init(f[, params])

将函数 f 包装为 WrappedFun,适用于转换。

jax.extend.mlir 模块

原文:jax.readthedocs.io/en/latest/jax.extend.mlir.html

方言

中间表示

传递管理器

jax.extend.random 模块

原文:jax.readthedocs.io/en/latest/jax.extend.random.html

define_prng_impl(*, key_shape, seed, split, …)

seed_with_impl(impl, seed)

threefry2x32_p

threefry_2x32(keypair, count)

应用 Threefry 2x32 哈希函数。

threefry_prng_impl

指定 PRNG 密钥形状和操作。

rbg_prng_impl

指定 PRNG 密钥形状和操作。

unsafe_rbg_prng_impl

指定 PRNG 密钥形状和操作。

jax.example_libraries 模块

原文:jax.readthedocs.io/en/latest/jax.example_libraries.html

JAX 提供了一些小型的实验性机器学习库。这些库一部分提供工具,另一部分作为使用 JAX 构建此类库的示例。每个库的源代码行数不超过 300 行,因此请查看并根据需要进行调整!

注意

每个小型库的目的是灵感,而非规范。

为了达到这个目的,最好保持它们的代码示例简洁;因此,我们通常不会合并添加新功能的 PR。相反,请将您可爱的拉取请求和设计想法发送到更完整的库,如Haiku或Flax。

  • jax.example_libraries.optimizers 模块
  • jax.example_libraries.stax 模块

jax.example_libraries.optimizers 模块

原文:jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

JAX 中如何编写优化器的示例。

您可能不想导入此模块!此库中的优化器仅供示例使用。如果您正在寻找功能完善的优化器库,两个不错的选择是 JAXopt 和 Optax。

此模块包含一些方便的优化器定义,特别是初始化和更新函数,可用于 ndarray 或任意嵌套的 tuple/list/dict 的 ndarray。

优化器被建模为一个 (init_fun, update_fun, get_params) 函数三元组,其中组件函数具有以下签名:

代码语言:javascript复制
init_fun(params)

Args:
  params: pytree representing the initial parameters.

Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`. 
代码语言:javascript复制
update_fun(step, grads, opt_state)

Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.

Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state. 
代码语言:javascript复制
get_params(opt_state)

Args:
  opt_state: pytree representing an optimizer state.

Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true. 

注意,优化器实现在 opt_state 的形式上具有很大的灵活性:它只需是 JaxTypes 的 pytree(以便可以将其传递给 api.py 中定义的 JAX 变换),并且它必须可以被 update_fun 和 get_params 消耗。

示例用法:

代码语言:javascript复制
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state

for i in range(num_steps):
  value, opt_state = step(i, opt_state) 
代码语言:javascript复制
class jax.example_libraries.optimizers.JoinPoint(subtree)

Bases: object

标记了两个连接(嵌套)的 pytree 之间的边界。

代码语言:javascript复制
class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)

Bases: NamedTuple

参数:

  • init_fn (Callable[**[Any]**, OptimizerState*]*)
  • update_fn (Callable[**[int, Any, OptimizerState*]*,* OptimizerState]*)
  • params_fn (Callable[[OptimizerState]*,* Any])
代码语言:javascript复制
init_fn: Callable[[Any], OptimizerState]

字段 0 的别名

代码语言:javascript复制
params_fn: Callable[[OptimizerState], Any]

字段 2 的别名

代码语言:javascript复制
update_fn: Callable[[int, Any, OptimizerState], OptimizerState]

字段 1 的别名

代码语言:javascript复制
class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)

Bases: tuple

代码语言:javascript复制
packed_state

字段 0 的别名

代码语言:javascript复制
subtree_defs

字段 2 的别名

代码语言:javascript复制
tree_def

字段 1 的别名

代码语言:javascript复制
jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)

构建 Adagrad 的优化器三元组。

适应性次梯度方法用于在线学习和随机优化:www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

参数:

  • step_size – 正标量,或者将迭代索引映射到正标量的可调用对象的步长表达式。
  • momentum – 可选,用于动量的正标量值

返回:

一个 (init_fun, update_fun, get_params) 三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)

构建 Adam 的优化器三元组。

参数:

  • step_size – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
  • b1 – 可选,一个正的标量值,用于 beta_1,第一个时刻估计的指数衰减率(默认为 0.9)。
  • b2 – 可选,一个正的标量值,用于 beta_2,第二个时刻估计的指数衰减率(默认为 0.999)。
  • eps – 可选,一个正的标量值,用于 epsilon,即数值稳定性的小常数(默认为 1e-8)。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)

为 AdaMax(基于无穷范数的 Adam 变体)构造优化器三元组。

参数:

  • step_size – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
  • b1 – 可选,一个正的标量值,用于 beta_1,第一个时刻估计的指数衰减率(默认为 0.9)。
  • b2 – 可选,一个正的标量值,用于 beta_2,第二个时刻估计的指数衰减率(默认为 0.999)。
  • eps – 可选,一个正的标量值,用于 epsilon,即数值稳定性的小常数(默认为 1e-8)。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)

将存储为 pytree 结构的梯度裁剪到最大范数 max_norm。

代码语言:javascript复制
jax.example_libraries.optimizers.constant(step_size)

返回类型:

Callable[[int], float]

代码语言:javascript复制
jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)
代码语言:javascript复制
jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)
代码语言:javascript复制
jax.example_libraries.optimizers.l2_norm(tree)

计算一个 pytree 结构的数组的 l2 范数。适用于权重衰减。

代码语言:javascript复制
jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)

参数:

scalar_or_schedule (float | Callable[**[int]**, float])

返回类型:

Callable[[int], float]

代码语言:javascript复制
jax.example_libraries.optimizers.momentum(step_size, mass)

为带动量的 SGD 构造优化器三元组。

参数:

  • step_size (Callable[**[int]**, float]) – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
  • mass (float) – 正的标量,表示动量系数。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.nesterov(step_size, mass)

为带有 Nesterov 动量的 SGD 构建优化器三元组。

参数:

  • step_sizeCallable[**[int]**, float]) – 正标量,或表示将迭代索引映射到正标量的步长计划的可调用对象。
  • massfloat) – 正标量,表示动量系数。

返回:

一个(init_fun, update_fun, get_params)三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.optimizer(opt_maker)

装饰器,使定义为数组的优化器通用于容器。

使用此装饰器,您可以编写只对单个数组操作的 init、update 和 get_params 函数,并将它们转换为对参数 pytrees 进行操作的相应函数。有关示例,请参见 optimizers.py 中定义的优化器。

参数:

opt_makerCallable[[], tuple[Callable[**[Any]**, Any]**, Callable[**[int, Any, Any]**, Any]**, Callable[**[Any]**, Any]]]) –

返回一个返回(init_fun, update_fun, get_params)函数三元组的函数,该函数可能仅适用于 ndarrays,如

代码语言:javascript复制
init_fun  ::  ndarray  ->  OptStatePytree  ndarray
update_fun  ::  OptStatePytree  ndarray  ->  OptStatePytree  ndarray
get_params  ::  OptStatePytree  ndarray  ->  ndarray 

返回:

一个(init_fun, update_fun, get_params)函数三元组,这些函数按照任意 pytrees 进行操作,如

代码语言:javascript复制
init_fun  ::  ParameterPytree  ndarray  ->  OptimizerState
update_fun  ::  OptimizerState  ->  OptimizerState
get_params  ::  OptimizerState  ->  ParameterPytree  ndarray 

返回函数使用的 OptimizerState pytree 类型与ParameterPytree (OptStatePytree ndarray)相同,但可能出于性能考虑将状态存储为部分展平的数据结构。

返回类型:

Callable[[…], Optimizer]

代码语言:javascript复制
jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)

将标记的 pytree 转换为 OptimizerState。

unpack_optimizer_state 的逆操作。将一个带有 JoinPoints 的标记 pytree(其外部 pytree 的叶子表示为 JoinPoints)转换回一个 OptimizerState。这个函数用于在反序列化优化器状态时很有用。

参数:

marked_pytree – 一个包含 JoinPoint 叶子的 pytree,其保持更多 pytree。

返回:

输入参数的等效 OptimizerState。

代码语言:javascript复制
jax.example_libraries.optimizers.piecewise_constant(boundaries, values)

参数:

  • boundaries (任意)
  • values (任意)
代码语言:javascript复制
jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)
代码语言:javascript复制
jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)

为 RMSProp 构造优化器三元组。

参数:

step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。gamma:衰减参数。eps:Epsilon 参数。

返回:

一个(init_fun, update_fun, get_params)三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)

为带动量的 RMSProp 构造优化器三元组。

这个优化器与 rmsprop 优化器分开,因为它需要跟踪额外的参数。

参数:

  • step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
  • gamma – 衰减参数。
  • eps – Epsilon 参数。
  • momentum – 动量参数。

返回:

一个(init_fun, update_fun, get_params)三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.sgd(step_size)

为随机梯度下降构造优化器三元组。

参数:

step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。

返回:

一个(init_fun, update_fun, get_params)三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)

为 SM3 构造优化器三元组。

大规模学习的内存高效自适应优化。arxiv.org/abs/1901.11150

参数:

  • step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
  • momentum – 可选,动量的正标量值

返回:

一个(init_fun, update_fun, get_params)三元组。

代码语言:javascript复制
jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)

将一个 OptimizerState 转换为带有 JoinPoints 叶子的标记 pytree。

将一个 OptimizerState 转换为带有 JoinPoints 叶子的标记 pytree,以避免丢失信息。这个函数在序列化优化器状态时很有用。

参数:

opt_state – 一个 OptimizerState

返回:

一个带有 JoinPoint 叶子的 pytree,其包含第二级 pytree。

jax.example_libraries.stax 模块

原文:jax.readthedocs.io/en/latest/jax.example_libraries.stax.html

Stax 是一个从头开始的小而灵活的神经网络规范库。

您可能不想导入此模块!Stax 仅用作示例库。对于 JAX,还有许多其他功能更全面的神经网络库,包括来自 Google 的Flax 和来自 DeepMind 的Haiku。

代码语言:javascript复制
jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)

用于创建池化层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)

用于创建批量归一化层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用卷积层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用转置卷积层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用转置卷积层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)

用于创建密集(全连接)层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.Dropout(rate, mode='train')

用于给定率创建丢弃层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.FanInConcat(axis=-1)

用于创建扇入连接层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.FanOut(num)

用于创建扇出层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用卷积层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用转置卷积层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)

用于创建池化层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)

用于创建池化层的层构造函数。

代码语言:javascript复制
jax.example_libraries.stax.elementwise(fun, **fun_kwargs)

在其输入上逐元素应用标量函数的层。

代码语言:javascript复制
jax.example_libraries.stax.parallel(*layers)

并行组合层的组合器。

此组合器生成的层通常与 FanOut 和 FanInSum 层一起使用。

参数:

*layers – 一个层序列,每个都是(init_fun, apply_fun)对。

返回:

表示给定层序列的并行组合的新层,即(init_fun, apply_fun)对。特别地,返回的层接受一个输入序列,并返回一个与参数层长度相同的输出序列。

代码语言:javascript复制
jax.example_libraries.stax.serial(*layers)

串行组合层的组合器。

参数:

*layers – 一个层序列,每个都是(init_fun, apply_fun)对。

返回:

表示给定层序列的串行组合的新层,即(init_fun, apply_fun)对。

代码语言:javascript复制
jax.example_libraries.stax.shape_dependent(make_layer)

延迟层构造对直到输入形状已知的组合器。

参数:

make_layer – 一个以输入形状(正整数元组)为参数的单参数函数,返回一个(init_fun, apply_fun)对。

返回:

表示与 make_layer 返回的相同层的新层,但其构造被延迟直到输入形状已知。

jax.experimental 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.html

jax.experimental.optix 已迁移到其自己的 Python 包中 (deepmind/optax)。

jax.experimental.ann 已迁移到 jax.lax

实验性模块

  • jax.experimental.array_api 模块
  • jax.experimental.checkify 模块
  • jax.experimental.host_callback 模块
  • jax.experimental.maps 模块
  • jax.experimental.pjit 模块
  • jax.experimental.sparse 模块
  • jax.experimental.jet 模块
  • jax.experimental.custom_partitioning 模块
  • jax.experimental.multihost_utils 模块
  • jax.experimental.compilation_cache 模块
  • jax.experimental.key_reuse 模块
  • jax.experimental.mesh_utils 模块
  • jax.experimental.serialize_executable 模块
  • jax.experimental.shard_map 模块

实验性 API

enable_x64([new_val])

实验性上下文管理器,临时启用 X64 模式。

disable_x64()

实验性上下文管理器,临时禁用 X64 模式。

jax.experimental.checkify.checkify(f[, errors])

在函数 f 中功能化检查调用,并可选地添加运行时错误检查。

jax.experimental.checkify.check(pred, msg, …)

检查谓词,如果谓词为假,则添加带有消息的错误。

jax.experimental.checkify.check_error(error)

如果 error 表示失败,则引发异常。

jax.experimental.array_api 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.array_api.html

此模块包括对 Python 数组 API 标准 的实验性 JAX 支持。目前对此的支持是实验性的,且尚未完全完成。

示例用法:

代码语言:javascript复制
>>> from jax.experimental import array_api as xp

>>> xp.__array_api_version__
'2023.12'

>>> arr = xp.arange(1000)

>>> arr.sum()
Array(499500, dtype=int32) 

xp 命名空间是 jax.numpy 的数组 API 兼容版本,并实现了大部分标准中列出的 API。

jax.experimental.checkify 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.checkify.html

API

checkify(f[, errors])

将检查调用功能化在函数 f 中,并可选择添加运行时错误检查。

check(pred, msg, *fmt_args, **fmt_kwargs)

检查一个断言,如果断言为 False,则添加带有消息 msg 的错误。

check_error(error)

如果 error 表示失败,则抛出异常。

Error(_pred, _code, _metadata, _payload)

JaxRuntimeError

user_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

nan_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

index_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

div_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

float_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

automatic_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

all_checks

frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象

jax.experimental.host_callback 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.host_callback.html

在 JAX 加速器代码中调用 Python 函数的原语。

警告

自 2024 年 3 月 20 日起,host_callback API 已弃用。功能已被 新的 JAX 外部回调 所取代。请参阅 google/jax#20385。

此模块介绍了主机回调函数 call()id_tap()id_print(),它们将其参数从设备发送到主机,并在主机上调用用户定义的 Python 函数,可选地将结果返回到设备计算中。

我们展示了下面如何使用这些函数。我们从 call() 开始,并讨论从 JAX 调用 CPU 上任意 Python 函数的示例,例如使用 NumPy CPU 自定义核函数。然后我们展示了使用 id_tap()id_print(),它们的限制是不能将主机返回值传回设备。这些原语通常更快,因为它们与设备代码异步执行。特别是它们可用于连接到和调试 JAX 代码。

使用 call() 调用主机函数并将结果返回给设备

使用 call() 调用主机上的计算并将 NumPy 数组返回给设备上的计算。主机计算在以下情况下非常有用,例如当设备计算需要一些需要在主机上进行 I/O 的数据,或者它需要一个在主机上可用但不希望在 JAX 中编码的库时。例如,在 JAX 中一般矩阵的特征值分解在 TPU 上不起作用。我们可以从任何 JAX 加速计算中调用 Numpy 实现,使用主机计算:

代码语言:javascript复制
# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
  return np.linalg.eigvals(m)

# This function is used in JAX
def device_fun(m):
  # We send "m" to the host, asking it to call "host_eig" and return the result.
  # We have to specify the result shape and dtype, either in the form of an
  # example return value or any object that has `shape` and `dtype` attributes,
  # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
  return hcb.call(host_eig, m,
                  # Given an input of shape (..., d, d), eig output has shape (..., d)
                  result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype)) 

call() 函数和 Python 主机函数都接受一个参数并返回一个结果,但这些可以是 pytrees。注意,我们必须告诉 call() 从主机调用中期望的形状和 dtype,使用 result_shape 关键字参数。这很重要,因为设备代码是按照这个期望进行编译的。如果实际调用产生不同的结果形状,运行时会引发错误。通常,这样的错误以及主机计算引发的异常可能很难调试。请参见下面的调试部分。这对 call() 是一个问题,但对于 id_tap() 不是,因为对于后者,设备代码不期望返回值。

call() API 可以在 jit 或 pmap 计算内部使用,或在 cond/scan/while 控制流内部使用。当在 jax.pmap() 内部使用时,将从每个参与设备中分别调用主机:

代码语言:javascript复制
def host_sin(x, *, device):
  # The ``device`` argument is passed due to ``call_with_device=True`` below.
  print(f"Invoking host_sin with {x.shape} on {device}")
  return np.sin(x)

# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
                            result_shape=x,
                            # Ask that the `host_sin` function be passed `device=dev`
                            call_with_device=True))(
         np.ones((2, 4), dtype=np.float32))

# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1 

请注意,call()不支持任何 JAX 转换,但如下所示,可以利用现有的支持来自定义 JAX 中的导数规则。

使用id_tap()在主机上调用 Python 函数,不返回任何值。

id_tap()id_print()call()的特殊情况,当您只希望 Python 回调的副作用时。这些函数的优点是一旦参数已发送到主机,设备计算可以继续进行,而无需等待 Python 回调返回。对于id_tap(),您可以指定要调用的 Python 回调函数,而id_print()则使用一个内置回调,在主机的标准输出中打印参数。传递给id_tap()的 Python 函数接受两个位置参数(从设备计算中获取的值以及一个transforms元组,如下所述)。可选地,该函数可以通过关键字参数device传递设备从中获取的设备。

几个示例:

代码语言:javascript复制
def host_func(arg, transforms):
   ...do something with arg...

# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)

# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x))  # The argument can be a pytree

# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True)  # Pass the device to the tap

# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)

# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y)) 

所有上述示例都可以改用id_print(),只是id_print()会在主机上打印位置参数,以及任何额外的关键字参数和自动关键字参数transforms之间的区别。

使用barrier_wait()等待所有回调函数执行结束。

如果你的 Python 回调函数有副作用,可能需要等到计算完成,以确保副作用已被观察到。你可以使用barrier_wait()函数来实现这一目的:

代码语言:javascript复制
accumulator = []
def host_log(arg, transforms):
  # We just record the arguments in a list
  accumulator.append(arg)

def device_fun(x):
  id_tap(host_log, x)
  id_tap(host_log, 2. * x)

jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)

# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing. 

请注意,barrier_wait()将在jax.local_devices()的每个设备上启动一个微小的计算,并等待所有这些计算的结果被接收。

一个替代方案是使用barrier_wait()仅等待计算结束,如果所有回调都是call()的话:

代码语言:javascript复制
accumulator = p[]
def host_log(arg):
  # We just record the arguments in a list
  accumulator.append(arg)
  return 0.  #  return something

def device_fun(c):
  y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  return y   z  # return something that uses both results

res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready() 

并行化转换下的行为

在存在jax.pmap()的情况下,代码将在多个设备上运行,并且每个设备将独立地执行其值。建议为id_print()id_tap()使用tap_with_device选项可能会有所帮助,以便查看哪个设备发送了哪些数据:

代码语言:javascript复制
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x²: (3., 9.)  # from the first device
# device=cpu:1 what=x,x²: (4., 16.)  # from the second device 

使用jax.pmap()和多个主机上的多个设备时,每个主机将从其所有本地设备接收回调,带有与每个设备切片对应的操作数。对于call(),回调必须仅向每个设备返回与相应设备相关的结果切片。

当使用实验性的pjit.pjit()时,代码将在多个设备上运行,并在输入的不同分片上。当前主机回调的实现将确保单个设备将收集并输出整个操作数,在单个回调中。回调函数应返回整个数组,然后将其发送到发出输出的同一设备的单个进料中。然后,此设备负责将所需的分片发送到其他设备:

代码语言:javascript复制
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
  pjit.pjit(power3, in_shardings=(P("d"),),
            out_shardings=(P("d"),))(np.array([3., 4.]))

# device=TPU:0 what=x,x²: ( [3., 4.],
#                            [9., 16.] ) 

请注意,在一个设备上收集操作数可能会导致内存不足,如果操作数分布在多个设备上则情况类似。

当在多个设备上的多个主机上使用 pjit.pjit() 时,仅设备 0(相对于网格)上的主机将接收回调,其操作数来自所有参与设备上的所有主机。对于 call(),回调必须返回所有设备上所有主机的整个数组。

在 JAX 自动微分转换下的行为

在 JAX 自动微分转换下使用时,主机回调函数仅处理原始值。考虑以下示例:

代码语言:javascript复制
def power3(x):
  y = x * x
  # Print both 'x' and 'x²'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x²")
  return y * x

power3(3.)
# what: x,x² : (3., 9.) 

(您可以在 host_callback_test.HostCallbackTapTest.test_tap_transforms 中查看这些示例的测试。)

当在 jax.jvp() 下使用时,仅会有一个回调处理原始值:

代码语言:javascript复制
jax.jvp(power3, (3.,), (0.1,))
# what: x,x² : (3., 9.) 

类似地,对于 jax.grad(),我们仅从前向计算中得到一个回调:

代码语言:javascript复制
jax.grad(power3)(3.)
# what: x,x² : (3., 9.) 

如果您想在 jax.jvp() 中对切线进行回调处理,可以使用 custom_jvp。例如,您可以定义一个除了其 custom_jvp 会打印切线之外无趣的函数:

代码语言:javascript复制
@jax.custom_jvp
def print_tangents(arg):
  return None

@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
  arg_dot, = tangents
  hcb.id_print(arg_dot, what="tangents")
  return primals, tangents 

然后,您可以在想要触发切线的位置使用此函数:

代码语言:javascript复制
def power3_with_tangents(x):
  y = x * x
  # Print both 'x' and 'x²'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x²")
  print_tangents((x, y))
  return y * x

jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x² : (3., 9.)
# what: tangents : (0.1, 0.6) 

您可以在 jax.grad() 中做类似的事情来处理余切。这时,您必须小心使用在其余计算中需要的余切值。因此,我们使 print_cotangents 返回其参数:

代码语言:javascript复制
@jax.custom_vjp
def print_cotangents(arg):
  # Must return the argument for which we want the cotangent.
  return arg

# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
  return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
  hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
  return ct_b,

print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)

def power3_with_cotangents(x):
  y = x * x
  # Print both 'x' and 'x²'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x²", output_stream=testing_stream)
  (x1, y1) = print_cotangents((x, y))
  # Must use the output of print_cotangents
  return y1 * x1

jax.grad(power3_with_cotangents)(3.)
# what: x,x² : (3., 9.)
# what: cotangents : (9., 3.) 

如果您使用 ad_checkpoint.checkpoint() 来重新生成反向传播的残差,则原始计算中的回调将被调用两次:

代码语言:javascript复制
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x² : (3., 9.)
# what: x,x² : (27., 729.)
# what: x,x² : (3., 9.) 

这些回调依次是:内部 power3 的原始计算,外部 power3 的原始计算,以及内部 power3 的残差重新生成。

jax.vmap 下的行为

主机回调函数 id_print()id_tap() 支持矢量化转换 jax.vmap()

对于 jax.vmap(),回调的参数是批量处理的,并且回调函数会传递一个特殊的 transforms,其中包含转换描述符列表,格式为 ("batch", {"batch_dims": ...}),其中 ... 表示被触发值的批处理维度(每个参数一个条目,None 表示广播的参数)。

jax.vmap(power3)(np.array([2., 3.])) # transforms: [(‘batch’, {‘batch_dims’: (0, 0)})] what: x,x² : ([2., 3.], [4., 9.])

请参阅 id_tap()id_print()call() 的文档。

更多用法示例,请参阅 tests/host_callback_test.py

使用 call() 调用 TensorFlow 函数,支持反向模式自动微分

主机计算的另一个可能用途是调用为另一个框架编写的库,如 TensorFlow。在这种情况下,通过使用 jax.custom_vjp() 机制来支持主机回调的 JAX 自动微分变得有趣。

一旦理解了 JAX 自定义 VJP 和 TensorFlow autodiff 机制,这就相对容易做到。可以在 host_callback_to_tf_test.py 中的 call_tf_full_ad 函数中看到如何实现这一点。该示例还支持任意高阶微分。

请注意,如果只想从 JAX 调用 TensorFlow 函数,也可以使用 jax2tf.call_tf function。

使用 call() 在另一个设备上调用 JAX 函数,支持反向模式自动微分

我们可以使用主机计算来调用另一个设备上的 JAX 计算,并不奇怪。参数从加速器发送到主机,然后发送到将运行 JAX 主机计算的外部设备,然后将结果发送回原始加速器。

可以在 host_callback_test.py 中的 call_jax_other_device function 中看到如何实现这一点。

低级细节和调试

主机回调函数将按照在设备上执行发送操作的顺序执行。

多个设备的主机回调函数可能会交错执行。设备数据由 JAX 运行时管理的单独线程接收(每个设备一个线程)。运行时维护一个可配置大小的缓冲区(参见标志 --jax_host_callback_max_queue_byte_size)。当缓冲区满时,所有接收线程将被暂停,最终暂停设备上的计算。对于更多关于 outfeed 接收器运行时机制的细节,请参阅 runtime code。

要等待已经启动在设备上的计算的所有数据到达并被处理,可以使用 barrier_wait()

用户定义的回调函数抛出的异常以及它们的堆栈跟踪都会被记录,但接收线程不会停止。相反,最后一个异常被记录,并且随后的 barrier_wait() 将在任何一个 tap 函数中发生异常时引发 CallbackException。此异常将包含最后异常的文本和堆栈跟踪。

对于必须将结果返回给调用原点设备的回调函数(如call()),存在进一步的复杂性。这在 CPU/GPU 设备与 TPU 设备上处理方式不同。

在 CPU/GPU 设备上,为了避免设备计算因等待永远不会到达的结果而陷入困境,在处理回调过程中出现任何错误(无论是由用户代码自身引发还是由于返回值与期望返回形状不匹配而引发),我们会向设备发送一个形状为 int8[12345] 的“虚假”结果。这将导致设备计算中止,因为接收到的数据与其预期的数据不同。在 CPU 上,运行时将崩溃并显示特定的错误消息:

` 检查失败:buffer->length() == buffer_length (12345 vs. ...) `

在 GPU 上,这种失败会更加用户友好,并将其作为以下形式暴露给 Python 程序:

` RET_CHECK 失败 ... 输入源缓冲区形状为 s8[12345] 不匹配 ... `

要调试这些消息的根本原因,请参阅调试部分。

在 TPU 设备上,目前没有对输入源进行形状检查,因此我们采取更安全的方式,在出现错误时不发送此虚假结果。这意味着计算将会挂起,且不会引发异常(但回调函数中的任何异常仍将出现在日志中)。

当前实现使用 XLA 提供的出料机制。该机制本身在某种程度上相当原始,因为接收器必须准确知道每个传入数据包的形状和预期的数据包数量。这使得它在同一计算中难以用于多种数据类型,并且在非常量迭代次数的条件或循环中几乎不可能使用。此外,直接使用出料机制的代码无法由 JAX 进行转换。所有这些限制都通过主机回调函数得到解决。此处引入的 tapping API 可以轻松地用于多种目的共享出料机制,同时支持所有转换。

注意,在使用主机回调函数后,您不能直接使用 lax.outfeed。如果以后需要使用 lax.outfeed,则可能需要 stop_outfeed_receiver()

由于实际调用您的回调函数是从 C 接收器进行的,因此调试这些调用可能会很困难。特别是,堆栈跟踪不会包含调用代码。您可以使用标志 jax_host_callback_inline(或环境变量 JAX_HOST_CALLBACK_INLINE)确保回调函数的调用是内联的。这仅在调用位于非常量迭代次数的阶段上下文之外时有效(例如 jit() 或控制流原语)。

C 接收器 会在首次调用 id_tap() 时自动启动。为了正确停止它,在启动时注册了一个 atexit 处理程序,以带有日志名称“at_exit”调用 barrier_wait()

有几个环境变量可用于启用 C outfeed 接收器后端的日志记录(接收器后端)。

  • TF_CPP_MIN_LOG_LEVEL=0:将 INFO 日志打开,适用于以下所有内容。
  • TF_CPP_MIN_VLOG_LEVEL=3:将所有 VLOG 日志级别为 3 的行为设为 INFO 日志。这可能有些过多,但你将看到哪些模块记录了相关信息,然后你可以选择从哪些模块记录日志。
  • TF_CPP_VMODULE=<module_name>=3(模块名可以是 C 或 Python,不带扩展名)。

你还应该使用 --verbosity=2 标志,这样你就可以看到 Python 的日志。

例如,你可以尝试在 host_callback 模块中启用日志记录:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

如果你想在更低级别的实现模块中启用日志记录,请尝试:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

(对于 bazel 测试,请使用 –test_arg=–vmodule=…

仍需完成:

  • 更多性能测试。
  • 探索在 TPU 上进行外部编译实现。
  • 探索在 CPU 和 GPU 上使用 XLA CustomCall 进行实现。

API

id_tap(tap_func, arg, *[, result, …])

主机回调 tap 原语,类似于带有 tap_func 调用的恒等函数。

id_print(arg, *[, result, tap_with_device, …])

类似于 id_tap(),带有打印 tap 函数。

call(callback_func, arg, *[, result_shape, …])

调用主机,并期望得到结果。

barrier_wait([logging_name])

阻塞调用线程,直到所有当前 outfeed 处理完毕。

CallbackException

表示某些回调函数发生异常。

jax.experimental.maps 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.maps.html

API

xmap(fun, in_axes, out_axes, *[, …])

为使用命名数组轴的程序分配位置签名。

jax.experimental.pjit 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.pjit.html

API

代码语言:javascript复制
jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)

使fun编译并自动跨多设备分区。

注意:此函数现在等同于 jax.jit,请改用其代替。返回的函数语义与fun相同,但编译为在多个设备(例如多个 GPU 或多个 TPU 核心)上并行运行的 XLA 计算。如果fun的 jitted 版本无法适应单个设备的内存,或者为了通过在多个设备上并行运行每个操作来加速fun,这将非常有用。

设备上的分区自动基于in_shardings中指定的输入分区传播以及out_shardings中指定的输出分区进行。这两个参数中指定的资源必须引用由jax.sharding.Mesh()上下文管理器定义的网格轴。请注意,pjit()应用时的网格定义将被忽略,并且返回的函数将使用每个调用站点可用的网格定义。

未经正确分区的pjit()函数输入将自动跨设备分区。在某些情况下,确保输入已经正确预分区可能会提高性能。例如,如果将一个pjit()函数的输出传递给另一个pjit()函数(或者在循环中使用同一个pjit()函数),请确保相关的out_shardings与相应的in_shardings匹配。

注意

多进程平台: 在诸如 TPU pods 的多进程平台上,pjit()可用于跨所有可用设备和进程运行计算。为实现此目的,pjit()设计为用于 SPMD Python 程序,其中每个进程运行相同的 Python 代码,以便所有进程按相同顺序运行相同的pjit()函数。

在此配置中运行时,网格应包含跨所有进程的设备。所有输入参数必须具有全局形状。fun仍将在网格中的所有设备上执行,包括来自其他进程的设备,并且将以全局视图处理跨多个进程展布的数据作为单个数组。

SPMD 模型还要求所有进程中运行的相同多进程pjit()函数必须按相同顺序运行,但可以与在单个进程中运行的任意操作交替进行。

参数:

  • funCallable) - 要编译的函数。应为纯函数,因为副作用只能执行一次。其参数和返回值应为数组、标量或其(嵌套的)标准 Python 容器(元组/列表/字典)。由 static_argnums 指示的位置参数可以是任何东西,只要它们是可散列的并且定义了相等操作。静态参数包含在编译缓存键中,这就是为什么必须定义哈希和相等运算符。
  • in_shardings – 与 fun 参数匹配的 pytree 结构,所有实际参数都替换为资源分配规范。还可以指定一个 pytree 前缀(例如,替换整个子树的一个值),在这种情况下,叶子将广播到该子树的所有值。 in_shardings 参数是可选的。JAX 将从输入的 jax.Array 推断出分片,并在无法推断出分片时默认复制输入。 有效的资源分配规范包括:
    • Sharding,它将决定如何分区值。使用网格上下文管理器时,不需要此操作。
    • None 是一种特殊情况,其语义为:
      • 如果未提供网格上下文管理器,则 JAX 可以自由选择任何分片方式。对于 in_shardings,JAX 将其标记为复制,但此行为可能在将来更改。对于 out_shardings,我们将依赖于 XLA GSPMD 分区器来确定输出的分片方式。
      • 如果提供了网格上下文管理器,则 None 将意味着该值将复制到网格的所有设备上。
    • 为了向后兼容,in_shardings 仍支持接受 PartitionSpec。此选项只能与网格上下文管理器一起使用。
      • PartitionSpec,最多与分区值的秩相等长的元组。每个元素可以是 None,一个网格轴或网格轴的元组,并指定分配给分区值维度的资源集,与其在规范中的位置匹配。

    每个维度的大小必须是其分配的总资源数的倍数。

  • out_shardings – 类似于 in_shardings,但指定了函数输出的资源分配。out_shardings 参数是可选的。如果未指定,jax.jit() 将使用 GSPMD 的分片传播来确定如何分片输出。
  • static_argnumsint | Sequence [int] | None) – 可选的整数或整数集合,用于指定将哪些位置参数视为静态(编译时常量)。在 Python 中(在追踪期间),仅依赖于静态参数的操作将被常量折叠,因此相应的参数值可以是任何 Python 对象。 静态参数应该是可哈希的,即实现了 __hash____eq__,并且是不可变的。对于这些常量调用 jitted 函数时,使用不同的值将触发重新编译。不是数组或其容器的参数必须标记为静态。 如果未提供 static_argnums,则不将任何参数视为静态。
  • static_argnames (str | Iterable[str] | None) – 可选的字符串或字符串集合,指定要视为静态(编译时常量)的命名参数。有关详细信息,请参阅关于 static_argnums 的注释。如果未提供但设置了 static_argnums,则默认基于调用 inspect.signature(fun) 查找相应的命名参数。
  • donate_argnums (int | Sequence[int] | None) – 指定要“捐赠”给计算的位置参数缓冲区。如果计算结束后不再需要它们,捐赠参数缓冲区是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如将您的一个输入缓冲区循环利用来存储结果。您不应重新使用捐赠给计算的缓冲区,如果尝试则 JAX 会引发错误。默认情况下,不会捐赠任何参数缓冲区。 如果既未提供 donate_argnums 也未提供 donate_argnames,则不会捐赠任何参数。如果未提供 donate_argnums,但提供了 donate_argnames,或者反之,则 JAX 使用 inspect.signature(fun) 查找与 donate_argnames 相对应的任何位置参数(或反之)。如果同时提供了 donate_argnumsdonate_argnames,则不使用 inspect.signature,并且只有在 donate_argnumsdonate_argnames 中列出的实际参数将被捐赠。 有关缓冲区捐赠的更多详情,请参阅FAQ。
  • 捐赠参数名 (str | Iterable[str] | None) – 一个可选的字符串或字符串集合,指定哪些命名参数将捐赠给计算。有关详细信息,请参见对 donate_argnums 的注释。如果未提供但设置了 donate_argnums,则默认基于调用 inspect.signature(fun) 查找相应的命名参数。
  • 保留未使用 (bool) – 如果为 False(默认值),JAX 确定 fun 未使用的参数 可能 会从生成的编译后 XLA 可执行文件中删除。这些参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会剪枝未使用的参数。
  • 设备 (Device | None) – 此参数已弃用。请在将参数传递给 jit 之前将您需要的设备置于其上。可选,jit 函数将在其上运行的设备。 (可通过 jax.devices() 获取可用设备。)默认情况下,继承自 XLA 的 DeviceAssignment 逻辑,并通常使用 jax.devices()[0]
  • 后端 (str | None) – 此参数已弃用。请在将参数传递给 jit 之前将您需要的后端置于其前。可选,表示 XLA 后端的字符串:'cpu''gpu''tpu'
  • 内联 (bool)
  • 抽象轴 (Any | None)

返回:

fun 的包装版本,专为即时编译而设,并在每次调用点根据可用的网格自动分区。

返回类型:

JitWrapped

例如,卷积运算符可以通过单个 pjit() 应用自动分区到任意一组设备上:

代码语言:javascript复制
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from jax.sharding import Mesh, PartitionSpec
>>> from jax.experimental.pjit import pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
...         in_shardings=None, out_shardings=PartitionSpec('devices'))
>>> with Mesh(np.array(jax.devices()), ('devices',)):
...   print(f(x))  
[ 0.5  2.   4.   6.   8.  10.  12.  10. ] 

jax.experimental.sparse 模块

jax.readthedocs.io/en/latest/jax.experimental.sparse.html

jax.experimental.sparse 模块包括对 JAX 中稀疏矩阵操作的实验性支持。它正在积极开发中,API 可能会更改。主要提供的接口是 BCOO 稀疏数组类型和 sparsify() 变换。

批量坐标(BCOO)稀疏矩阵

JAX 中目前主要的高级稀疏对象是 BCOO,或者 批量坐标 稀疏数组,它提供与 JAX 变换兼容的压缩存储格式,特别是 JIT(例如 jax.jit())、批处理(例如 jax.vmap())和自动微分(例如 jax.grad())。

下面是一个从稠密数组创建稀疏数组的例子:

代码语言:javascript复制
>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np 
代码语言:javascript复制
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]]) 
代码语言:javascript复制
>>> M_sp = sparse.BCOO.fromdense(M) 
代码语言:javascript复制
>>> M_sp
BCOO(float32[3, 4], nse=4) 

使用 todense() 方法转换回稠密数组:

代码语言:javascript复制
>>> M_sp.todense()
Array([[0., 1., 0., 2.],
 [3., 0., 0., 0.],
 [0., 0., 4., 0.]], dtype=float32) 

BCOO 格式是标准 COO 格式的一种略微修改版本,密集表示可以在 dataindices 属性中看到:

代码语言:javascript复制
>>> M_sp.data  # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32) 
代码语言:javascript复制
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
 [0, 3],
 [1, 0],
 [2, 2]], dtype=int32) 

BCOO 对象具有类似数组的属性,以及稀疏特定的属性:

代码语言:javascript复制
>>> M_sp.ndim
2 
代码语言:javascript复制
>>> M_sp.shape
(3, 4) 
代码语言:javascript复制
>>> M_sp.dtype
dtype('float32') 
代码语言:javascript复制
>>> M_sp.nse  # "number of specified elements"
4 

BCOO 对象还实现了许多类数组的方法,允许您直接在 jax 程序中使用它们。例如,在这里我们计算转置矩阵向量乘积:

代码语言:javascript复制
>>> y = jnp.array([3., 6., 5.]) 
代码语言:javascript复制
>>> M_sp.T @ y
Array([18.,  3., 20.,  6.], dtype=float32) 
代码语言:javascript复制
>>> M.T @ y  # Compare to dense version
Array([18.,  3., 20.,  6.], dtype=float32) 

BCOO 对象设计成与 JAX 变换兼容,包括 jax.jit()jax.vmap()jax.grad() 等。例如:

代码语言:javascript复制
>>> from jax import grad, jit 
代码语言:javascript复制
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32) 

注意,正常情况下,jax.numpyjax.lax 函数不知道如何处理稀疏矩阵,因此尝试计算诸如 jnp.dot(M_sp.T, y) 的东西将导致错误(但请参见下一节)。

稀疏化变换

JAX 稀疏实现的一个主要目标是提供一种无缝从密集到稀疏计算切换的方法,而无需修改密集实现。这个稀疏实验通过 sparsify() 变换实现了这一目标。

考虑这个函数,它从矩阵和向量输入计算更复杂的结果:

代码语言:javascript复制
>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v)   1
...
>>> f(M, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32) 

如果我们直接传递稀疏矩阵到这个函数,将会导致错误,因为 jnp 函数不识别稀疏输入。然而,使用 sparsify(),我们得到一个接受稀疏矩阵的函数版本:

代码语言:javascript复制
>>> f_sp = sparse.sparsify(f) 
代码语言:javascript复制
>>> f_sp(M_sp, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32) 

sparsify() 支持包括许多最常见的原语,例如:

  • 广义(批量)矩阵乘积和爱因斯坦求和(dot_general_p
  • 保持零的逐元素二元操作(例如 add_pmul_p 等)
  • 保持零的逐元素一元操作(例如 abs_pjax.lax.neg_p 等)
  • 求和约简(reduce_sum_p
  • 通用索引操作(slice_plax.dynamic_slice_plax.gather_p
  • 连接和堆叠(concatenate_p
  • 转置和重塑(transpose_preshape_psqueeze_pbroadcast_in_dim_p
  • 一些高阶函数(cond_pwhile_pscan_p
  • 一些简单的 1D 卷积(conv_general_dilated_p

几乎任何 jax.numpy 函数在 sparsify 转换中都可以使用,以操作稀疏数组。这组基元足以支持相对复杂的稀疏工作流程,如下一节所示。

示例:稀疏逻辑回归

作为更复杂稀疏工作流的示例,让我们考虑在 JAX 中实现的简单逻辑回归。请注意,以下实现与稀疏性无关:

代码语言:javascript复制
>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize 
代码语言:javascript复制
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2)   1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:])   params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat)   (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1]   1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x 
代码语言:javascript复制
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
 0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547] 

这会返回密集逻辑回归问题的最佳拟合参数。要在稀疏数据上拟合相同的模型,我们可以应用sparsify()转换:

代码语言:javascript复制
>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
 0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663] 

稀疏 API 参考

sparsify(f[, use_tracer])

实验性稀疏化转换。

grad(fun[, argnums, has_aux])

jax.grad() 的稀疏版本

value_and_grad(fun[, argnums, has_aux])

jax.value_and_grad() 的稀疏版本

empty(shape[, dtype, index_dtype, sparse_format])

创建空稀疏数组。

eye(N[, M, k, dtype, index_dtype, sparse_format])

创建二维稀疏单位矩阵。

todense(arr)

将输入转换为密集矩阵。

random_bcoo(key, shape, *[, dtype, …])

生成随机 BCOO 矩阵。

JAXSparse(args, *, shape)

高级 JAX 稀疏对象的基类。

BCOO 数据结构

BCOOBatched COO format,是在 jax.experimental.sparse 中实现的主要稀疏数据结构。其操作与 JAX 的核心转换兼容,包括批处理(例如 jax.vmap())和自动微分(例如 jax.grad())。

BCOO(args, *, shape[, indices_sorted, …])

在 JAX 中实现的实验性批量 COO 矩阵

bcoo_broadcast_in_dim(mat, *, shape, …)

通过复制数据来扩展 BCOO 数组的大小和秩。

bcoo_concatenate(operands, *, dimension)

jax.lax.concatenate() 的稀疏实现

bcoo_dot_general(lhs, rhs, *, dimension_numbers)

一般的收缩操作。

bcoo_dot_general_sampled(A, B, indices, *, …)

在给定稀疏索引处计算输出的收缩操作。

bcoo_dynamic_slice(mat, start_indices, …)

jax.lax.dynamic_slice 的稀疏实现。

bcoo_extract(sparr, arr, *[, assume_unique])

根据稀疏数组的索引从密集数组中提取值。

bcoo_fromdense(mat, *[, nse, n_batch, …])

从密集矩阵创建 BCOO 格式的稀疏矩阵。

bcoo_gather(operand, start_indices, …[, …])

lax.gather 的 BCOO 版本。

bcoo_multiply_dense(sp_mat, v)

稀疏数组和密集数组的逐元素乘法。

bcoo_multiply_sparse(lhs, rhs)

两个稀疏数组的逐元素乘法。

bcoo_update_layout(mat, *[, n_batch, …])

更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。

bcoo_reduce_sum(mat, *, axes)

在给定轴上对数组元素求和。

bcoo_reshape(mat, *, new_sizes[, dimensions])

{func}jax.lax.reshape的稀疏实现。

bcoo_slice(mat, *, start_indices, limit_indices)

{func}jax.lax.slice的稀疏实现。

bcoo_sort_indices(mat)

对 BCOO 数组的索引进行排序。

bcoo_squeeze(arr, *, dimensions)

{func}jax.lax.squeeze的稀疏实现。

bcoo_sum_duplicates(mat[, nse])

对 BCOO 数组中的重复索引求和,返回一个排序后的索引数组。

bcoo_todense(mat)

将批量稀疏矩阵转换为密集矩阵。

bcoo_transpose(mat, *, permutation)

转置 BCOO 格式的数组。

BCSR 数据结构

BCSR批量压缩稀疏行格式,正在开发中。其操作与 JAX 的核心转换兼容,包括批处理(如jax.vmap())和自动微分(如jax.grad())。

BCSR(args, *, shape[, indices_sorted, …])

在 JAX 中实现的实验性批量 CSR 矩阵。

bcsr_dot_general(lhs, rhs, *, dimension_numbers)

通用收缩运算。

bcsr_extract(indices, indptr, mat)

从给定的 BCSR(indices, indptr)处的密集矩阵中提取值。

bcsr_fromdense(mat, *[, nse, n_batch, …])

从密集矩阵创建 BCSR 格式的稀疏矩阵。

bcsr_todense(mat)

将批量稀疏矩阵转换为密集矩阵。

其他稀疏数据结构

其他稀疏数据结构包括COOCSRCSC。这些是简单稀疏结构的参考实现,具有少数核心操作。它们的操作通常与自动微分转换(如jax.grad())兼容,但不与批处理转换(如jax.vmap())兼容。

COO(args, *, shape[, rows_sorted, cols_sorted])

在 JAX 中实现的实验性 COO 矩阵。

CSC(args, *, shape)

在 JAX 中实现的实验性 CSC 矩阵;API 可能会更改。

CSR(args, *, shape)

在 JAX 中实现的实验性 CSR 矩阵。

coo_fromdense(mat, *[, nse, index_dtype])

从密集矩阵创建 COO 格式的稀疏矩阵。

coo_matmat(mat, B, *[, transpose])

COO 稀疏矩阵与密集矩阵的乘积。

coo_matvec(mat, v[, transpose])

COO 稀疏矩阵与密集向量的乘积。

coo_todense(mat)

将 COO 格式的稀疏矩阵转换为密集矩阵。

csr_fromdense(mat, *[, nse, index_dtype])

从密集矩阵创建 CSR 格式的稀疏矩阵。

csr_matmat(mat, B, *[, transpose])

CSR 稀疏矩阵与密集矩阵的乘积。

csr_matvec(mat, v[, transpose])

CSR 稀疏矩阵与密集向量的乘积。

csr_todense(mat)

将 CSR 格式的稀疏矩阵转换为密集矩阵。

jax.experimental.sparse.linalg

稀疏线性代数例程。

spsolve(data, indices, indptr, b[, tol, reorder])

使用 QR 分解的稀疏直接求解器。

lobpcg_standard(A, X[, m, tol])

使用 LOBPCG 例程计算前 k 个标准特征值。

jax.experimental.sparse.BCOO

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCOO.html

代码语言:javascript复制
class jax.experimental.sparse.BCOO(args, *, shape, indices_sorted=False, unique_indices=False)

在 JAX 中实现的实验性批量 COO 矩阵

参数:

  • **(**data – 批量 COO 格式中的数据和索引。
  • indices**)** – 批量 COO 格式中的数据和索引。
  • shape (tuple[int, …**]) – 稀疏数组的形状。
  • args (tuple[Array,* Array]*)
  • indices_sorted (bool)
  • unique_indices (bool)
代码语言:javascript复制
data

形状为[*batch_dims, nse, *dense_dims]的 ndarray,包含稀疏矩阵中显式存储的数据。

类型:

jax.Array

代码语言:javascript复制
indices

形状为[*batch_dims, nse, n_sparse]的 ndarray,包含显式存储数据的索引。重复的条目将被求和。

类型:

jax.Array

示例

从稠密数组创建稀疏数组:

代码语言:javascript复制
>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]])
>>> M_sp = BCOO.fromdense(M)
>>> M_sp
BCOO(float32[2, 3], nse=3) 

检查内部表示:

代码语言:javascript复制
>>> M_sp.data
Array([2., 1., 4.], dtype=float32)
>>> M_sp.indices
Array([[0, 1],
 [1, 0],
 [1, 2]], dtype=int32) 

从稀疏数组创建稠密数组:

代码语言:javascript复制
>>> M_sp.todense()
Array([[0., 2., 0.],
 [1., 0., 4.]], dtype=float32) 

从 COO 数据和索引创建稀疏数组:

代码语言:javascript复制
>>> data = jnp.array([1., 3., 5.])
>>> indices = jnp.array([[0, 0],
...                      [1, 1],
...                      [2, 2]])
>>> mat = BCOO((data, indices), shape=(3, 3))
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
Array([[1., 0., 0.],
 [0., 3., 0.],
 [0., 0., 5.]], dtype=float32) 
代码语言:javascript复制
__init__(args, *, shape, indices_sorted=False, unique_indices=False)

参数:

  • args (tuple[Array,* Array]*)
  • shape (Sequence[int])
  • indices_sorted (bool)
  • unique_indices (bool)

方法

__init__(args, *, shape[, indices_sorted, …])

astype(*args, **kwargs)

复制数组并转换为指定的 dtype。

block_until_ready()

from_scipy_sparse(mat, *[, index_dtype, …])

从scipy.sparse数组创建 BCOO 数组。

fromdense(mat, *[, nse, index_dtype, …])

从(稠密)Array创建 BCOO 数组。

reshape(*args, **kwargs)

返回具有新形状的相同数据的数组。

sort_indices()

返回索引排序后的矩阵副本。

sum(*args, **kwargs)

沿轴求和数组。

sum_duplicates([nse, remove_zeros])

返回重复索引求和后的数组副本。

todense()

创建数组的稠密版本。

transpose([axes])

创建包含转置的新数组。

tree_flatten()

tree_unflatten(aux_data, children)

update_layout(*[, n_batch, n_dense, …])

更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。

属性

T

dtype

n_batch

n_dense

n_sparse

ndim

nse

size

data

indices

shape

indices_sorted

unique_indices

jax.experimental.sparse.bcoo_broadcast_in_dim

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_broadcast_in_dim.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_broadcast_in_dim(mat, *, shape, broadcast_dimensions)

通过复制数据扩展 BCOO 数组的大小和秩。

BCOO 相当于 jax.lax.broadcast_in_dim。

参数:

  • matBCOO) – BCOO 格式的数组。
  • shapetuple[int,* ]*) – 目标数组的形状。
  • broadcast_dimensionsSequence[int]) – 目标数组形状的维度,每个操作数(mat)形状对应一个维度。

返回:

包含目标数组的 BCOO 格式数组。

返回类型:

BCOO

jax.experimental.sparse.bcoo_concatenate

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_concatenate.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_concatenate(operands, *, dimension)

稀疏实现的jax.lax.concatenate()函数

参数:

  • operandsSequence[BCOO]) – 要连接的 BCOO 数组序列。这些数组必须具有相同的形状,除了在维度轴上。此外,这些数组必须具有等效的批处理、稀疏和密集维度。
  • dimensionint) – 指定沿其连接数组的维度的正整数。维度必须是输入的批处理或稀疏维度之一;不支持沿密集维度的连接。

返回值:

包含输入数组连接的 BCOO 数组。

返回类型:

BCOO

jax.experimental.sparse.bcoo_dot_general

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dot_general.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_dot_general(lhs, rhs, *, dimension_numbers, precision=None, preferred_element_type=None)

一般的收缩操作。

参数:

  • lhsBCOO | Array) – 一个 ndarray 或 BCOO 格式的稀疏数组。
  • rhsBCOO | Array) – 一个 ndarray 或 BCOO 格式的稀疏数组。
  • dimension_numberstuple[tuple[Sequence[int]**, Sequence[int]], tuple[Sequence[int]**, Sequence[int]]]) – 一个形如((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))的元组的元组。
  • precisionNone) – 未使用
  • preferred_element_typeNone) – 未使用

返回:

一个包含结果的 ndarray 或 BCOO 格式的稀疏数组。如果两个输入都是稀疏的,结果将是稀疏的,类型为 BCOO。如果任一输入是密集的,结果将是密集的,类型为 ndarray。

返回类型:

BCOO | Array

jax.experimental.sparse.bcoo_dot_general_sampled

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dot_general_sampled.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)

给定稀疏索引处计算输出的收缩操作。

参数:

  • lhs – 一个 ndarray。
  • rhs – 一个 ndarray。
  • indicesArray) – BCOO 索引。
  • dimension_numberstuple[tuple[Sequence[int]**, Sequence[int]], tuple[Sequence[int]**, Sequence[int]]]) – 形式为 ((lhs 收缩维度,rhs 收缩维度),(lhs 批次维度,rhs 批次维度)) 的元组的元组。
  • AArray
  • BArray

返回:

BCOO 数据,包含结果的 ndarray。

返回类型:

Array

jax.experimental.sparse.bcoo_dynamic_slice

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dynamic_slice.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_dynamic_slice(mat, start_indices, slice_sizes)

{func}jax.lax.dynamic_slice的稀疏实现。

参数:

  • mat (BCOO) – 要切片的 BCOO 数组。
  • start_indices (Sequence[Any]) – 每个维度的标量索引列表。这些值可能是动态的。
  • slice_sizes (Sequence[int]) – 切片的大小。必须是非负整数序列,长度等于操作数的维度数。在 JIT 编译的函数内部,仅支持静态值(所有 JAX 数组在 JIT 内必须具有静态已知大小)。

返回:

包含切片的 BCOO 数组。

返回类型:

out

jax.experimental.sparse.bcoo_extract

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_extract.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_extract(sparr, arr, *, assume_unique=None)

根据稀疏数组的索引从密集数组中提取值。

参数:

  • sparr (BCOO) – 用于输出的 BCOO 数组的索引。
  • arr (jax.typing.ArrayLike) – 形状与 self.shape 相同的 ArrayLike
  • assume_unique (bool | None) – 布尔值,默认为 sparr.unique_indices。如果为 True,则提取每个索引的值,即使索引包含重复项。如果为 False,则重复的索引将其值求和,并返回第一个索引的位置。

返回:

一个具有与 self 相同稀疏模式的 BCOO 数组。

返回类型:

提取的结果

jax.experimental.sparse.bcoo_fromdense

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_fromdense.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=<class 'jax.numpy.int32'>)

从密集矩阵创建 BCOO 格式的稀疏矩阵。

参数:

  • matArray)– 要转换为 BCOO 格式的数组。
  • nseint | None)– 每个批次中指定元素的数量
  • n_batchint)– 批次维度的数量(默认:0)
  • n_denseint)– 块维度的数量(默认:0)
  • index_dtypejax.typing.DTypeLike)– 稀疏索引的数据类型(默认:int32)

返回:

矩阵的 BCOO 表示。

返回类型:

mat_bcoo

jax.experimental.sparse.bcoo_gather

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_gather.html

代码语言:javascript复制
jax.experimental.sparse.bcoo_gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)

BCOO 版本的 lax.gather。

参数:

  • operand (BCOO)
  • start_indices (数组)
  • dimension_numbers (GatherDimensionNumbers)
  • slice_sizes (tuple[int, …**])
  • unique_indices (bool)
  • indices_are_sorted (bool)
  • mode (str | GatherScatterMode | None)

返回类型:

BCOO

0 人点赞