原文:
jax.readthedocs.io/en/latest/
jax.experimental.sparse.bcoo_multiply_dense
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_multiply_dense.html
jax.experimental.sparse.bcoo_multiply_dense(sp_mat, v)
稀疏数组和稠密数组之间的逐元素乘法。
参数:
- lhs – 一个 BCOO 格式的数组。
- rhs – 一个 ndarray。
- sp_mat(BCOO)
- v(Array)
返回:
包含结果的 ndarray。
返回类型:
Array
jax.experimental.sparse.bcoo_multiply_sparse
代码语言:javascript复制jax.experimental.sparse.bcoo_multiply_sparse
jax.experimental.sparse.bcoo_multiply_sparse(lhs, rhs)
两个稀疏数组的逐元素乘积。
参数:
- lhs (BCOO) – 一个 BCOO 格式的数组。
- rhs (BCOO) – 一个 BCOO 格式的数组。
返回值:
包含结果的 BCOO 格式数组。
返回类型:
BCOO
jax.experimental.sparse.bcoo_update_layout
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_update_layout.html
jax.experimental.sparse.bcoo_update_layout(mat, *, n_batch=None, n_dense=None, on_inefficient='error')
更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。
在许多情况下,可以在不引入不必要的存储开销的情况下完成此操作。然而,增加 mat.n_batch
或 mat.n_dense
将导致存储效率非常低下,许多零值都是显式存储的,除非新的批处理或密集维度的大小为 0 或 1。在这种情况下,bcoo_update_layout
将引发 SparseEfficiencyError
。可以通过指定 on_inefficient
参数来消除此警告。
参数:
- mat(BCOO) – BCOO 数组
- n_batch(int | None) – 可选参数(整数),输出矩阵中批处理维度的数量。如果为 None,则 n_batch = mat.n_batch。
- n_dense(int | None) – 可选参数(整数),输出矩阵中密集维度的数量。如果为 None,则 n_dense = mat.n_dense。
- on_inefficient(str | None) – 可选参数(字符串),其中之一
['error', 'warn', None]
。指定在重新配置效率低下的情况下的行为。这被定义为结果表示的大小远大于输入表示的情况。
返回:
BCOO 数组
表示与输入相同的稀疏数组的 BCOO 数组,具有指定的布局。 mat_out.todense()
将与 mat.todense()
在适当的精度上匹配。
返回类型:
mat_out
jax.experimental.sparse.bcoo_reduce_sum
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_reduce_sum.html
jax.experimental.sparse.bcoo_reduce_sum(mat, *, axes)
对给定轴上的数组元素求和。
参数:
- mat(BCOO) – 一个 BCOO 格式的数组。
- shape – 目标数组的形状。
- axes(Sequence[int]) – 包含
mat
上进行求和的轴的元组、列表或 ndarray。
返回:
包含结果的 BCOO 格式数组。
返回类型:
BCOO
jax.experimental.sparse.bcoo_reshape
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_reshape.html
jax.experimental.sparse.bcoo_reshape(mat, *, new_sizes, dimensions=None)
稀疏实现的{func}jax.lax.reshape
。
参数:
- operand – 待重塑的 BCOO 数组。
- new_sizes (Sequence[int]) – 指定结果形状的整数序列。最终数组的大小必须与输入的大小相匹配。这必须指定为批量、稀疏和密集维度不混合的形式。
- dimensions (Sequence[int] | None) – 可选的整数序列,指定输入形状的排列顺序。如果指定,长度必须与
operand.shape
相匹配。此外,维度必须仅在 mat 的相似维度之间进行排列:批量、稀疏和密集维度不能混合排列。 - mat (BCOO)
返回:
重塑后的数组。
返回类型:
输出
jax.experimental.sparse.bcoo_slice
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_slice.html
jax.experimental.sparse.bcoo_slice(mat, *, start_indices, limit_indices, strides=None)
{func}jax.lax.slice
的稀疏实现。
参数:
- mat (BCOO) – 待重新形状的 BCOO 数组。
- 起始索引 (Sequence[int]) – 长度为 mat.ndim 的整数序列,指定每个切片的起始索引。
- 限制索引 (Sequence[int]) – 长度为 mat.ndim 的整数序列,指定每个切片的结束索引
- 步幅 (Sequence[int] | None) – (未实现) 长度为 mat.ndim 的整数序列,指定每个切片的步幅
返回:
包含切片的 BCOO 数组。
返回类型:
输出
jax.experimental.sparse.bcoo_sort_indices
代码语言:javascript复制链接:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_sort_indices.html
jax.experimental.sparse.bcoo_sort_indices(mat)
排序一个 BCOO 数组的索引。
参数:
mat(BCOO)– BCOO 数组
返回:
带有已排序索引的 BCOO 数组。
返回类型:
mat_out
jax.experimental.sparse.bcoo_squeeze
代码语言:javascript复制
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_squeeze.html
jax.experimental.sparse.bcoo_squeeze(arr, *, dimensions)
{func}jax.lax.squeeze
的稀疏实现。
从数组中挤出任意数量的大小为 1 的维度。
参数:
- arr (BCOO) – 要重新塑形的 BCOO 数组。
- 维度 (Sequence[int]) – 指定要挤压的整数序列。
返回:
重新塑形的数组。
返回类型:
out
jax.experimental.sparse.bcoo_sum_duplicates
代码语言:javascript复制原文
jax.experimental.sparse.bcoo_sum_duplicates(mat, nse=None)
对 BCOO 数组内的重复索引求和,返回一个带有排序索引的数组。
参数:
- mat (BCOO) – BCOO 数组
- nse (int | None) – 整数(可选)。输出矩阵中指定元素的数量。这必须指定以使 bcoo_sum_duplicates 兼容 JIT 和其他 JAX 变换。如果未指定,将根据数据和索引数组的内容计算最佳 nse。如果指定的 nse 大于必要的数量,将使用标准填充值填充数据和索引数组。如果小于必要的数量,将从输出矩阵中删除数据元素。
返回:
BCOO 数组具有排序索引且无重复索引。
返回类型:
mat_out
jax.experimental.sparse.bcoo_todense
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_todense.html
jax.experimental.sparse.bcoo_todense(mat)
将批处理稀疏矩阵转换为稠密矩阵。
参数:
mat(BCOO)– BCOO 矩阵。
返回:
mat
的稠密版本。
返回类型:
mat_dense
jax.experimental.sparse.bcoo_transpose
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_transpose.html
jax.experimental.sparse.bcoo_transpose(mat, *, permutation)
转置 BCOO 格式的数组。
参数:
- mat (BCOO) – 一个 BCOO 格式的数组。
- permutation (Sequence[int]) – 一个元组、列表或 ndarray,其中包含对
mat
的轴进行排列的置换,顺序为批处理、稀疏和稠密维度。返回数组的第 i 个轴对应于mat
的编号为 permutation[i] 的轴。目前,转置置换不支持将批处理轴与非批处理轴混合,也不支持将稠密轴与非稠密轴混合。
返回:
BCOO 格式的数组。
返回类型:
BCOO
jax.experimental.jet 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.jet.html
Jet 是一个实验性模块,用于更高阶的自动微分,不依赖于重复的一阶自动微分。
如何?通过截断的泰勒多项式的传播。考虑一个函数 ( f = g circ h ),某个点 ( x ) 和某个偏移 ( v )。一阶自动微分(如 jax.jvp()
)从对 ((h(x), partial h(x)[v])) 的计算得到对 ((f(x), partial f(x)[v])) 的计算。
jet()
实现了更高阶的类似方法:给定元组
((h_0, … h_K) := (h(x), partial h(x)[v], partial² h(x)[v, v], …, partial^K h(x)[v,…,v])),
代表在 ( x ) 处 ( h ) 的 ( K ) 阶泰勒近似,jet()
返回在 ( x ) 处 ( f ) 的 ( K ) 阶泰勒近似,
((f_0, …, f_K) := (f(x), partial f(x)[v], partial² f(x)[v, v], …, partial^K f(x)[v,…,v])).
更具体地说,jet()
计算
[f_0, (f_1, . . . , f_K) = texttt{jet} (f, h_0, (h_1, . . . , h_K))]
因此可用于 ( f ) 的高阶自动微分。详细内容请参见 这些注释。
注
通过贡献 优秀的原始规则 来改进 jet()
。
API
代码语言:javascript复制jax.experimental.jet.jet(fun, primals, series)
泰勒模式高阶自动微分。
参数:
- fun – 要进行微分的函数。其参数应为数组、标量或标准 Python 容器中的数组或标量。应返回一个数组、标量或标准 Python 容器中的数组或标量。
- primals – 应评估
fun
泰勒近似值的原始值。应该是参数的元组或列表,并且其长度应与fun
的位置参数数量相等。 - 系列 – 更高阶的泰勒级数系数。原始数据和系列数据组成了一个截断的泰勒多项式。应该是一个元组或列表,其长度决定了截断的泰勒多项式的阶数。
返回:
一个 (primals_out, series_out)
对,其中 primals_out
是 fun(*primals)
的值,primals_out
和 series_out
一起构成了 ( f(h(cdot)) ) 的截断泰勒多项式。primals_out
的值具有与 primals
相同的 Python 树结构,series_out
的值具有与 series
相同的 Python 树结构。
例如:
代码语言:javascript复制>>> import jax
>>> import jax.numpy as np
考虑函数 ( h(z) = z³ ),( x = 0.5 ),和前几个泰勒系数 ( h_0=x³ ),( h_1=3x² ),( h_2=6x )。让 ( f(y) = sin(y) )。
代码语言:javascript复制>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
jet()
根据法阿·迪布鲁诺公式返回 ( f(h(z)) = sin(z³) ) 的泰勒系数:
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),))
>>> print(f0, f(h0))
0.12467473 0.12467473
代码语言:javascript复制>>> print(f1, df(h0) * h1)
0.7441479 0.74414825
代码语言:javascript复制>>> print(f2, ddf(h0) * h1 ** 2 df(h0) * h2)
2.9064622 2.9064634
jax.experimental.custom_partitioning 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html
API
代码语言:javascript复制jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())
在 XLA 图中插入一个 CustomCallOp,并使用自定义的 SPMD 降低规则。
代码语言:javascript复制@custom_partitioning
def f(*args):
return ...
def propagate_user_sharding(mesh, user_shape):
'''Update the sharding of the op from a user's shape.sharding.'''
user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)
def partition(mesh, arg_shapes, result_shape):
def lower_fn(*args):
... builds computation on per-device shapes ...
result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
# result_sharding and arg_shardings may optionally be modified and the
# partitioner will insert collectives to reshape.
return mesh, lower_fn, result_sharding, arg_shardings
def infer_sharding_from_operands(mesh, arg_shapes, shape):
'''Compute the result sharding from the sharding of the operands.'''
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
def_partition
的参数如下:
-
propagate_user_sharding
:一个可调用对象,接受用户(在 DAG 中)的分片并返回一个新的 NamedSharding 的建议。默认实现只是返回建议的分片。 -
partition
:一个可调用对象,接受 SPMD 建议的分片形状和分片规格,并返回网格、每个分片的降低函数以及最终的输入和输出分片规格(SPMD 分片器将重新分片输入以匹配)。返回网格以允许在未提供网格时配置集体的 axis_names。 -
infer_sharding_from_operands
:一个可调用对象,从每个参数选择的NamedSharding
中计算输出的NamedSharding
。 -
decode_shardings
:当设置为 True 时,如果可能,从输入中转换pyGSPMDSharding``s to ``NamedSharding
。如果用户未提供上下文网格,则可能无法执行此操作。
可以使用 static_argnums 将位置参数指定为静态参数。JAX 使用 inspect.signature(fun)
来解析这些位置参数。
示例
例如,假设我们想增强现有的 jax.numpy.fft.fft
。该函数计算 N 维输入沿最后一个维度的离散 Fourier 变换,并且在前 N-1 维度上进行批处理。但是,默认情况下,它会忽略输入的分片并在所有设备上收集输入。然而,由于 jax.numpy.fft.fft
在前 N-1 维度上进行批处理,这是不必要的。我们将创建一个新的 my_fft
操作,它不会改变前 N-1 维度上的分片,并且仅在需要时沿最后一个维度收集输入。
import jax
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np
# Pattern to detect all-gather or dynamic-slice in the generated HLO
_PATTERN = '(dynamic-slice|all-gather)'
# For an N-D input, keeps sharding along the first N-1 dimensions
# but replicate along the last dimension
def supported_sharding(sharding, shape):
rank = len(shape.shape)
max_shared_dims = min(len(sharding.spec), rank-1)
names = tuple(sharding.spec[:max_shared_dims]) tuple(None for _ in range(rank - max_shared_dims))
return NamedSharding(sharding.mesh, P(*names))
def partition(mesh, arg_shapes, result_shape):
result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return mesh, fft, supported_sharding(arg_shardings[0], arg_shapes[0]), (supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return supported_sharding(arg_shardings[0], arg_shapes[0])
@custom_partitioning
def my_fft(x):
return fft(x)
my_fft.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition)
现在创建一个沿第一个轴分片的二维数组,通过 my_fft
处理它,并注意它仍按预期进行分片,并且与 fft
的输出相同。但是,检查 HLO(使用 lower(x).compile().runtime_executable().hlo_modules()
)显示 my_fft
不创建任何全收集或动态切片,而 fft
则创建。
with Mesh(np.array(jax.devices()), ('x',)):
x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x'))
print(pjit_my_fft(y))
print(pjit_fft(y))
# dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
# dynamic-slice or all-gather are present in the HLO for fft
assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
代码语言:javascript复制# my_fft
[[-38.840824 0.j -40.649452 11.845365j
...
-1.6937828 0.8402481j 15.999859 -4.0156755j]]
# jax.numpy.fft.fft
[[-38.840824 0.j -40.649452 11.845365j
...
-1.6937828 0.8402481j 15.999859 -4.0156755j]]
由于 supported_sharding
中的逻辑,my_fft
也适用于一维数组。但是,在这种情况下,my_fft
的 HLO 显示动态切片,因为最后一个维度是计算 FFT 的维度,在计算之前需要在所有设备上复制。
with Mesh(np.array(jax.devices()), ('x',)):
x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x'))
print(pjit_my_fft(y))
print(pjit_fft(y))
# dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
# dynamic-slice or all-gather are present in the HLO for fft
assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
代码语言:javascript复制# my_fft
[ 7.217285 0.j -3012.4937 4287.635j -405.83594 3042.984j
... 1422.4502 7271.4297j -405.84033 -3042.983j
-3012.4963 -4287.6343j]
# jax.numpy.fft.fft
[ 7.217285 0.j -3012.4937 4287.635j -405.83594 3042.984j
... 1422.4502 7271.4297j -405.84033 -3042.983j
-3012.4963 -4287.6343j]
jax.experimental.multihost_utils 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.multihost_utils.html
用于跨多个主机同步和通信的实用程序。
多主机工具 API 参考
broadcast_one_to_all(in_tree[, is_source]) | 从源主机(默认为主机 0)向所有其他主机广播数据。 |
---|---|
sync_global_devices(name) | 在所有主机/设备之间创建屏障。 |
process_allgather(in_tree[, tiled]) | 从各个进程收集数据。 |
assert_equal(in_tree[, fail_message]) | 验证所有主机具有相同的值树。 |
host_local_array_to_global_array(…) | 将主机本地值转换为全局分片的 jax.Array。 |
global_array_to_host_local_array(…) | 将全局 jax.Array 转换为主机本地 jax.Array。 |
jax.experimental.compilation_cache 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.compilation_cache.html
JAX 磁盘编译缓存。
API
代码语言:javascript复制jax.experimental.compilation_cache.compilation_cache.is_initialized()
已废弃。
返回缓存是否已启用。初始化可以延迟,因此不会检查初始化状态。该名称保留以确保向后兼容性。
返回类型:
bool
代码语言:javascript复制jax.experimental.compilation_cache.compilation_cache.initialize_cache(path)
此 API 已废弃;请使用set_cache_dir
替代。
设置路径。为了生效,在调用get_executable_and_time()
和put_executable_and_time()
之前应该调用此方法。
返回类型:
无
代码语言:javascript复制jax.experimental.compilation_cache.compilation_cache.set_cache_dir(path)
设置持久化编译缓存目录。
调用此方法后,jit 编译的函数将保存到路径中,因此如果进程重新启动或再次运行,则无需重新编译。这也告诉 Jax 在编译之前从哪里查找已编译的函数。
返回类型:
无
代码语言:javascript复制jax.experimental.compilation_cache.compilation_cache.reset_cache()
返回到原始未初始化状态。
返回类型:
无
jax.experimental.key_reuse 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html
实验性密钥重用检查
此模块包含用于检测 JAX 程序中随机密钥重用的实验性功能。它正在积极开发中,并且这里的 API 可能会发生变化。下面的使用需要 JAX 版本 0.4.26 或更新版本。
可以通过 jax_debug_key_reuse
配置启用密钥重用检查。全局设置如下:
>>> jax.config.update('jax_debug_key_reuse', True)
或者可以通过 jax.debug_key_reuse()
上下文管理器在本地启用。启用后,使用相同的密钥两次将导致 KeyReuseError
:
>>> import jax
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key)
Traceback (most recent call last):
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
目前密钥重用检查器处于实验阶段,但未来我们可能会默认启用它。
jax.experimental.mesh_utils 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html
用于构建设备网格的实用工具。
API
create_device_mesh(mesh_shape[, devices, …]) | 为 jax.sharding.Mesh 创建一个高性能的设备网格。 |
---|---|
create_hybrid_device_mesh(mesh_shape, …[, …]) | 创建一个用于混合(例如 ICI 和 DCN)并行性的设备网格。 |
jax.experimental.serialize_executable 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.serialize_executable.html
为预编译二进制文件提供了 Pickling 支持。
API
serialize(compiled) | 序列化编译后的二进制文件。 |
---|---|
deserialize_and_load(serialized, in_tree, …) | 从序列化的可执行文件构建一个 jax.stages.Compiled 对象。 |
jax.experimental.shard_map 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.shard_map.html
API
shard_map(f, mesh, in_specs, out_specs[, …]) | 将一个函数映射到数据的分片上。 |
---|
jax.lib 模块
原文:
jax.readthedocs.io/en/latest/jax.lib.html
jax.lib 包是一组内部工具和类型,用于连接 JAX 的 Python 前端和其 XLA 后端。
jax.lib.xla_bridge
default_backend() | 返回默认 XLA 后端的平台名称。 |
---|---|
get_backend([platform]) | |
get_compile_options(num_replicas, num_partitions) | 返回用于编译的选项,从标志值派生而来。 |
jax.lib.xla_client
变更日志
原文:
jax.readthedocs.io/en/latest/changelog.html
最佳查看此处。
jax 0.4.31
jaxlib 0.4.31
- Bug 修复
- 修复了一个 bug,导致 jit 在快速路径中错误处理负的静态参数。
jax 0.4.30(2024 年 6 月 18 日)
- 变更
- JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已提升到 0.4.0,但此次发布已回滚,以便 TensorFlow 和 JAX 的用户有足够时间迁移到更新的 TensorFlow 版本。
-
jax.experimental.mesh_utils
现在可以为 TPU v5e 创建高效的网格。 - 现在,jax 直接依赖于 jaxlib。这一变更由 CUDA 插件开关驱动:不再存在多个 jaxlib 变体。您可以通过
pip install jax
安装仅支持 CPU 的 jax,无需额外的内容。 - 添加了导出和序列化 JAX 函数的 API。此功能曾存在于
jax.experimental.export
中(正在弃用),现在将位于jax.export
中。请参阅文档。
- 弃用信息
- 内部漂亮打印工具
jax.core.pp_*
已弃用,并将在将来的版本中移除。 - 对追踪器的哈希化已弃用,并将在未来的 JAX 版本中导致
TypeError
。这在先前的 JAX 版本中是一种情况,但在最近几个 JAX 版本中出现了意外的退化。 -
jax.experimental.export
已弃用。请改用jax.export
。参见迁移指南。 - 在大多数情况下,现在已弃用将数组作为 dtype 的传递方式;例如,对于数组
x
和y
,x.astype(y)
将引发警告。要消除警告,请使用x.astype(y.dtype)
。 -
jax.xla_computation
已弃用,并将在将来的版本中移除。请使用 AOT API 以获得与jax.xla_computation
相同的功能。-
jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。 - 您还可以使用
jax.stages.Lowered
的.out_info
属性来获取输出信息(例如树结构、形状和 dtype)。 - 对于跨后端的降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
-
- 内部漂亮打印工具
jaxlib 0.4.30(2024 年 6 月 18 日)
- 不再支持单片 CUDA jaxlibs。您必须使用基于插件的安装方式(
pip install jax[cuda12]
或pip install jax[cuda12_local]
)。
jax 0.4.29(2024 年 6 月 10 日)
- 变更
- 我们预计这将是支持单片 CUDA jaxlib 的 JAX 和 jaxlib 的最后一个版本发布。未来的版本将使用基于插件的 CUDA jaxlib(例如
pip install jax[cuda12]
)。 - JAX 现在要求 ml_dtypes 版本为 0.4.0 或更新。
- 移除了对旧版
jax.experimental.export
API 的向后兼容支持。不再可以使用from jax.experimental.export import export
,而应改为from jax.experimental import export
。已自 0.4.24 版本起弃用该功能。 - 在
jax.tree.all()
和jax.tree_util.tree_all()
中添加了is_leaf
参数。
- 我们预计这将是支持单片 CUDA jaxlib 的 JAX 和 jaxlib 的最后一个版本发布。未来的版本将使用基于插件的 CUDA jaxlib(例如
- 弃用
- 弃用了
jax.sharding.XLACompatibleSharding
。请使用jax.sharding.Sharding
。 -
jax.experimental.Exported.in_shardings
已重命名为jax.experimental.Exported.in_shardings_hlo
。out_shardings
也是如此。旧名称将在 3 个月后移除。 - 移除了一些先前弃用的 API:
- 来自
jax.core
:non_negative_dim
,DimSize
,Shape
- 来自
jax.lax
:tie_in
- 来自
jax.nn
:normalize
- 来自
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,XlaOp
。
- 来自
-
jax.numpy.linalg.matrix_rank()
的tol
参数即将弃用并很快将被移除。请改用rtol
。 -
jax.numpy.linalg.pinv()
的rcond
参数即将弃用并很快将被移除。请改用rtol
。 - 已移除不推荐使用的
jax.config
子模块。要配置 JAX,请使用import jax
,然后通过jax.config
引用配置对象。 -
jax.random
API 现在不再接受批量键,先前一些 API 无意中接受了。未来建议在这些情况下显式使用jax.vmap()
。 - 在
jax.scipy.special.beta()
中,为了与其他beta
API 保持一致性,已将x
和y
参数重命名为a
和b
。
- 弃用了
- 新功能
- 添加了
jax.experimental.Exported.in_shardings_jax()
来构建可以与存储在Exported
对象中的 HloShardings 在 JAX API 中使用的 shardings。
- 添加了
jaxlib 0.4.29(2024 年 6 月 10 日)
- Bug 修复
- 修复了 XLA 不正确分片某些连接操作的 bug,表现为累积归约输出不正确(#21403)。
- 修复了 XLA:CPU 错误编译某些矩阵乘法融合的 bug(https://github.com/openxla/xla/pull/13301)。
- 修复了 GPU 上的编译器崩溃(https://github.com/google/jax/issues/21396)。
- 弃用
jax.tree.map(f, None, non-None)
现在会发出DeprecationWarning
,并且在未来的 jax 版本中将引发错误。None
只是其自身的树前缀。为保留当前行为,您可以请求jax.tree.map
将None
视为叶子值,方法是写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。
jax 0.4.28(2024 年 5 月 9 日)
- Bug 修复
- 撤销了导致 Equinox 失效的
make_jaxpr
更改(#21116)。
- 撤销了导致 Equinox 失效的
- 弃用与移除
-
jax.numpy.sort()
和jax.numpy.argsort()
的kind
参数现已移除。请改用stable=True
或stable=False
。 - 从
jax.experimental.pallas.gpu
模块中移除了get_compute_capability
。请改用由jax.devices()
或jax.local_devices()
返回的 GPU 设备的compute_capability
属性。 -
jax.numpy.reshape()
的newshape
参数已被弃用,并将很快移除。请改用shape
。
-
- 变更
- 本版本 jaxlib 的最低版本为 0.4.27。
jaxlib 0.4.28 (2024 年 5 月 9 日)
- Bug 修复
- 修复了在 Python 3.10 或更早版本中的数组和 JIT Python 对象类型名称中的内存损坏 bug。
- 修复了在 CUDA 12.4 下的警告
' ptx84' is not a recognized feature for this target
。 - 修复了 CPU 上的缓慢编译问题。
- 变更
- 现在的 Windows 构建使用 Clang 而不是 MSVC。
jax 0.4.27 (2024 年 5 月 7 日)
- 新功能
- 新增了
jax.numpy.unstack()
和jax.numpy.cumulative_sum()
,遵循其在 2023 年标准的数组 API 中的添加,这很快将被 NumPy 采纳。 - 新增了一个新的配置选项
jax_cpu_collectives_implementation
,用于选择 CPU 后端使用的跨进程集合操作的实现。可用选项为'none'
(默认)、'gloo'
和'mpi'
(需要 jaxlib 0.4.26)。如果设置为'none'
,则禁用跨进程集合操作。
- 新增了
- 变更
-
jax.pure_callback()
、jax.experimental.io_callback()
和jax.debug.callback()
现在使用jax.Array
而不是np.ndarray
。您可以通过在传递给回调之前通过jax.tree.map(np.asarray, args)
转换参数来恢复旧的行为。 -
complex_arr.astype(bool)
现在遵循与 NumPy 相同的语义,当complex_arr
等于0 0j
时返回 False,否则返回 True。 -
core.Token
现在是一个包装jax.Array
的非平凡类。可以创建并将其传递到计算中,以建立依赖关系。已移除了单例对象core.token
,现在用户应该创建和使用新的core.Token
对象。 - 在 GPU 上,默认情况下,Threefry PRNG 实现不再降低为内核调用。这种选择可以在编译时减少运行时内存使用。可以通过
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
恢复先前的行为,即产生内核调用。如果新的默认行为导致问题,请报告 bug。否则,我们计划在未来的版本中移除此标志。
-
- 废弃和移除
- Pallas 现在完全采用 XLA 编译 GPU 上的内核。通过 Triton Python API 的旧降低通路已被移除,
JAX_TRITON_COMPILE_VIA_XLA
环境变量不再起作用。 -
jax.numpy.clip()
现在具有新的参数签名:a
、a_min
和a_max
已被弃用,改用x
(仅位置参数)、min
和max
(#20550)。 - JAX 数组的
device()
方法已被移除,自 JAX v0.4.21 弃用后。请改用arr.devices()
。 - 对于
jax.nn.softmax()
和jax.nn.log_softmax()
,initial
参数已弃用;现在支持不设置 softmax 的空输入。 - 在
jax.jit()
中,传递无效的static_argnums
或static_argnames
现在会导致错误,而不是警告。 - 最低的 jaxlib 版本现在是 0.4.23。
-
jax.numpy.hypot()
函数现在在传递复数输入时会发出弃用警告。在弃用完成时,将会引发错误。 - 标量参数传递给
jax.numpy.nonzero()
、jax.numpy.where()
及其相关函数现在会引发错误,这与 NumPy 中的类似变更一致。 - 配置选项
jax_cpu_enable_gloo_collectives
已不推荐使用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
。 - 在 JAX v0.4.22 中弃用并移除了
jax.Array.device_buffer
和jax.Array.device_buffers
方法。改用jax.Array.addressable_shards
和jax.Array.addressable_data()
。 -
jax.numpy.where
的condition
、x
和y
参数现在只能按位置传递,这是在 JAX v0.4.21 中关键字被弃用后的变更。 - 现在在
jax.lax.linalg
中函数的非数组参数必须通过关键字指定。之前会引发 DeprecationWarning。 - 现在在几个
jax.numpy
的 API 中(包括apply_along_axis()
、apply_over_axes()
、inner()
、outer()
、cross()
、kron()
和lexsort()
),需要使用类似数组的参数。
- Pallas 现在完全采用 XLA 编译 GPU 上的内核。通过 Triton Python API 的旧降低通路已被移除,
- Bug 修复
- 当
copy=True
时,jax.numpy.astype()
现在总是返回一个副本。之前当输出数组的 dtype 与输入数组相同时,不会进行复制。这可能会导致一些内存使用增加。默认值设置为copy=False
以保持向后兼容性。
- 当
jaxlib 0.4.27 (2024 年 5 月 7 日)
jax 0.4.26 (2024 年 4 月 3 日)
- 新功能
- 添加了
jax.numpy.trapezoid()
,跟随 NumPy 2.0 中此函数的添加。
- 添加了
- 变更
- 复数值
jax.numpy.geomspace()
现在选择与 NumPy 2.0 一致的对数螺旋分支。 - 在
jax.vmap
下,lax.rng_bit_generator
的行为,以及'rbg'
和'unsafe_rbg'
的 PRNG 实现,已发生变化,使得在密钥上进行映射只会从批处理中的第一个密钥生成随机数。 - 文档现在使用
jax.random.key
构造 PRNG 密钥数组,而不是jax.random.PRNGKey
。
- 复数值
- 弃用和移除
-
jax.tree_map()
已弃用;请改用jax.tree.map
,或者为了与旧版 JAX 向后兼容性,请使用jax.tree_util.tree_map()
。 -
jax.clear_backends()
因其名字不确保做其名义暗示的操作,可能导致意外后果而被弃用,例如,它不会销毁现有的后端或释放相应的资源。如果只想清理编译缓存,请使用jax.clear_caches()
。为了向后兼容性或者确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()
。 - 废弃了
jax.experimental.maps
模块和jax.experimental.maps.xmap
。请使用jax.experimental.shard_map
或在表达 SPMD 设备并行计算时使用带有spmd_axis_name
参数的jax.vmap
。 - 废弃了
jax.experimental.host_callback
模块。请改用新的 JAX 外部回调。添加了JAX_HOST_CALLBACK_LEGACY
标志以帮助过渡到新的回调。参见 #20385 进行讨论。 - 将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
现在会导致异常。 - 移除了废弃标志
jax_parallel_functions_output_gda
。该标志早已废弃且无效;其使用对操作无影响。 - 先前弃用的导入
jax.interpreters.ad.config
和jax.interpreters.ad.source_info_util
现已移除。请改用jax.config
和jax.extend.source_info_util
。 - JAX 导出不再支持旧的序列化版本。自 2023 年 10 月 27 日起支持版本 9,并自 2024 年 2 月 1 日起成为默认版本。详见版本描述。此更改可能会影响将 JAX 序列化版本设置为低于 9 的客户端。
-
jaxlib 0.4.26(2024 年 4 月 3 日)
- 更改
- JAX 现在仅支持 CUDA 12.1 或更新版本。不再支持 CUDA 11.8。
- JAX 现在支持 NumPy 2.0。
jax 0.4.25(2024 年 2 月 26 日)
- 新功能
- 增加了对 CUDA 数组接口 的导入支持(需要 jaxlib 0.4.24)。
- JAX 数组现在支持 NumPy 风格的标量布尔索引,例如
x[True]
或x[False]
。 - 新增了
jax.tree
模块,提供了更便捷的接口来引用jax.tree_util
中的函数。 -
jax.tree.transpose()
(即jax.tree_util.tree_transpose()
)现在接受inner_treedef=None
,在这种情况下,内部 treedef 将自动推断。
- 更改
- Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
JAX_TRITON_COMPILE_VIA_XLA
环境变量设置为"0"
来恢复到旧行为。 -
jax.interpreters.xla
中几个在 v0.4.24 中移除的废弃 API 在 v0.4.25 中重新添加,包括backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
和XLAOp
。这些仍被视为废弃,将来会在更好的替代品可用时再次移除。参见 #19816 进行讨论。
- Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
- 废弃与移除
-
jax.numpy.linalg.solve()
现在对于批处理的 1D 解法(b.ndim > 1
)显示废弃警告。将来将将这些视为批处理的 2D 解法。 - 将非标量数组转换为 Python 标量现在会引发错误,无论数组的大小如何。在非标量大小为 1 的数组的情况下,之前会引发弃用警告。这与 NumPy 中的类似弃用相似。
- 先前弃用的配置 API 已经根据标准的 3 个月弃用周期被移除(请参见 API 兼容性)。这些包括
-
jax.config.config
对象和 -
jax.config
的define_*_state
和DEFINE_*
方法。
-
- 通过
import jax.config
导入jax.config
子模块已经被弃用。配置 JAX 请使用import jax
,然后通过jax.config
引用配置对象。 - 最低的 jaxlib 版本现在是 0.4.20。
-
jaxlib 0.4.25(2024 年 2 月 26 日)
jax 0.4.24(2024 年 2 月 6 日)
- 变更
- JAX 降级到 StableHLO 不再依赖于物理设备。如果您的原语在降级规则中使用
custom_partitioning
或 JAX 回调,即传递给mlir.register_lowering
的rule
参数的函数,则将原语添加到jax._src.dispatch.prim_requires_devices_during_lowering
集合中。这是因为custom_partitioning
和 JAX 回调需要物理设备在降级过程中创建Sharding
。这是一个临时状态,直到我们可以在没有物理设备的情况下创建Sharding
。 -
jax.numpy.argsort()
和jax.numpy.sort()
现在支持stable
和descending
参数。 - 对形状多态性处理的若干更改(用于
jax.experimental.jax2tf
和jax.experimental.export
中):- 更清晰地打印符号表达式(#19227)
- 增加了在维度变量上指定符号约束的功能。这使得形状多态性更加表达,并且提供了一个方法来解决不等式推理中的限制。参见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
- 随着符号约束的增加(#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名称。来自不同作用域的符号表达式不能相互作用,例如,在算术操作中。作用域由
jax.experimental.jax2tf.convert()
,jax.experimental.export.symbolic_shape()
,jax.experimental.export.symbolic_args_specs()
引入。符号表达式e
的作用域可以通过e.scope
读取,并传递给上述函数以指导它们在给定作用域中构建符号表达式。请参阅 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。 - 简化和加快等式比较,如果它们的差异的标准化形式减少为 0,则认为两个符号维度相等(#19231;请注意,这可能导致用户可见的行为变化)
- 改进了不确定的不等式比较的错误消息 (#19235)。
-
core.non_negative_dim
API(最近引入)已弃用,引入了core.max_dim
和core.min_dim
(#18953) 用于表示符号维度的max
和min
。您可以使用core.max_dim(d, 0)
代替core.non_negative_dim(d)
。 -
shape_poly.is_poly_dim
已弃用,改为使用export.is_symbolic_dim
(#19282)。 -
export.args_specs
已弃用,应使用export.symbolic_args_specs ({jax-issue}
#19283)
。 -
shape_poly.PolyShape
和jax2tf.PolyShape
已弃用,应使用字符串来指定多态形状 (#19284)。 - JAX 默认的本地序列化版本现在是 9。这对
jax.experimental.jax2tf
和jax.experimental.export
非常重要。请参阅 版本号说明。
- 重构了
jax.experimental.export
的 API。现在应使用from jax.experimental import export
而不是from jax.experimental.export import export
。旧的导入方式将在 3 个月的弃用期后停止支持。 - 添加了
jax.scipy.stats.sem()
。 - 带有
return_inverse = True
的jax.numpy.unique()
返回重塑为输入维度的反向索引,遵循 NumPy 2.0 中类似的更改numpy.unique()
。 -
jax.numpy.sign()
现在对非零复数输入返回x / abs(x)
。这与 NumPy 2.0 版本中numpy.sign()
的行为一致。 - 带有
return_sign=True
的jax.scipy.special.logsumexp()
现在使用 NumPy 2.0 中的复数符号约定x / abs(x)
。这与 SciPy v1.13 中的scipy.special.logsumexp()
的行为一致。 - JAX 现在支持布尔型 DLPack 类型的导入和导出。之前布尔值无法导入,并且以整数形式导出。
- JAX 降级到 StableHLO 不再依赖于物理设备。如果您的原语在降级规则中使用
- 弃用和移除:
- 删除了许多先前弃用的函数,遵循标准的 3 个月弃用周期(请参阅 API 兼容性)。
- 从
jax.core
中移除:TracerArrayConversionError
、TracerIntegerConversionError
、UnexpectedTracerError
、as_hashable_function
、collections
、dtypes
、lu
、map
、namedtuple
、partial
、pp
、ref
、safe_zip
、safe_map
、source_info_util
、total_ordering
、traceback_util
、tuple_delete
、tuple_insert
和zip
。 - 从
jax.lax
中移除:dtypes
、itertools
、naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
。 -
jax.linear_util
子模块及其所有内容。 -
jax.prng
子模块及其所有内容。 - 来自
jax.random
:PRNGKeyArray
、KeyArray
、default_prng_impl
、threefry_2x32
、threefry2x32_key
、threefry2x32_p
、rbg_key
和unsafe_rbg_key
。 - 来自
jax.tree_util
:register_keypaths
、AttributeKeyPathEntry
和GetItemKeyPathEntry
。 - 来自
jax.interpreters.xla
:backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
、axis_groups
、ShapedArray
、ConcreteArray
、AxisEnv
、backend_compile
和XLAOp
。 - 来自
jax.numpy
:NINF
、NZERO
、PZERO
、row_stack
、issubsctype
、trapz
和in1d
。 - 来自
jax.scipy.linalg
:tril
和triu
。
- 从
- 已弃用的方法
PRNGKeyArray.unsafe_raw_array
已被移除。请使用jax.random.key_data()
替代。 -
bool(empty_array)
现在引发错误,而不是返回False
。这之前会引发弃用警告,并遵循 NumPy 中类似的更改。 - 弃用了对 mhlo MLIR 方言的支持。JAX 不再使用 mhlo 方言,而是改用 stablehlo。将来将删除指称“mhlo”的 API。请改用“stablehlo”方言。
-
jax.random
:直接将批处理密钥传递给随机数生成函数(如bits()
、gamma()
等)已弃用,并将发出FutureWarning
。请使用jax.vmap
进行显式批处理。 - 弃用了
jax.lax.tie_in()
:自 JAX v0.2.0 以来已成为无操作。
- 删除了许多先前弃用的函数,遵循标准的 3 个月弃用周期(请参阅 API 兼容性)。
jaxlib 0.4.24(2024 年 2 月 6 日)
- 变更
- JAX 现在支持 CUDA 12.3 和 CUDA 11.8。不再支持 CUDA 12.2。
-
cost_analysis
现在可以与交叉编译的Compiled
对象一起使用(例如,在非 TPU 计算机上使用.lower().compile()
编译为云 TPU 时使用拓扑对象)。 - 添加了CUDA 数组接口导入支持(需要 jax 0.4.25)。
jax 0.4.23(2023 年 12 月 13 日)
jaxlib 0.4.23(2023 年 12 月 13 日)
- 修复了导致 GPU 编译器在编译期间产生冗长日志的错误。
jax 0.4.22(2023 年 12 月 13 日)
- 弃用内容
- JAX 数组的
device_buffer
和device_buffers
属性已弃用。显式缓冲区已被更灵活的数组分片接口取代,但以前的输出可以通过以下方式恢复:-
arr.device_buffer
变为arr.addressable_data(0)
-
arr.device_buffers
变为[x.data for x in arr.addressable_shards]
-
- JAX 数组的
jaxlib 0.4.22(2023 年 12 月 13 日)
jax 0.4.21(2023 年 12 月 4 日)
- 新特性
- 添加了
jax.nn.squareplus
。
- 添加了
- 变更
- 最低 jaxlib 版本现在为 0.4.19。
- 现在发布的 Wheels 使用 clang 而不是 gcc 构建。
- 在调用
jax.distributed.initialize()
之前,强制确保设备后端未初始化。 - 在云 TPU 环境中自动化
jax.distributed.initialize()
的参数。
- 弃用内容
- 从
jax.scipy.linalg.solve()
中删除了先前弃用的sym_pos
参数。请改用assume_a='pos'
。 - 将
None
传递给jax.array()
或jax.asarray()
,无论是直接传递还是在列表或元组中传递,已被弃用并现在引发FutureWarning
。当前转换为 NaN,在将来将引发TypeError
。 - 通过关键字参数传递
condition
、x
和y
参数给jax.numpy.where
已被弃用,以匹配numpy.where
。 - 传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
的参数如果不能转换为 JAX 数组,则已被弃用并现在引发DeprecationWaning
。当前函数返回 False,在将来将引发异常。 - JAX 数组的
device()
方法已被弃用。根据上下文,可能替换为以下之一:-
jax.Array.devices()
返回数组使用的所有设备集。 -
jax.Array.sharding
给出了数组使用的分片配置。
-
- 从
jaxlib 0.4.21 (2023 年 12 月 4 日)
- 变更
- 为了添加分布式 CPU 支持的准备工作,JAX 现在将 CPU 设备与 GPU 和 TPU 设备相同对待,即:
-
jax.devices()
包括分布式作业中所有设备,即使这些设备不在当前进程中也包括在内。jax.local_devices()
仍然只包括当前进程中的设备,因此如果对jax.devices()
的更改影响到您,您可能更希望使用jax.local_devices()
。 - CPU 设备现在在分布式作业中接收全局唯一的 ID 号码;以前 CPU 设备将接收进程本地的 ID 号码。
- 每个 CPU 设备的
process_index
现在将与同一进程中的任何 GPU 或 TPU 设备匹配;以前 CPU 设备的process_index
总是 0。
-
- 在 NVIDIA GPU 上,JAX 现在优先选择 Jacobi SVD 求解器用于大小不超过 1024x1024 的矩阵。与非 Jacobi 版本相比,Jacobi 求解器似乎更快。
- 为了添加分布式 CPU 支持的准备工作,JAX 现在将 CPU 设备与 GPU 和 TPU 设备相同对待,即:
- Bug 修复
- 当传递具有非有限值的数组给非对称特征分解时发生错误/挂起(#18226)。现在,具有非有限值的数组将产生由 NaN 组成的输出数组。
jax 0.4.20 (2023 年 11 月 2 日)
jaxlib 0.4.20 (2023 年 11 月 2 日)
- Bug 修复
- 修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆。
jax 0.4.19 (2023 年 10 月 19 日)
- 新功能
- 添加了
jax.typing.DTypeLike
,可用于注释可转换为 JAX 数据类型的对象。 - 添加了
jax.numpy.fill_diagonal
。
- 添加了
- 变更
- 现在 JAX 要求 SciPy 1.9 或更新版本。
- Bug 修复
- 在多控制器分布式 JAX 程序中,只有进程 0 将写入持久编译缓存条目。如果缓存放置在网络文件系统(如 GCS)上,则修复了写入争用问题。
- 当决定已安装的 cusolver 和 cufft 版本是否至少与 JAX 构建的版本一样新时,版本检查现在不再考虑补丁版本。
jaxlib 0.4.19 (2023 年 10 月 19 日)
- 变更
- jaxlib 现在始终优先使用通过 pip 安装的 NVIDIA CUDA 库(nvidia-… packages),而不管
LD_LIBRARY_PATH
中命名的其他 CUDA 安装。如果因此导致问题且意图是使用系统安装的 CUDA,则解决方法是移除 pip 安装的 CUDA 库包。
- jaxlib 现在始终优先使用通过 pip 安装的 NVIDIA CUDA 库(nvidia-… packages),而不管
jax 0.4.18(2023 年 10 月 6 日)
jaxlib 0.4.18(2023 年 10 月 6 日)
- 变更:
- CUDA jaxlibs 现在依赖于用户安装兼容的 NCCL 版本。如果使用推荐的
cuda12_pip
安装,NCCL 应会自动安装。目前需要 NCCL 2.16 或更新版本。 - 现在我们提供 Linux aarch64 wheels,包括带有和不带有 NVIDIA GPU 支持的版本。
-
jax.Array.item()
现在支持可选的索引参数。
- CUDA jaxlibs 现在依赖于用户安装兼容的 NCCL 版本。如果使用推荐的
- 弃用:
- 一些
jax.lax
中的内部实用程序和无意导出已被弃用,并将在将来的版本中移除。-
jax.lax.dtypes
: 使用jax.dtypes
替代。 -
jax.lax.itertools
:使用itertools
替代。 -
naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
是内部实用程序,现在已弃用且没有替代。
-
- 一些
- Bug 修复
- 修复了云 TPU 回归,因 smem 导致编译 OOM。
jax 0.4.17(2023 年 10 月 3 日)
- 新功能
- 新增了
jax.numpy.bitwise_count()
函数,与最近添加到 NumPy 的类似函数的 API 匹配。
- 新增了
- 弃用:
- 移除了弃用的模块
jax.abstract_arrays
及其所有内容。 -
jax.random
中的命名键构造函数已被弃用。改为向jax.random.PRNGKey()
或jax.random.key()
传递impl
参数:-
random.threefry2x32_key(seed)
变为random.PRNGKey(seed, impl='threefry2x32')
-
random.rbg_key(seed)
变为random.PRNGKey(seed, impl='rbg')
-
random.unsafe_rbg_key(seed)
变为random.PRNGKey(seed, impl='unsafe_rbg')
-
- 移除了弃用的模块
- 变更:
- CUDA:JAX 现在会验证其找到的 CUDA 库是否至少与 JAX 构建时使用的 CUDA 库一样新。如果发现较旧的库,JAX 将引发异常,因为这比神秘的故障和崩溃更可取。
- 移除了“未找到 GPU/TPU”的警告。而是在 Linux 上,如果发现但未使用 NVIDIA GPU 或 Google TPU,并且未指定
--jax_platforms
,则发出警告。 -
jax.scipy.stats.mode()
现在在跨尺寸为 0 的轴上取模时返回 0 计数,与 SciPy 1.11 中scipy.stats.mode
的行为相匹配。 - 大多数
jax.numpy
函数和属性现在都具有完全定义的类型存根。以前,这些函数中的许多被静态类型检查器(如mypy
和pytype
)视为Any
。
jaxlib 0.4.17(2023 年 10 月 3 日)
- 变更:
- Python 3.12 wheels 已添加到此版本中。
- CUDA 12 wheels 现在需要 CUDA 12.2 或更新版本以及 cuDNN 8.9.4 或更新版本。
- Bug 修复:
- 修复了当 JAX CPU 后端初始化时,ABSL 输出大量日志的问题。
jax 0.4.16(2023 年 9 月 18 日)
- 变更:
- 添加了
jax.numpy.ufunc
,以及jax.numpy.frompyfunc()
,它可以将任何标量值函数转换为类似于numpy.ufunc()
的对象,具有outer()
、reduce()
、accumulate()
、at()
和reduceat()
等方法(#17054)。 - 添加了
jax.scipy.integrate.trapezoid()
。 - 在非 IPython 环境下:当引发异常时,JAX 现在会从回溯中过滤掉其内部帧的整体。(之前会出现“未过滤堆栈跟踪”)。这应该会产生更友好的堆栈跟踪。详见 此处 的示例。此行为可以通过设置
JAX_TRACEBACK_FILTERING=remove_frames
(用于两个单独的未过滤/过滤后的堆栈跟踪,即旧的行为)或JAX_TRACEBACK_FILTERING=off
(用于一个未过滤的堆栈跟踪)来改变。 - jax2tf 默认序列化版本现在是 7,引入了新的形状 安全断言。
- 传递给
jax.sharding.Mesh
的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices()
已经是可哈希的。
- 添加了
- 破坏性变更:
- jax2tf 现在默认使用本地序列化。请查阅 jax2tf 文档 获取详细信息以及覆盖默认设置的机制。
- 选项
--jax_coordination_service
已被移除。现在总是True
。 -
jax.jaxpr_util
已从公共 JAX 命名空间中移除。 -
JAX_USE_PJRT_C_API_ON_TPU
不再生效(即它总是默认为 true)。 - 自 2021 年 12 月引入的兼容性标志
--jax_host_callback_ad_transforms
已被移除。
- 弃用:
- 根据 NumPy NEP-52,几个
jax.numpy
API 已经被弃用:-
jax.numpy.NINF
已经被弃用。请改用-jax.numpy.inf
。 -
jax.numpy.PZERO
已经被弃用。请改用0.0
。 -
jax.numpy.NZERO
已经被弃用。请改用-0.0
。 -
jax.numpy.issubsctype(x, t)
已经被弃用。请改用jax.numpy.issubdtype(x.dtype, t)
。 -
jax.numpy.row_stack
已经被弃用。请改用jax.numpy.vstack
。 -
jax.numpy.in1d
已经被弃用。请改用jax.numpy.isin
。 -
jax.numpy.trapz
已经被弃用。请改用jax.scipy.integrate.trapezoid
。
-
-
jax.scipy.linalg.tril
和jax.scipy.linalg.triu
已经被弃用,遵循 SciPy 的做法。请改用jax.numpy.tril
和jax.numpy.triu
。 -
jax.lax.prod
已经在 JAX v0.4.11 中被移除,之前已被弃用。请改用内置的math.prod
。 - 从
jax.interpreters.xla
导出的一些与为自定义 JAX 原语定义 HLO 降低规则有关的内容已经被弃用。应该使用jax.interpreters.mlir
中的 StableHLO 降低实用工具来定义自定义原语。 - 在经过三个月的弃用期后,以下先前弃用的函数已被移除:
-
jax.abstract_arrays.ShapedArray
: 使用jax.core.ShapedArray
。 -
jax.abstract_arrays.raise_to_shaped
: 使用jax.core.raise_to_shaped
。 -
jax.numpy.alltrue
: 使用jax.numpy.all
。 -
jax.numpy.sometrue
: 使用jax.numpy.any
。 -
jax.numpy.product
: 使用jax.numpy.prod
。 -
jax.numpy.cumproduct
: 使用jax.numpy.cumprod
。
-
- 根据 NumPy NEP-52,几个
- 弃用/移除:
- 内部子模块
jax.prng
现已弃用。其内容可在jax.extend.random
中找到。 - 内部子模块路径
jax.linear_util
已被弃用。请使用jax.extend.linear_util
替代(jax.extend 的一部分:扩展模块)。 -
jax.random.PRNGKeyArray
和jax.random.KeyArray
已弃用。请在类型注释中使用jax.Array
,并在运行时使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
来检测类型化的 PRNG 密钥。 - 方法
PRNGKeyArray.unsafe_raw_array
已弃用。请改用jax.random.key_data()
。 -
jax.experimental.pjit.with_sharding_constraint
已弃用。请使用jax.lax.with_sharding_constraint
替代。 - 内部工具函数
jax.core.is_opaque_dtype
和jax.core.has_opaque_dtype
已被移除。不透明数据类型已重命名为扩展数据类型;请使用jnp.issubdtype(dtype, jax.dtypes.extended)
替代(自 jax v0.4.14 起可用)。 - 实用工具函数
jax.interpreters.xla.register_collective_primitive
已被移除。在最新的 JAX 发行版中,此实用工具无任何作用,可以安全移除其调用。 - 内部子模块路径
jax.linear_util
已被弃用。请使用jax.extend.linear_util
替代(jax.extend 的一部分:扩展模块)。
- 内部子模块
jaxlib 0.4.16(2023 年 9 月 18 日)
- 变更:
- 在 NVIDIA GPU 上,通过实验性的 jax 稀疏 API 进行的稀疏 CSR 矩阵乘法不再使用确定性算法。此更改是为了与 CUDA 12.2.1 兼容性而进行的。
- Bug 修复:
- 修复了由于关于乱序段和 IMAGE_REL_AMD64_ADDR32NB 重定位的致命 LLVM 错误而在 Windows 上崩溃的问题(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14(2023 年 7 月 27 日)
- 变更:
-
jax.jit
接受donate_argnames
作为参数。其语义类似于static_argnames
。如果既不提供donate_argnums
也不提供donate_argnames
,则不会捐赠任何参数。如果不提供donate_argnums
但提供了donate_argnames
,或者反之,则 JAX 使用inspect.signature(fun)
来查找与donate_argnames
(或其反向)相对应的任何位置参数。如果同时提供了donate_argnums
和donate_argnames
,则不使用inspect.signature
,并且只有实际参数列在donate_argnums
或donate_argnames
中将被捐赠。 -
jax.random.gamma()
已重新设计为更高效的算法,具有更健壮的端点行为(#16779)。这意味着给定key
的值序列在 JAX v0.4.13 和 v0.4.14 之间的gamma
和相关抽样器(包括jax.random.ball()
、jax.random.beta()
、jax.random.chisquare()
、jax.random.dirichlet()
、jax.random.generalized_normal()
、jax.random.loggamma()
、jax.random.t()
)将发生变化。
-
- 删除项:
- 自弃用以来已超过 3 个月的
in_axis_resources
和out_axis_resources
已从 pjit 中删除。请使用in_shardings
和out_shardings
进行替换。这是一个安全和简单的名称替换。它不会改变任何当前的 pjit 语义,也不会破坏任何代码。您仍然可以将PartitionSpecs
传递给in_shardings
和out_shardings
。
- 自弃用以来已超过 3 个月的
- 弃用项:
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,已删除对 Python 3.8 的支持。
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,JAX 现在要求 NumPy 1.22 或更新版本。
- 不再支持通过位置向
jax.numpy.ndarray.at()
传递可选参数,已在 JAX 版本 0.4.7 中被弃用。例如,不再使用x.at[i].get(True)
,而是使用x.at[i].get(indices_are_sorted=True)
。 - 以下
jax.Array
方法在 JAX v0.4.5 中被弃用后已被移除:-
jax.Array.broadcast
: 改用jax.lax.broadcast()
。 -
jax.Array.broadcast_in_dim
: 改用jax.lax.broadcast_in_dim()
。 -
jax.Array.split
: 使用jax.numpy.split()
替代。
-
- 在之前的弃用之后,以下 API 已被移除:
-
jax.ad
: 使用jax.interpreters.ad
。 -
jax.curry
: 使用curry = lambda f: partial(partial, f)
。 -
jax.partial_eval
: 使用jax.interpreters.partial_eval
。 -
jax.pxla
: 使用jax.interpreters.pxla
。 -
jax.xla
: 使用jax.interpreters.xla
。 -
jax.ShapedArray
: 使用jax.core.ShapedArray
。 -
jax.interpreters.pxla.device_put
: 使用jax.device_put()
。 -
jax.interpreters.pxla.make_sharded_device_array
: 使用jax.make_array_from_single_device_arrays()
。 -
jax.interpreters.pxla.ShardedDeviceArray
: 使用jax.Array
。 -
jax.numpy.DeviceArray
: 使用jax.Array
。 -
jax.stages.Compiled.compiler_ir
: 使用jax.stages.Compiled.as_text()
。
-
- 破坏性变更:
- JAX 现在要求 ml_dtypes 版本 0.2.0 或更新版本。
- 为了修复一个边缘情况,调用
jax.lax.cond()
时,如果第二个和第三个参数是可调用的,则使用五个参数总是解析为文档中记录的 “common operands”cond
行为,即使其他操作数也是可调用的。参见 #16413。 - 已删除无效配置选项
jax_array
和jax_jit_pjit_api_merge
。这些选项默认情况下自许多版本以来都为 true。
- 新功能:
- JAX 现在支持配置标志
--jax_serialization_version
和环境变量JAX_SERIALIZATION_VERSION
来控制序列化版本(#16746)。 - 在形状多态性存在的情况下,jax2tf 现在生成检查某些形状约束的代码,如果序列化版本至少为 7。详见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
- JAX 现在支持配置标志
jaxlib 0.4.14(2023 年 7 月 27 日)
- 弃用
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,不再支持 Python 3.8。
jax 0.4.13(2023 年 6 月 22 日)
- 更改
-
jax.jit
现在允许将None
传递给in_shardings
和out_shardings
。语义如下:- 对于
in_shardings
,JAX 将其标记为复制,但这种行为可能会在将来更改。 - 对于
out_shardings
,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
- 对于
-
jax.experimental.pjit.pjit
也允许将None
传递给in_shardings
和out_shardings
。语义如下:- 如果未提供网格上下文管理器,则 JAX 可自由选择所需的分片方式。
- 对于
in_shardings
,JAX 将其标记为复制,但这种行为可能会在将来更改。 - 对于
out_shardings
,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
- 对于
- 如果提供了网格上下文管理器,
None
将意味着该值将在网格的所有设备上复制。
- 如果未提供网格上下文管理器,则 JAX 可自由选择所需的分片方式。
- Executable.cost_analysis() 在 Cloud TPU 上可用
- 如果正在使用非允许的
jaxlib
插件,则添加了警告。 - 添加了
jax.tree_util.tree_leaves_with_path
。 -
None
不是jax.experimental.multihost_utils.host_local_array_to_global_array
或jax.experimental.multihost_utils.global_array_to_host_local_array
的有效输入。如果您希望复制您的输入,请使用jax.sharding.PartitionSpec()
。
-
- Bug 修复
- 在 CUDA 12 发布中修复了错误的轮子名称(#16362);正确的轮子名称为
cudnn89
而不是cudnn88
。
- 在 CUDA 12 发布中修复了错误的轮子名称(#16362);正确的轮子名称为
- 弃用
jax.experimental.jax2tf.convert()
的native_serialization_strict_checks
参数已被弃用,推荐使用新的native_serializaation_disabled_checks
(#16347)。
jaxlib 0.4.13(2023 年 6 月 22 日)
- 更改
- 将 Windows 仅 CPU 轮子添加到
jaxlib
Pypi 发布中。
- 将 Windows 仅 CPU 轮子添加到
- Bug 修复
-
__cuda_array_interface__
在之前的 jaxlib 版本中出现问题,现已修复(#16440)。 - 并行 CUDA 内核跟踪现在默认启用于 NVIDIA GPU。
-
jax 0.4.12(2023 年 6 月 8 日)
- 更改
- 添加了
scipy.spatial.transform.Rotation
和scipy.spatial.transform.Slerp
- 添加了
- 弃用
-
jax.abstract_arrays
及其内容已被弃用。请参阅:mod:jax.core
中的相关功能。 -
jax.numpy.alltrue
:使用jax.numpy.all
。这遵循了 NumPy 版本 1.25.0 中numpy.alltrue
的弃用。 -
jax.numpy.sometrue
:使用jax.numpy.any
。这遵循了 NumPy 版本 1.25.0 中numpy.sometrue
的弃用。 -
jax.numpy.product
:使用jax.numpy.prod
。这遵循了 NumPy 版本 1.25.0 中numpy.product
的弃用。 -
jax.numpy.cumproduct
:使用jax.numpy.cumprod
。这遵循了 NumPy 版本 1.25.0 中numpy.cumproduct
的弃用。 -
jax.sharding.OpShardingSharding
已被移除,因为它已经弃用了 3 个月。
-
jaxlib 0.4.12 (2023 年 6 月 8 日)
- 变更
- 包含了 Hopper(SM 版本 9.0 )GPU 的 PTX/SASS。之前的 jaxlib 版本应该可以在 Hopper 上工作,但第一次执行 JAX 操作时可能会有较长的 JIT 编译延迟。
- Bug 修复
- 修复了在 Python 3.11 下 JAX 生成的 Python 回溯中源代码行信息不正确的问题。
- 修复了在 JAX 生成的 Python 回溯的帧中打印本地变量时崩溃的问题(#16027)。
jax 0.4.11 (2023 年 5 月 31 日)
- 弃用
- 根据 API 兼容性政策,在 3 个月的弃用期后,已移除以下 API:
-
jax.experimental.PartitionSpec
:使用jax.sharding.PartitionSpec
。 -
jax.experimental.maps.Mesh
:使用jax.sharding.Mesh
。 -
jax.experimental.pjit.NamedSharding
:使用jax.sharding.NamedSharding
。 -
jax.experimental.pjit.PartitionSpec
:使用jax.sharding.PartitionSpec
。 -
jax.experimental.pjit.FROM_GDA
。请将分片的jax.Array
对象作为输入传递,并删除pjit
的可选in_shardings
参数。 -
jax.interpreters.pxla.PartitionSpec
:使用jax.sharding.PartitionSpec
。 -
jax.interpreters.pxla.Mesh
:使用jax.sharding.Mesh
。 -
jax.interpreters.xla.Buffer
:使用jax.Array
。 -
jax.interpreters.xla.Device
:使用jax.Device
。 -
jax.interpreters.xla.DeviceArray
:使用jax.Array
。 -
jax.interpreters.xla.device_put
:使用jax.device_put
。 -
jax.interpreters.xla.xla_call_p
:使用jax.experimental.pjit.pjit_p
。 -
with_sharding_constraint
的axis_resources
参数已被移除。请改用shardings
。
-
- 根据 API 兼容性政策,在 3 个月的弃用期后,已移除以下 API:
jaxlib 0.4.11 (2023 年 5 月 31 日)
- 变更
- 向
Device
添加了memory_stats()
方法。如果支持,它将返回一个包含字符串统计名称和整数值的字典,例如"bytes_in_use"
,如果平台不支持内存统计,则返回 None。具体的统计数据可能因平台而异。目前仅在 Cloud TPU 上实现。 - 重新添加了对 CPU 设备上 Python 缓冲协议(
memoryview
)的支持。
- 向
jax 0.4.10 (2023 年 5 月 11 日)
jaxlib 0.4.10 (2023 年 5 月 11 日)
- 变更
- 修复了阻止上一个版本在 Mac M1 上运行的
'apple-m1' is not a recognized processor for this target (ignoring processor)
问题。
- 修复了阻止上一个版本在 Mac M1 上运行的
jax 0.4.9 (2023 年 5 月 9 日)
- 变更
-
experimental_cpp_jit
、experimental_cpp_pjit
和experimental_cpp_pmap
标志已被移除。它们现在始终开启。 - TPU 上奇异值分解(SVD)的准确性已经提高(需要 jaxlib 0.4.9)。
-
- 废弃功能
-
jax.experimental.gda_serialization
已废弃,并已重命名为jax.experimental.array_serialization
。请更改您的导入以使用jax.experimental.array_serialization
。 -
pjit
的in_axis_resources
和out_axis_resources
参数已废弃。请分别使用in_shardings
和out_shardings
。 - 函数
jax.numpy.msort
已被移除。自 JAX v0.4.1 起已被废弃。请使用jnp.sort(a, axis=0)
代替。 -
in_parts
和out_parts
参数已从jax.xla_computation
中移除,因为它们只与sharded_jit
一起使用,并且sharded_jit
已不再使用。 - 自从很久以来未被使用,
instantiate_const_outputs
参数已从jax.xla_computation
中移除。
-
jaxlib 0.4.9(2023 年 5 月 9 日)
jax 0.4.8(2023 年 3 月 29 日)
- 破坏性变更
- Cloud TPU 运行时的一个重要组件已升级。这使得以下新功能在 Cloud TPU 上可用:
-
jax.debug.print()
、jax.debug.callback()
和jax.debug.breakpoint()
现在在 Cloud TPU 上可用。 - 自动 TPU 内存碎片整理
在新的运行时组件上,不再支持
jax.experimental.host_callback()
在 Cloud TPU 上的使用。如果新的jax.debug
API 不能满足您的需求,请在JAX 问题跟踪器上提出问题。 旧的运行时组件将通过设置环境变量JAX_USE_PJRT_C_API_ON_TPU=false
至少在接下来的三个月内可用。如果您发现需要出于任何原因禁用新的运行时,请在JAX 问题跟踪器上告知我们。 -
- Cloud TPU 运行时的一个重要组件已升级。这使得以下新功能在 Cloud TPU 上可用:
- 变更
- 最低 jaxlib 版本已从 0.4.6 提升至 0.4.7。
- 废弃功能
- 支持 CUDA 11.4 已被移除。JAX GPU 版本仅支持 CUDA 11.8 和 CUDA 12。如果使用旧版 CUDA 构建 jaxlib 可能会正常工作。
-
pmap
的global_arg_shapes
参数仅适用于sharded_jit
,已从pmap
中移除。请迁移到pjit
并从pmap
中移除global_arg_shapes
。
jax 0.4.7(2023 年 3 月 27 日)
- 变更
- 根据 https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration,不再支持禁用
jax.config.jax_array
。 - 不再支持禁用
jax.config.jax_jit_pjit_api_merge
。 -
jax.experimental.jax2tf.convert()
现在支持native_serialization
参数,使用 JAX 的本机降级到 StableHLO 以获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原语降级到 TensorFlow 操作。这简化了内部操作,并增加了您序列化内容与 JAX 本机语义匹配的信心。详见文档。作为这一变更的一部分,配置标志--jax2tf_default_experimental_native_lowering
已重命名为--jax2tf_native_serialization
。 - JAX 现在依赖于
ml_dtypes
,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。 - JAX 现在要求使用 NumPy 1.21 或更新版本以及 SciPy 1.7 或更新版本。
- 根据 https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration,不再支持禁用
- 弃用信息
- 类型
jax.numpy.DeviceArray
已弃用。请改用jax.Array
,它是其别名。 - 类型
jax.interpreters.pxla.ShardedDeviceArray
已弃用。请改用jax.Array
。 - 通过位置传递额外参数给
jax.numpy.ndarray.at()
已被弃用。例如,不要使用x.at[i].get(True)
,而是使用x.at[i].get(indices_are_sorted=True)
-
jax.interpreters.xla.device_put
已被弃用。请使用jax.device_put
。 -
jax.interpreters.pxla.device_put
已被弃用。请使用jax.device_put
。 -
jax.experimental.pjit.FROM_GDA
已被弃用。请将分片的 jax.Arrays 作为输入,并移除 pjit 中的in_shardings
参数,因为它是可选的。
- 类型
jaxlib 0.4.7(2023 年 3 月 27 日)
变更:
- jaxlib 现在依赖于
ml_dtypes
,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。
jax 0.4.6(2023 年 3 月 9 日)
- 变更
-
jax.tree_util
现在包含一组允许用户为其自定义 pytree 节点定义键的 API。-
tree_flatten_with_path
可以展平树并返回每个叶子及其键路径。 -
tree_map_with_path
可以映射一个接受键路径作为参数的函数。 -
register_pytree_with_keys
用于注册自定义 pytree 节点中键路径和叶子的外观。 -
keystr
用于漂亮地打印键路径。
-
-
jax2tf.call_tf()
现在有一个新参数output_shape_dtype
(默认为None
),可用于声明结果的输出形状和类型。这使得jax2tf.call_tf()
能够在形状多态性存在的情况下工作。(#14734)
-
- 弃用信息
-
jax.tree_util
中的旧键路径 API 已被弃用,并将在 2023 年 3 月 10 日后的 3 个月内移除:-
register_keypaths
:请使用jax.tree_util.register_pytree_with_keys()
替代。 -
AttributeKeyPathEntry
:请改用GetAttrKey
。 -
GetitemKeyPathEntry
:请改用SequenceKey
或DictKey
。
-
-
jaxlib 0.4.6(2023 年 3 月 9 日)
jax 0.4.5(2023 年 3 月 2 日)
- 弃用信息
-
jax.sharding.OpShardingSharding
已重命名为jax.sharding.GSPMDSharding
。jax.sharding.OpShardingSharding
将在 2023 年 2 月 17 日后的 3 个月内移除。 - 下列
jax.Array
方法已被弃用,并将在 2023 年 2 月 23 日后的 3 个月内移除:-
jax.Array.broadcast
:请使用jax.lax.broadcast()
替代。 -
jax.Array.broadcast_in_dim
:请使用jax.lax.broadcast_in_dim()
替代。 -
jax.Array.split
:请使用jax.numpy.split()
替代。
-
-
jax 0.4.4(2023 年 2 月 16 日)
- 变更
-
jit
和pjit
的实现已合并。合并 jit 和 pjit 改变了 JAX 的内部实现,但不影响 JAX 的公共 API。之前,jit
是一种最终风格的原语。最终风格意味着尽可能延迟创建 jaxpr 并将变换堆叠在一起。随着jit
-pjit
实现的合并,jit
变成了一种初始风格的原语,这意味着我们尽早追踪到 jaxpr。更多信息请参见 autodidax 中的这一部分。转移到初始风格应该简化 JAX 的内部实现,并使得动态形状等功能的开发更加容易。你只能通过环境变量来禁用它,即os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
。由于它影响到 JAX 的导入时机,因此必须通过环境变量禁用它,在导入 jax 之前就需要禁用它。 -
with_sharding_constraint
的axis_resources
参数已弃用。请改用shardings
。如果你将其作为参数使用,则无需更改。如果你将其作为关键字参数使用,请改用shardings
。axis_resources
将在 2023 年 2 月 13 日后的 3 个月内删除。 - 添加了
jax.typing
模块,用于 JAX 函数的类型注解工具。 - 下列名称已被弃用:
-
jax.xla.Device
和jax.interpreters.xla.Device
: 使用jax.Device
。 -
jax.experimental.maps.Mesh
. 使用jax.sharding.Mesh
替代。 -
jax.experimental.pjit.NamedSharding
: 使用jax.sharding.NamedSharding
。 -
jax.experimental.pjit.PartitionSpec
: 使用jax.sharding.PartitionSpec
。 -
jax.interpreters.pxla.Mesh
: 使用jax.sharding.Mesh
。 -
jax.interpreters.pxla.PartitionSpec
: 使用jax.sharding.PartitionSpec
。
-
-
- Breaking Changes
jax.numpy.sum
等的initial
参数现在要求是一个标量,与对应的 NumPy API 保持一致。以前的行为是对非标量initial
值进行广播,这是一个意外的实现细节(#14446)。
jaxlib 0.4.4(2023 年 2 月 16 日)
- Breaking changes
- 默认的
jaxlib
构建中已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,可以通过使用 Kepler 支持的源码构建jaxlib
(通过build.py
的--cuda_compute_capabilities=sm_35
选项),不过请注意 CUDA 12 已完全停止支持 Kepler GPU。
- 默认的
jax 0.4.3(2023 年 2 月 8 日)
- Breaking changes
- 删除了
jax.scipy.linalg.polar_unitary()
,这是一个已弃用的 JAX 扩展到 scipy API 的函数。请改用jax.scipy.linalg.polar()
。
- 删除了
- Changes
- 添加了
jax.scipy.stats.rankdata()
。
- 添加了
jaxlib 0.4.3(2023 年 2 月 8 日)
jax.Array
现在具有非阻塞的is_ready()
方法,如果数组已准备就绪则返回True
(参见jax.block_until_ready()
)。
jax 0.4.2(2023 年 1 月 24 日)
- Breaking changes
- 删除了
jax.experimental.callback
- 在存在
jax2tf
形状多态性的情况下,对带有维度的操作进行了泛化处理,通过将符号维度转换为 JAX 数组来在更多场景下工作。现在,涉及符号维度和np.ndarray
的操作在结果用作形状值时可能会引发错误(#14106)。 - 现在,
jaxpr
对象在设置属性时会引发错误,以避免问题变异(#14102)
- 删除了
- 变更
-
jax2tf.call_tf()
现在有一个新参数has_side_effects
(默认为True
),可用于声明实例是否可以被 JAX 优化(如死代码消除)删除或复制(#13980)。 - 为了支持
jax2tf
形状多态性的floordiv
和mod
,我们增加了更多支持。之前,存在符号维度时某些除法操作会导致错误(#14108)。
-
jaxlib 0.4.2(2023 年 1 月 24 日)
- 变更
- 设置
JAX_USE_PJRT_C_API_ON_TPU=1
可启用新的 Cloud TPU 运行时,具备自动设备内存碎片整理功能。
- 设置
jax 0.4.1(2022 年 12 月 13 日)
- 变更
- 根据 JAX 的 Python 和 NumPy 版本支持政策,不再支持 Python 3.7。
- 我们引入了
jax.Array
,它是 JAX 中的统一数组类型,涵盖了DeviceArray
、ShardedDeviceArray
和GlobalDeviceArray
类型。jax.Array
类型有助于使并行成为 JAX 的核心特性,简化和统一 JAX 内部结构,并允许我们统一jit
和pjit
。jax.Array
已在 JAX 0.4 中默认启用,并对pjit
API 进行了一些破坏性更改。jax.Array 迁移指南 可帮助您将代码库迁移到jax.Array
。您还可以查看Distributed arrays and automatic parallelization 教程,以理解新概念。 -
PartitionSpec
和Mesh
现在不再处于实验阶段。新的 API 端点是jax.sharding.PartitionSpec
和jax.sharding.Mesh
。jax.experimental.maps.Mesh
和jax.experimental.PartitionSpec
已被弃用,并将在三个月内移除。 -
with_sharding_constraint
的新公共端点是jax.lax.with_sharding_constraint
。 - 如果与
jax.config
一起使用 ABSL 标志,那么在最初从 ABSL 标志填充 JAX 配置选项后,就不再读取或写入 ABSL 标志值。此更改改进了读取jax.config
选项的性能,这些选项在 JAX 中广泛使用。 -
jax2tf.call_tf
函数现在使用与嵌入 JAX 计算相同平台的第一个 TF 设备进行 TF 降级。以前,它使用的是 JAX 默认后端的第 0 个设备。 - 现在,一些
jax.numpy
函数的参数已标记为仅限位置参数,与 NumPy 匹配。 -
jnp.msort
现已废弃,遵循 numpy 1.24 中np.msort
的废弃。它将在未来的版本中移除,符合 API 兼容性策略。可以用jnp.sort(a, axis=0)
替换。
jaxlib 0.4.1 (2022 年 12 月 13 日)
- 变更
- 支持 Python 3.7 已被放弃,符合 JAX 的 Python 和 NumPy 版本支持政策。
-
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
的行为已更改,现在分配总 GPU 内存的 XX%来预分配,而不是以前使用当前可用 GPU 内存来计算预分配。有关更多详情,请参阅GPU memory allocation。 - 废弃的方法
.block_host_until_ready()
已被移除。请改用.block_until_ready()
。
jax 0.4.0 (2022 年 12 月 12 日)
- 此版本已被撤回。
jaxlib 0.4.0 (2022 年 12 月 12 日)
- 此版本已被撤回。
jax 0.3.25 (2022 年 11 月 15 日)
- 变更
-
jax.numpy.linalg.pinv()
现在支持hermitian
选项。 -
jax.scipy.linalg.hessenberg()
现在仅在 CPU 上支持。需要 jaxlib > 0.3.24。 - 新函数
jax.lax.linalg.hessenberg()
,jax.lax.linalg.tridiagonal()
和jax.lax.linalg.householder_product()
已添加。Householder 约简目前仅支持 CPU,三对角约简支持 CPU 和 GPU。 - 现在更经济地计算非方阵的
svd
和jax.numpy.linalg.pinv
的梯度。
-
- 突破性变更
- 删除了
jax_experimental_name_stack
配置选项。 - 将字符串
axis_names
参数转换为jax.experimental.maps.Mesh
构造函数的单例元组,而不是将字符串解包为字符轴名称序列。
- 删除了
jaxlib 0.3.25 (2022 年 11 月 15 日)
- 变更
- 添加了对 CPU 和 GPU 上三对角约简的支持。
- 添加了对 CPU 上上 Hessenberg 约简的支持。
- Bug 修复
- 修复了一个 bug,导致 JAX 捕获的回溯中的帧被错误地映射到 Python 3.10 下的源行。
jax 0.3.24 (2022 年 11 月 4 日)
- 变更
- JAX 导入速度应更快。现在我们懒惰地导入 scipy,这在 JAX 的导入时间中占据了相当大的部分。
- 设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
可以用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间超过 1 秒的计算将被缓存。- 添加了
jax.scipy.stats.mode()
。
- 添加了
- 如果在 TPU 上未指定顺序,则
pmap
的默认设备顺序现在与单进程作业的jax.devices()
匹配。以前两种排序不同,可能导致不必要的复制或内存不足错误。要求排序一致简化了问题。
- 突破性变更
-
jax.numpy.gradient()
现在像jax.numpy
中的大多数其他函数一样,禁止传递列表或元组以替代数组(#12958)。 -
jax.numpy.linalg
和jax.numpy.fft
中的函数现在统一要求输入为数组形式:即不能使用列表和元组代替数组。部分属于#7737。
-
- 弃用
jax.sharding.MeshPspecSharding
已重命名为jax.sharding.NamedSharding
。jax.sharding.MeshPspecSharding
名称将在 3 个月内删除。
jaxlib 0.3.24(2022 年 11 月 4 日)
- 更改
- 现在在 CPU 上可以使用缓冲器捐赠。这可能会破坏在 CPU 上标记缓冲区进行捐赠但依赖捐赠未实现的代码。
jax 0.3.23(2022 年 10 月 12 日)
- 更改
- 更新 Colab TPU 驱动程序版本以支持新的 jaxlib 发布。
jax 0.3.22(2022 年 10 月 11 日)
- 更改
- 在 TPU 初始化中添加
JAX_PLATFORMS=tpu,cpu
作为默认设置,因此如果无法初始化 TPU,JAX 将引发错误而不是回退到 CPU。设置JAX_PLATFORMS=''
以覆盖此行为并自动选择可用的后端(原始默认),或设置JAX_PLATFORMS=cpu
以始终使用 CPU,而不管 TPU 是否可用。
- 在 TPU 初始化中添加
- 弃用
- JAX v0.3.8 中弃用的几个测试工具现已从
jax.test_util
中移除。
- JAX v0.3.8 中弃用的几个测试工具现已从
jaxlib 0.3.22(2022 年 10 月 11 日)
jax 0.3.21(2022 年 9 月 30 日)
- GitHub 提交记录。
- 更改
- 持久化编译缓存现在在出错时会发出警告而不是抛出异常(#12582),所以如果缓存出现问题,程序可以继续执行。设置
JAX_RAISE_PERSISTENT_CACHE_ERRORS=true
可以恢复此行为。
- 持久化编译缓存现在在出错时会发出警告而不是抛出异常(#12582),所以如果缓存出现问题,程序可以继续执行。设置
jax 0.3.20(2022 年 9 月 28 日)
- Bug 修复:
- 添加了上一个发布版本中缺失的
.pyi
文件(#12536)。 - 修复了
jax
0.3.19 与其固定的 libtpu 版本之间的不兼容性(#12550)。需要 jaxlib 0.3.20。 - 修复了
setup.py
注释中pip
的错误网址(#12528)。
- 添加了上一个发布版本中缺失的
jaxlib 0.3.20(2022 年 9 月 28 日)
- GitHub 提交记录。
- Bug 修复
- 修复通过
jax_cuda_visible_devices
在分布式作业中限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成非常重要(#12533)。
- 修复通过
jax 0.3.19(2022 年 9 月 27 日)
- GitHub 提交记录。
- 需要的 jaxlib 版本修复。
jax 0.3.18(2022 年 9 月 26 日)
- GitHub 提交记录。
- 更改
- 提前编译和编译功能(在#7733中跟踪)是稳定和公开的。查看概述和
jax.stages
的 API 文档。 - 引入了
jax.Array
,用于 JAX 中数组类型的isinstance
检查和类型注释。请注意,这包括了对jax.numpy.ndarray
在 JAX 内部对象中如何工作的一些微妙更改,因为jax.numpy.ndarray
现在是jax.Array
的简单别名。
- 提前编译和编译功能(在#7733中跟踪)是稳定和公开的。查看概述和
- 破坏性变更
-
jax._src
不再导入公共jax
命名空间。这可能会打破使用 JAX 内部功能的用户。 - 已删除
jax.soft_pmap
。请改用pjit
或xmap
。jax.soft_pmap
未记录文档。如果有文档记录,将提供弃用期。
-
jax 0.3.17(2022 年 8 月 31 日)
- GitHub 提交记录。
- 错误修复
- 修复了
lax.pow
的梯度在指数为零时的特殊情况问题(#12041)
- 修复了
- 破坏性变更
jax.checkpoint()
,又称jax.remat()
,不再支持concrete
选项,遵循前一个版本的弃用;请参阅JEP 11830。
- 变更
- 添加了
jax.pure_callback()
,允许从编译函数(例如用jax.jit
或jax.pmap
装饰的函数)调用纯 Python 函数。
- 添加了
- 弃用:
- 已移除不推荐使用的
DeviceArray.tile()
方法。使用jax.numpy.tile()
代替(#11944)。 - 已弃用
DeviceArray.to_py()
。请改用np.asarray(x)
。
- 已移除不推荐使用的
jax 0.3.16
- GitHub 提交记录。
- 破坏性变更
- 支持 NumPy 1.19 已被移除,根据弃用政策。请升级到 NumPy 1.20 或更新版本。
- 变更
- 添加了
jax.debug
,包括用于运行时值调试的实用程序,如jax.debug.print()
和jax.debug.breakpoint()
。 - 添加了用于运行时值调试的新文档
- 添加了
- 弃用
- 移除了
jax.mask()
和jax.shapecheck()
API。详见#11557。 - 移除了
jax.experimental.loops
。可查看#10278获取替代 API。 -
jax.tree_util.tree_multimap()
已移除。自 JAX 版本 0.3.5 起已被弃用,jax.tree_util.tree_map()
是直接替换。 - 删除了
jax.experimental.stax
;它长期以来一直是jax.example_libraries.stax
的弃用别名。 - 移除了
jax.experimental.optimizers
;它长期以来一直是jax.example_libraries.optimizers
的弃用别名。 -
jax.checkpoint()
,又称jax.remat()
,有了一个新的默认实现,意味着旧的实现已被弃用;请参阅JEP 11830。
- 移除了
jax 0.3.15(2022 年 7 月 22 日)
- GitHub 提交记录。
- 变更
-
jax.test_util
中已移除JaxTestCase
和JaxTestLoader
类,自 v0.3.1 起已弃用(#11248)。 - 添加了
jax.scipy.gaussian_kde
(#11237)。 - JAX 数组与内置集合(
dict
、list
、set
、tuple
)之间的二元操作现在在所有情况下都会引发TypeError
。以前的某些情况(特别是相等性和不等式)会返回与 NumPy 中类似操作不一致的布尔标量(#11234)。 - 几个作为顶级 JAX 包导入的
jax.tree_util
例程现已弃用,并将根据 API 兼容性政策在未来的 JAX 发布版本中移除。-
jax.treedef_is_leaf()
已弃用,推荐使用jax.tree_util.treedef_is_leaf()
。 -
jax.tree_flatten()
已弃用,推荐使用jax.tree_util.tree_flatten()
。 -
jax.tree_leaves()
已弃用,推荐使用jax.tree_util.tree_leaves()
。 -
jax.tree_structure()
已弃用,推荐使用jax.tree_util.tree_structure()
。 -
jax.tree_transpose()
已弃用,推荐使用jax.tree_util.tree_transpose()
。 -
jax.tree_unflatten()
已弃用,推荐使用jax.tree_util.tree_unflatten()
。
-
-
jax.scipy.linalg.solve()
的sym_pos
参数已弃用,推荐使用assume_a='pos'
,遵循scipy.linalg.solve()
中类似的弃用。
-
jaxlib 0.3.15(2022 年 7 月 22 日)
- GitHub 提交。
jax 0.3.14(2022 年 6 月 27 日)
- GitHub 提交。
- 破坏性变更
-
jax.experimental.compilation_cache.initialize_cache()
现在不再支持max_cache_size_ bytes
,并且不会将其作为输入。 - 当平台初始化失败时,
JAX_PLATFORMS
现在会引发异常。
-
- 变更
- 解决了与 NumPy 1.23 的兼容性问题。
-
jax.numpy.linalg.slogdet()
现在接受一个可选的method
参数,允许选择基于 LU 分解或基于 QR 分解的实现。 -
jax.numpy.linalg.qr()
现在支持mode="raw"
。 - 在对 JAX 数组使用
pickle
、copy.copy
和copy.deepcopy
时,现在支持更完整的支持(#10659)。特别是:- 当对
DeviceArray
使用pickle
和deepcopy
时,以前返回np.ndarray
对象,现在返回DeviceArray
对象。对于deepcopy
,复制的数组位于与原始数组相同的设备上。对于pickle
,反序列化的数组将位于默认设备上。 - 在函数转换(即跟踪代码)内部,
deepcopy
和copy
以前是空操作。现在它们使用与DeviceArray.copy()
相同的机制。 - 对跟踪数组进行
pickle
操作现在会导致显式的ConcretizationTypeError
。
- 当对
- 在 TPU 上,奇异值分解(SVD)和对称/Hermitian 特征分解的实现应显著更快,特别是对于超过 1000x1000 大小的矩阵。现在都使用了谱分裂与征算法进行特征分解(QDWH-eig)。
-
jax.numpy.ldexp()
现在不再将所有输入默认提升为 float64,而是对于 int32 或更小的整数输入,提升为 float32 (#10921)。 - 添加了一个
create_perfetto_link
选项到jax.profiler.start_trace()
和jax.profiler.start_trace()
。使用时,分析器将生成一个链接到 Perfetto UI 以查看跟踪信息。 - 更改了
jax.profiler.start_server(...)()
的语义,将 keepalive 全局存储,而不再要求用户保留引用。 - 添加了
jax.random.generalized_normal()
。 - 添加了
jax.random.ball()
。 - 添加了
jax.default_device()
。 - 添加了一个
python -m jax.collect_profile
脚本,手动捕获程序跟踪,作为 TensorBoard UI 的替代方法。 - 添加了一个
jax.named_scope
上下文管理器,向 Python 程序添加分析器元数据(类似于jax.named_call
)。 - 在 scatter-update 操作(即 :attr:
jax.numpy.ndarray.at
)中,不安全的隐式 dtype 转换已弃用,现在会产生FutureWarning
。在将来的版本中,这将变成一个错误。一个不安全的隐式转换的例子是jnp.zeros(4, dtype=int).at[0].set(1.5)
,其中1.5
之前会被静默截断为1
。 -
jax.experimental.compilation_cache.initialize_cache()
现在支持 gcs 存储桶路径作为输入。 - 添加了
jax.scipy.stats.gennorm()
。 -
jax.numpy.roots()
现在在strip_zeros=False
时,在系数有前导零时行为更佳 (#11215)。
jaxlib 0.3.14(2022 年 6 月 27 日)。
- GitHub 提交记录。
- x86-64 Mac wheels 现在要求 Mac OS 10.14(Mojave)或更新版本。Mac OS 10.14 发布于 2018 年,因此这不应该是一个非常繁重的要求。
- 捆绑的 NCCL 版本更新到 2.12.12,修复了一些死锁问题。
- Python flatbuffers 包不再是 jaxlib 的依赖项。
jax 0.3.13(2022 年 5 月 16 日)。
- GitHub 提交记录。
jax 0.3.12(2022 年 5 月 15 日)。
- GitHub 提交记录。
- 变更:
- 修复了 #10717。
jax 0.3.11(2022 年 5 月 15 日)。
- GitHub 提交记录。
- 变更:
jax.lax.eigh()
现在接受一个可选的sort_eigenvalues
参数,允许用户在 TPU 上选择不排序特征值。
- 弃用:
-
jax.lax.linalg
中的函数现在要求非数组参数必须作为关键字参数传递。为了向后兼容,将关键字参数作为位置参数传递将会得到警告,但在未来的 JAX 发布中,将会导致失败。大多数用户应该优先考虑使用jax.numpy.linalg
。 -
jax.scipy.linalg.polar_unitary()
,这是 JAX 对 scipy API 的扩展,已被弃用。请改用jax.scipy.linalg.polar()
。
-
jax 0.3.10 (2022 年 5 月 3 日)
- GitHub 提交记录.
jaxlib 0.3.10 (2022 年 5 月 3 日)
- GitHub 提交记录.
- 变更
- TF 提交记录 修复了 MHLO 规范化器中的问题,该问题导致某些程序的常量折叠花费很长时间或崩溃。
jax 0.3.9 (2022 年 5 月 2 日)
- GitHub 提交记录.
- 变更
- 增加了对 GlobalDeviceArray 的完全异步检查点支持。
jax 0.3.8 (2022 年 4 月 29 日)
- GitHub 提交记录.
- 变更
- 在 TPU 上,
jax.numpy.linalg.svd()
现在使用 qdwh-svd 求解器。 - 在 TPU 上,
jax.numpy.linalg.cond()
现在接受复数输入。 - 在 TPU 上,
jax.numpy.linalg.pinv()
现在接受复数输入。 - 在 TPU 上,
jax.numpy.linalg.matrix_rank()
现在接受复数输入。 - 已添加
jax.scipy.cluster.vq.vq()
。 -
jax.experimental.maps.mesh
已删除。请使用jax.experimental.maps.Mesh
。请参阅 此处 获取更多信息。 - 当
mode='r'
时,jax.scipy.linalg.qr()
现在返回一个长度为 1 的元组,而不是原始数组,以匹配scipy.linalg.qr
的行为(#10452) -
jax.numpy.take_along_axis()
现在接受一个可选的mode
参数,用于指定超出边界索引的行为。默认情况下,超出边界的索引会返回无效值(例如 NaN)。在 JAX 的早期版本中,无效的索引会被夹在范围内。可以通过传递mode="clip"
恢复先前的行为。 -
jax.numpy.take()
现在默认为mode="fill"
,这会对超出索引范围的位置返回无效值(例如 NaN)。 - 散点操作,例如
x.at[...].set(...)
,现在具有"drop"
语义。这对散点操作本身没有影响,但这意味着在进行微分时,散点的梯度对超出边界的索引的余切为零。以前超出边界的索引在梯度中被夹在范围内,这在数学上是不正确的。 -
jax.numpy.take_along_axis()
现在如果其索引不是整数类型将会引发TypeError
,与numpy.take_along_axis()
的行为一致。先前非整数索引会被静默转换为整数。 -
jax.numpy.ravel_multi_index()
现在如果其dims
参数不是整数类型将会引发TypeError
,与numpy.ravel_multi_index()
的行为一致。先前非整数dims
参数会被静默转换为整数。 -
jax.numpy.split()
现在如果其axis
参数不是整数类型将会引发TypeError
,与numpy.split()
的行为一致。先前非整数axis
参数会被静默转换为整数。 -
jax.numpy.indices()
现在如果其维度不是整数类型将会引发TypeError
,与numpy.indices()
的行为一致。先前非整数维度会被静默转换为整数。 -
jax.numpy.diag()
现在如果其k
参数不是整数类型将会引发TypeError
,与numpy.diag()
的行为一致。先前非整数k
参数会被静默转换为整数。 - 添加了
jax.random.orthogonal()
。
- 在 TPU 上,
- 已过时:
- 许多
jax.test_util
中可用的函数和对象现已过时,并将在导入时引发警告。包括cases_from_list
、check_close
、check_eq
、device_under_test
、format_shape_dtype_string
、rand_uniform
、skip_on_devices
、with_config
、xla_bridge
和_default_tolerance
(#10389)。这些以及先前过时的JaxTestCase
、JaxTestLoader
和BufferDonationTestCase
将在未来的 JAX 发布中移除。大多数这些实用程序可以通过调用标准的 Python 和 NumPy 测试实用程序来替换,如unittest
、absl.testing
、numpy.testing
等。可以通过公共 API(例如jax.devices()
)来替换 JAX 特定的功能,如设备检查。许多已过时的实用程序仍然存在于jax._src.test_util
中,但这些不是公共 API,因此可能在未来的发布中更改或移除,而不另行通知。
- 许多
jax 0.3.7(2022 年 4 月 15 日)
- GitHub 提交记录。
- 变更:
- 修复了当传递给
jax.numpy.take_along_axis()
的索引广播时的性能问题(#10281)。 -
jax.scipy.special.expit()
和jax.scipy.special.logit()
现在要求其参数为标量或 JAX 数组。它们现在还将整数参数提升为浮点数。 -
DeviceArray.tile()
方法已弃用,因为 numpy 数组没有tile()
方法。作为替代,请使用jax.numpy.tile()
(#10266)。
- 修复了当传递给
jaxlib 0.3.7(2022 年 4 月 15 日)
- 变更:
- Linux 版本现在符合
manylinux2014
标准,而不是manylinux2010
。
- Linux 版本现在符合
jax 0.3.6(2022 年 4 月 12 日)
- GitHub 提交记录。
- 变更:
- 将 libtpu 轮子升级到修复初始化 TPU pod 时挂起的版本。修复了 #10218。
- 弃用:
jax.experimental.loops
将被弃用。参见 #10278 了解替代 API。
jax 0.3.5(2022 年 4 月 7 日)
- GitHub 提交记录。
- 变更:
- 添加了
jax.random.loggamma()
并改进了对小参数值的jax.random.beta()
和jax.random.dirichlet()
的行为(#9906)。 -
lax_numpy
私有子模块不再暴露在jax.numpy
命名空间中(#10029)。 - 添加了数组创建例程
jax.numpy.frombuffer()
、jax.numpy.fromfunction()
和jax.numpy.fromstring()
(#10049)。 -
DeviceArray.copy()
现在返回DeviceArray
而不是np.ndarray
(#10069) - 添加了
jax.scipy.linalg.rsf2csf()
-
jax.experimental.sharded_jit
已被弃用,并将很快移除。
- 添加了
- 弃用:
-
jax.nn.normalize()
将被弃用。请使用jax.nn.standardize()
替代(#9899)。 -
jax.tree_util.tree_multimap()
已弃用。请使用jax.tree_util.tree_map()
替代(#5746)。 -
jax.experimental.sharded_jit
已弃用。请使用pjit
替代。
-
jaxlib 0.3.5(2022 年 4 月 7 日)
- 修复了 bug
- 修复了一个 bug,双精度复杂到实数 IRFFT 在 GPU 上会改变其输入缓冲区(#9946)。
- 修复了复杂散布常量折叠错误(#10159)
jax 0.3.4(2022 年 3 月 18 日)
- GitHub 提交记录。
jax 0.3.3(2022 年 3 月 17 日)
- GitHub 提交记录。
jax 0.3.2(2022 年 3 月 16 日)
- GitHub 提交记录。
- 变更:
- 函数
jax.ops.index_update
、jax.ops.index_add
在 0.2.22 中已弃用。请使用JAX 数组上的.at
属性,例如,x.at[idx].set(y)
。 - 将
jax.experimental.ann.approx_*_k
移至jax.lax
。这些函数是jax.lax.top_k
的优化替代品。 -
jax.numpy.broadcast_arrays()
和jax.numpy.broadcast_to()
现在要求标量或类数组输入,并在传递列表时将失败(部分 #7737)。 - 标准的
jax[tpu]
安装现在可以与 Cloud TPU v4 VMs 一起使用。 -
pjit
现在支持在 CPU 上运行(除了之前的 TPU 和 GPU 支持)。
- 函数
jaxlib 0.3.2 (2022 年 3 月 16 日)
- 更改
XlaComputation.as_hlo_text()
现在支持通过传递布尔标志print_large_constants=True
打印大常量。
- 弃用:
JAX
数组上的.block_host_until_ready()
方法已弃用。请改用.block_until_ready()
。
jax 0.3.1 (2022 年 2 月 18 日)
GitHub 提交记录。
更改:
jax.test_util.JaxTestCase
和 jax.test_util.JaxTestLoader
现在已弃用。建议直接使用 parametrized.TestCase
进行替换。对于依赖于自定义断言(如 JaxTestCase.assertAllClose()
)的测试,请使用标准的 numpy 测试工具,如numpy.testing.assert_allclose()
,它们直接与 JAX 数组一起工作(#9620)。
jax.test_util.JaxTestCase
现在默认设置 jax_numpy_rank_promotion='raise'
(#9562)。要恢复以前的行为,请使用新的 jax.test_util.with_config
装饰器:
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
添加了 jax.scipy.linalg.schur()
、jax.scipy.linalg.sqrtm()
、jax.scipy.signal.csd()
、jax.scipy.signal.stft()
、jax.scipy.signal.welch()
。
jax 0.3.0 (2022 年 2 月 10 日)
- GitHub 提交记录。
- 更改
- jax 版本已升级至 0.3.0. 请参阅设计文档以获取说明。
jaxlib 0.3.0 (2022 年 2 月 10 日)
- 更改
- 现在需要 Bazel 5.0.0 来构建 jaxlib。
- jaxlib 版本已升级至 0.3.0. 请参阅设计文档以获取说明。
jax 0.2.28 (2022 年 2 月 1 日)
- GitHub 提交记录。
- 如果未传递
dialect=
,jax.jit(f).lower(...).compiler_ir()
现在默认为 MHLO 方言。 -
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
现在返回 MLIRir.Module
对象,而不是其字符串表示。
- 如果未传递
jaxlib 0.1.76 (2022 年 1 月 27 日)
- 新功能
- 包括为 NVidia 计算能力 8.0 的 GPU(例如 A100)预编译的 SASS。删除了计算能力 6.1 的预编译 SASS,以避免增加计算能力的数量:具有计算能力 6.1 的 GPU 可以使用 6.0 的 SASS。
- 使用 jaxlib 0.1.76,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。
- Breaking changes
- 不再支持 NumPy 1.18,根据弃用策略。请升级到支持的 NumPy 版本。
- Bug 修复
- 修复了一个 bug,即由不同路径构造的表面相同的 pytreedef 对象不会被视为相等(#9066)。
- JAX jit 缓存要求两个静态参数具有相同的类型以进行缓存命中(#9311)。
jax 0.2.27(2022 年 1 月 18 日)
- GitHub 提交。
- Breaking changes:
- 不再支持 NumPy 1.18,根据弃用策略。请升级到支持的 NumPy 版本。
- host_callback 原语已简化,取消了 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,只有原始值被 tap。可以通过设置
JAX_HOST_CALLBACK_AD_TRANSFORMS
环境变量或--jax_host_callback_ad_transforms
标志来获取旧的行为(在有限时间内)。此外,增加了如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。 - 排序现在与 NumPy 的行为匹配,无论位表示如何,对于
0.0
和NaN
都是如此。特别是,现在0.0
和-0.0
被视为等价,而之前-0.0
被视为小于0.0
。此外,所有的NaN
表示现在都被视为等价,并且按照这些位模式排序到数组的末尾。以前,负数的NaN
值被排序到数组的前面,并且具有不同内部位表示的NaN
值不被视为等价,根据这些位模式排序(#9178)。 -
jax.numpy.unique()
现在在处理NaN
值时与 NumPy 版本 1.21 及更新版本的np.unique
一样:在唯一化的输出中最多只会出现一个NaN
值(#9184)。
- Bug 修复:
- 现在 host_callback 支持 ad_checkpoint.checkpoint(#8907)。
- 新功能:
- 添加了
jax.block_until_ready
({jax-issue}`#8941)。 - 添加了一个新的调试标志/环境变量
JAX_DUMP_IR_TO=/path
。如果设置了,JAX 会将它为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件。 - 添加了
jax.ensure_compile_time_eval
到公共 API(#7987)。 - jax2tf 现在支持一个标志 jax2tf_associative_scan_reductions,用于改变关联约简的降低,例如 jnp.cumsum,在 CPU 和 GPU 上的行为(使用关联扫描)。更多细节请参见 jax2tf README(#9189)。
- 添加了
jaxlib 0.1.75(2021 年 12 月 8 日)
- 新功能:
- 支持 python 3.10。
jax 0.2.26(2021 年 12 月 8 日)
- GitHub 提交记录。
- 错误修复:
- 对
jax.ops.segment_sum
的越界索引现在将使用FILL_OR_DROP
语义处理,如文档中所述。这主要影响反向模式导数,其中与越界索引对应的梯度现在将返回为 0。(#8634)。 - jax2tf 现在会强制转换代码,使其在 jax.jit 下的代码片段使用 XLA,例如大多数 jax.numpy 函数(#7839)。
- 对
jaxlib 0.1.74(2021 年 11 月 17 日)
- 在 GPU 之间启用点对点复制。以前,GPU 复制通过主机反弹,这通常更慢。
- 增加了实验性的 MLIR Python 绑定,供 JAX 使用。
jax 0.2.25(2021 年 11 月 10 日)
- GitHub 提交记录。
- 新功能:
- (实验性)
jax.distributed.initialize
暴露多主机 GPU 后端。 -
jax.random.permutation
支持新的independent
关键字参数(#8430)
- (实验性)
- 破坏性更改
- 将
jax.experimental.stax
移至jax.example_libraries.stax
- 将
jax.experimental.optimizers
移至jax.example_libraries.optimizers
- 将
- 新功能:
- 添加了
jax.lax.linalg.qdwh
。
- 添加了
jax 0.2.24(2021 年 10 月 19 日)
- GitHub 提交记录。
- 新功能:
jax.random.choice
和jax.random.permutation
现在支持多维数组和可选的axis
参数(#8158)。
- 破坏性更改:
- 现在
jax.numpy.take
和jax.numpy.take_along_axis
要求数组样式的输入(参见 #7737)。
- 现在
jaxlib 0.1.73(2021 年 10 月 18 日)
现在支持多个 cuDNN 版本的 jaxlib GPU cuda11
轮。
- cuDNN 8.2 或更新版本。如果您的 cuDNN 安装足够新,请使用 cuDNN 8.2 轮,因为它支持额外的功能。
- cuDNN 8.0.5 或更新版本。
破坏性更改:
GPU jaxlib 的安装命令如下:
代码语言:javascript复制pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22(2021 年 10 月 12 日)
- GitHub 提交记录。
- 破坏性更改
-
jax.pmap
的静态参数现在必须是可哈希的。 在jax.jit
上长期不允许非哈希静态参数,但在jax.pmap
上仍然允许;jax.pmap
使用对象标识比较非哈希静态参数。 这种行为可能会导致一些问题,因为使用对象身份比较来比较参数会导致每次对象身份变化时重新编译。现在我们禁止非可哈希参数:如果jax.pmap
的用户希望通过对象身份比较静态参数,他们可以在其对象上定义__hash__
和__eq__
方法,或者将其对象包装在具有对象身份语义的对象中。另一种选择是使用functools.partial
将非可哈希的静态参数封装到函数对象中。 -
jax.util.partial
是一个意外导出的内容,已被移除。请使用 Python 标准库中的functools.partial
替代。
-
- Deprecations
- 函数
jax.ops.index_update
、jax.ops.index_add
等已被弃用,并将在未来的 JAX 版本中移除。请改用 JAX 数组上的.at
属性,例如x.at[idx].set(y)
。目前,这些函数会产生DeprecationWarning
。
- 函数
- New features:
- 优化的 C 代码路径现在是使用 jaxlib 0.1.72 或更新版本时的默认设置,用于提高
pmap
的调度时间。可以使用--experimental_cpp_pmap
标志(或JAX_CPP_PMAP
环境变量)禁用该功能。 -
jax.numpy.unique
现在支持一个可选的fill_value
参数(#8121)。
- 优化的 C 代码路径现在是使用 jaxlib 0.1.72 或更新版本时的默认设置,用于提高
jaxlib 0.1.72 (Oct 12, 2021)
- Breaking changes:
- CUDA 10.2 和 CUDA 10.1 的支持已被移除。Jaxlib 现在支持 CUDA 11.1 。
- Bug fixes:
- 修复了 https://github.com/google/jax/issues/7461,在所有平台上由于 XLA 编译器内部的错误缓冲区别名而导致错误的输出。
jax 0.2.21 (Sept 23, 2021)
- GitHub commits.
- Breaking Changes
-
jax.api
已被移除。之前作为jax.api.*
可用的函数现在被别名为jax.*
中的函数;请直接使用jax.*
中的函数。 -
jax.partial
和jax.lax.partial
是意外导出的内容,已被移除。请使用 Python 标准库中的functools.partial
替代。 - 布尔标量索引现在会引发
TypeError
;之前这些操作会静默返回错误的结果(#7925)。 - 许多
jax.numpy
函数现在要求数组样式的输入,如果传递列表将会报错(#7747 #7802 #7907)。查看 #7737 以了解此更改背后的原因讨论。 - 当在
jax.jit
等转换内部时,jax.numpy.array
总是将其生成的数组分阶段到跟踪的计算中。以前的jax.numpy.array
有时会在jax.jit
装饰器下生成一个设备上的数组。这种变化可能会破坏使用 JAX 数组执行必须静态知道形状或索引计算的代码;解决方法是改用经典的 NumPy 数组执行这些计算。 -
jnp.ndarray
现在是 JAX 数组的真正基类。特别地,对于标准的 numpy 数组x
,isinstance(x, jnp.ndarray)
现在会返回False
(#7927)。
-
- 新特性:
- 添加了
jax.numpy.insert()
的实现 (#7936)。
- 添加了
jax 0.2.20 (2021 年 9 月 2 日)
- GitHub 提交记录。
- Breaking Changes
-
jnp.poly*
函数现在要求数组样式的输入 (#7732)。 -
jnp.unique
和其他类似集合的操作现在要求数组样式的输入 (#7662)。
-
jaxlib 0.1.71 (2021 年 9 月 1 日)
- Breaking changes:
- 不再支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1 。
jax 0.2.19 (2021 年 8 月 12 日)
- GitHub 提交记录。
- Breaking changes:
- 支持 NumPy 1.17 已经被废弃,按照弃用政策。请升级到支持的 NumPy 版本。
- 在 JAX 数组的多个操作的实现周围添加了
jit
装饰器。这加快了常见操作如x 2**40
)。解决方法是将常数转换为显式类型(例如np.float64(2**40)
)。
- 新特性:
- 改进了对需要在数组计算中使用维度大小的操作在 jax2tf 中的形状多态支持,例如
jnp.mean
。 (#7317)。
- 改进了对需要在数组计算中使用维度大小的操作在 jax2tf 中的形状多态支持,例如
- Bug 修复:
- 上一个版本的泄漏的追踪错误 (#7613)。
jaxlib 0.1.70 (2021 年 8 月 9 日)
- Breaking changes:
- 支持 Python 3.6 已经被废弃,按照弃用政策。请升级到支持的 Python 版本。
- 支持 NumPy 1.17 已经被废弃,按照弃用政策。请升级到支持的 NumPy 版本。
- 现在主机回调机制每个本地设备使用一个线程来调用 Python 回调。以前所有设备共用一个线程。这意味着现在回调可能交错调用。仍然会按顺序调用一个设备对应的所有回调。
jax 0.2.18(2021 年 7 月 21 日)
- GitHub 提交记录。
- Breaking 变更:
- 根据弃用策略,不再支持 Python 3.6。请升级到支持的 Python 版本。
- jaxlib 最低版本现在是 0.1.69。
-
jax.dlpack.from_dlpack()
的backend
参数已移除。
- 新功能:
- 添加了极分解(
jax.scipy.linalg.polar()
)。
- 添加了极分解(
- Bug 修复:
- 加强了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会使用无效的
axis
值或空的减少维度。 (#7196)
- 加强了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会使用无效的
jaxlib 0.1.69(2021 年 7 月 9 日)
- 修复了 TFRT CPU 后端中导致结果不正确的错误。
jax 0.2.17(2021 年 7 月 9 日)
- GitHub 提交记录。
- Bug 修复:
- 对于 jaxlib <= 0.1.68,默认使用较旧的“stream_executor” CPU 运行时,以解决#7229,这导致 CPU 上由于并发问题输出错误结果。
- 新功能:
- 新的 SciPy 函数
jax.scipy.special.sph_harm()
。 - 反向模式自动微分函数(
jax.grad()
,jax.value_and_grad()
,jax.vjp()
和jax.linear_transpose()
)支持一个参数,指示在后向传递中应该对哪些命名轴进行求和,如果它们在前向传递中被广播。这使得可以在 maps 内部以非每个示例的方式使用这些 API(最初仅jax.experimental.maps.xmap()
)(#6950)。
- 新的 SciPy 函数
jax 0.2.16(2021 年 6 月 23 日)
- GitHub 提交记录。
jax 0.2.15(2021 年 6 月 23 日)
- GitHub 提交记录。
- 新功能:
- #7042 使用了 TFRT CPU 后端,在 CPU 上显著提升了分派性能。
-
jax2tf.convert()
支持布尔型不等式和 min/max 函数(#6956)。 - 新的 SciPy 函数
jax.scipy.special.lpmn_values()
。
- Breaking 变更:
- 根据弃用策略,不再支持 NumPy 1.16。
- Bug 修复:
- 修复了阻止从 JAX 到 TF 再到 JAX 回传的错误:
jax2tf.call_tf(jax2tf.convert)
(#6947)。
- 修复了阻止从 JAX 到 TF 再到 JAX 回传的错误:
jaxlib 0.1.68(2021 年 6 月 23 日)
- Bug 修复:
- 修复了 TFRT CPU 后端中将 TPU 缓冲区传输到 CPU 时出现 NaN 的错误。
jax 0.2.14(2021 年 6 月 10 日)
- GitHub 提交记录。
- 新功能:
-
jax2tf.convert()
现在支持pjit
和sharded_jit
。 - 新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯信息。
- 在足够新的 IPython 版本中,默认启用了使用
__tracebackhide__
的新的回溯过滤模式。 -
jax2tf.convert()
在算术操作中使用未知维度时,即使在形状多态性中,也支持形状多态性,例如jnp.reshape(-1)
(#6827)。 -
jax2tf.convert()
现在在 TF 操作中生成具有位置信息的自定义属性。在 jax2tf 之后 XLA 生成的代码具有与 JAX/XLA 相同的位置信息。 - 新的 SciPy 函数
jax.scipy.special.lpmn()
。
-
- Bug fixes:
-
jax2tf.convert()
现在确保对于 Python 标量和选择 32 位 vs. 64 位计算时使用相同的类型规则,如 JAX(#6883)。 -
jax2tf.convert()
现在正确地将enable_xla
转换参数限定范围到仅在即时转换期间应用(#6720)。 -
jax2tf.convert()
现在使用XlaDot
TensorFlow 操作来转换lax.dot_general
,以提高与 JAX 数值精度的一致性(#6717)。 -
jax2tf.convert()
现在支持复数的不等式比较和最小/最大值(#6892)。
-
jaxlib 0.1.67(2021 年 5 月 17 日)
jaxlib 0.1.66(2021 年 5 月 11 日)
- 新特性:
- 现在支持在所有 CUDA 11 版本(11.1 或更高版本)上使用 CUDA 11.1 wheels。 NVIDIA 现在承诺从 CUDA 11.1 开始兼容 CUDA 小版本更新。这意味着 JAX 可以发布一个兼容 CUDA 11.2 和 11.3 的单个 CUDA 11.1 wheel。 不再为 CUDA 11.2(或更高版本)发布单独的 jaxlib 版本;对于这些版本,请使用 CUDA 11.1 wheel(cuda111)。
- Jaxlib 现在在 CUDA wheels 中捆绑
libdevice.10.bc
。不需要指定 CUDA 安装路径来查找此文件。 -
jit()
实现自动支持静态关键字参数。 - 添加了对预转换异常跟踪的支持。
- 初步支持从
jit()
转换的计算中剪枝未使用的参数。剪枝仍在进行中。 - 改进了
PyTreeDef
对象的字符串表示。 - 添加了对 XLA 可变 ReduceWindow 的支持。
- Bug fixes:
- 修复了在远程云 TPU 支持中传递大量参数时的 bug。
- 修复了一个问题,即
jit()
转换的函数未触发 JAX 垃圾回收。
jax 0.2.13(2021 年 5 月 3 日)
- GitHub 提交。
- 新特性:
- 结合 jaxlib 0.1.66 使用时,
jax.jit()
现在支持静态关键字参数。新增了static_argnames
选项以指定关键字参数为静态。 -
jax.nonzero()
现在有一个新的可选参数size
,允许在jit
内使用 (#6501)。 -
jax.numpy.unique()
现在支持axis
参数 (#6532)。 -
jax.experimental.host_callback.call()
现在支持pjit.pjit
(#6569)。 - 添加了
jax.scipy.linalg.eigh_tridiagonal()
,用于计算三对角矩阵的特征值。目前仅支持特征值。 - 异常中筛选和未筛选的堆栈跟踪顺序已更改。从 JAX 转换代码中抛出的异常现在附带有过滤后的回溯,
UnfilteredStackTrace
异常包含原始跟踪作为过滤异常的__cause__
。现在,筛选的堆栈跟踪也适用于 Python 3.6。 - 如果由反向模式自动微分转换的代码引发异常,JAX 现在尝试附加一个
JaxStackTraceBeforeTransformation
对象作为异常的__cause__
,该对象包含在正向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
- 结合 jaxlib 0.1.66 使用时,
- 破坏性变更:
- 下列函数名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
-
host_id
–>process_index()
-
host_count
–>process_count()
-
host_ids
–>range(jax.process_count())
-
- 同样地,
local_devices()
的参数已从host_id
重命名为process_index
。 - 除了函数之外的
jax.jit()
参数现在标记为仅限关键字。此更改旨在防止在向jit
添加参数时意外破坏代码。
- 下列函数名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
- Bug 修复:
- 现在
jax2tf.convert()
在带有整数输入的函数梯度存在时能正常工作 (#6360)。 - 修复了
jax2tf.call_tf()
在与捕获的tf.Variable
结合使用时的断言失败 (#6572)。
- 现在
jaxlib 0.1.65(2021 年 4 月 7 日)
jax 0.2.12(2021 年 4 月 1 日)
- GitHub 提交记录。
- 新功能
- 新的分析 API:
jax.profiler.start_trace()
,jax.profiler.stop_trace()
和jax.profiler.trace()
-
jax.lax.reduce()
现在可微分。
- 新的分析 API:
- 破坏性变更:
- 最低的 jaxlib 版本现在是 0.1.64。
- 一些分析器 API 名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
-
TraceContext
–>TraceAnnotation()
-
StepTraceContext
–>StepTraceAnnotation()
-
trace_function
–>annotate_function()
-
- 无法禁用全局分析。有关更多信息,请参阅 omnistaging。
- Python 整数大于最大的
int64
值现在在所有情况下都会导致溢出,而不是在某些情况下静默转换为uint64
(#6047)。 - 在非 X64 模式下,超出
int32
可表示范围的 Python 整数现在将导致OverflowError
,而不是静默截断其值。
- Bug 修复:
-
host_callback
现在支持参数和结果中的空数组(#6262)。 -
jax.random.randint()
在超出限制范围时会剪切而不是包裹,现在可以生成指定 dtype 的整数的全部范围(#5868)。
-
jax 0.2.11(2021 年 3 月 23 日)
- GitHub 提交记录。
- 新特性:
- #6112 添加了上下文管理器:
jax.enable_checks
,jax.check_tracer_leaks
,jax.debug_nans
,jax.debug_infs
,jax.log_compiles
。 - #6085 添加了
jnp.delete
- #6112 添加了上下文管理器:
- Bug 修复:
- #6136 泛化了
jax.flatten_util.ravel_pytree
以处理整数 dtype。 - #6129 修复了处理像
enum.IntEnums
这样的一些常量的错误 - #6145 修复了不完全贝塔函数批处理问题
- #6014 修复了追踪过程中的 H2D 传输问题
- #6165 在将一些大的 Python 整数转换为浮点数时避免 OverflowErrors
- #6136 泛化了
- 破坏性变更:
- jaxlib 最小版本现在是 0.1.62。
jaxlib 0.1.64(2021 年 3 月 18 日)
jaxlib 0.1.63(2021 年 3 月 17 日)
jax 0.2.10(2021 年 3 月 5 日)
- GitHub 提交记录。
- 新特性:
-
jax.scipy.stats.chi2()
现在作为具有 logpdf 和 pdf 方法的分布可用。 -
jax.scipy.stats.betabinom()
现在作为具有 logpmf 和 pmf 方法的分布可用。 - 添加了
jax.experimental.jax2tf.call_tf()
以从 JAX 调用 TensorFlow 函数(#5627)和README。 - 扩展了
lax.pad
的批处理规则以支持填充值的批处理。
-
- Bug 修复:
jax.numpy.take()
正确处理负索引(#5768)
- 破坏性变更:
- 调整了 JAX 的提升规则,使提升更一致且不受 JIT 影响。特别是,当适当时,二进制操作现在可以产生弱类型值。更改的主要用户可见效果是某些操作的输出精度与之前不同;例如表达式
jnp.bfloat16(1) 0.1 * jnp.arange(10)
以前返回float64
数组,现在返回bfloat16
数组。JAX 的类型提升行为在类型提升语义中描述。 -
jax.numpy.linspace()
现在计算整数值的地板,即向负无穷取整,而不是向 0 取整。此更改是为了与 NumPy 1.20.0 保持一致。 -
jax.numpy.i0()
不再接受复数。之前该函数计算复数参数的绝对值。此更改是为了与 NumPy 1.20.0 的语义保持一致。 - 几个
jax.numpy
函数不再接受元组或列表作为数组参数的替代:jax.numpy.pad()
,jax.numpy.ravel
,jax.numpy.repeat()
,jax.numpy.reshape()
。通常情况下,应使用标量或数组参数调用jax.numpy
函数。
- 调整了 JAX 的提升规则,使提升更一致且不受 JIT 影响。特别是,当适当时,二进制操作现在可以产生弱类型值。更改的主要用户可见效果是某些操作的输出精度与之前不同;例如表达式
jaxlib 0.1.62 (2021 年 3 月 9 日)
- 新特性:
- 在 x86-64 机器上,默认情况下构建 jaxlib wheels 需要 AVX 指令。如果要在不支持 AVX 的机器上使用 JAX,可以使用
build.py
的--target_cpu_features
标志从源代码构建 jaxlib。--target_cpu_features
还替换了--enable_march_native
。
- 在 x86-64 机器上,默认情况下构建 jaxlib wheels 需要 AVX 指令。如果要在不支持 AVX 的机器上使用 JAX,可以使用
jaxlib 0.1.61 (2021 年 2 月 12 日)
jaxlib 0.1.60 (2021 年 2 月 3 日)
- 错误修复:
- 修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏问题。在 jaxlib 发布的 0.1.58 和 0.1.59 版本中存在该内存泄漏。
-
bool
,int8
和uint8
现在被认为是安全的,可以转换为bfloat16
NumPy 扩展类型。
jax 0.2.9 (2021 年 1 月 26 日)
- GitHub 提交记录.
- 新特性:
- 扩展
jax.experimental.loops
模块以支持 pytrees。改进了错误检查和错误消息。 - 添加
jax.experimental.enable_x64()
和jax.experimental.disable_x64()
。这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
- 扩展
- 破坏性变更:
jax.ops.segment_sum()
现在在性能考虑下删除超出范围的段 ID,而不是将它们包装到段 ID 空间。
jaxlib 0.1.59 (2021 年 1 月 15 日)
jax 0.2.8 (2021 年 1 月 12 日)
- GitHub 提交记录.
- 新特性:
- 添加
jax.closure_convert()
用于与高阶自定义导数函数一起使用。 (#5244) - 添加
jax.experimental.host_callback.call()
以调用主机上的自定义 Python 函数并将结果返回到设备计算中。 (#5243)
- 添加
- 错误修复:
-
jax.numpy.arccosh
现在对复数输入返回与numpy.arccosh
相同的分支(#5156)。 - 现在
host_callback.id_tap
在jax.pmap
中也可以使用。对于id_tap
和id_print
,现在有一个可选参数,可以请求将值从中提取的设备作为关键字参数传递给 tap 函数(#5182)。
-
- 破坏性更改:
-
jax.numpy.pad
现在接受关键字参数。位置参数constant_values
已被移除。此外,传递不受支持的关键字参数将引发错误。 -
jax.experimental.host_callback.id_tap()
的更改(#5243):- 删除了对
jax.experimental.host_callback.id_tap()
的kwargs
支持(这种支持已经被弃用几个月了)。 - 更改了
jax.experimental.host_callback.id_print()
中元组的打印方式,使用了(
而不是‘‘
。 - 在
jax.experimental.host_callback.id_print()
存在 JVP 的情况下,更改了打印元组的方式,现在使用了一对主元和切线。以前是分别打印主元和切线。 - 删除了
host_callback.outfeed_receiver
(这不再需要,并且几个月前已被弃用)。
- 删除了对
-
- 新功能:
- 为
inf
的调试添加了一个新标志,类似于NaN
的标志(#5224)。
- 为
jax 0.2.7(2020 年 12 月 4 日)
- GitHub 提交。
- 新功能:
- 添加了
jax.device_put_replicated
。 - 向
jax.experimental.sharded_jit
添加了多主机支持。 - 增加对
jax.numpy.linalg.eig
计算的特征值的微分支持。 - 增加了对在 Windows 平台上构建的支持。
- 在
jax.pmap
中添加了对通用in_axes
和out_axes
的支持。 - 添加了对
jax.numpy.linalg.slogdet
的复数支持。
- 添加了
- Bug 修复:
- 修复
jax.numpy.sinc
在零点处高于二阶导数的问题。 - 修复了在转置规则中的符号零的一些难以命中的 bug。
- 修复
- 破坏性更改:
- 已删除
jax.experimental.optix
,改为独立的optax
Python 包。 - 使用非元组序列索引 JAX 数组现在会引发
TypeError
。这种类型的索引自从 Numpy v1.16 和 JAX v0.2.4 开始已经被弃用。参见 #4564。
- 已删除
jax 0.2.6(2020 年 11 月 18 日)
GitHub 提交。
新功能:
- 为
jax.experimental.jax2tf
转换器的形状多态跟踪添加了支持。参见 README.md。
破坏性更改清理:
对于 jax.jit
和 xla_computation
中的非可哈希静态参数,现在会引发错误。参见 cb48f42。
改善了类型提升行为的一致性(#4744):
- 将复杂的 Python 标量添加到 JAX 浮点数会保留 JAX 浮点数的精度。例如,
jnp.float32(1) 1j
现在返回complex64
,而之前返回的是complex128
。 - 当涉及到包含
uint64
、有符号整型和第三种类型的三个或更多术语的类型提升时,现在与参数顺序无关。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
都返回float16
,之前第一个返回float64
,第二个返回float16
。
(未记录的) jax.lax_linalg
线性代数模块现在公开为 jax.lax.linalg
。
jax.random.PRNGKey
现在在 JIT 编译内外产生相同的结果 (#4877)。这需要在几个特定情况下更改给定种子的结果:
- 使用
jax_enable_x64=False
时,作为 Python 整数传递的负数种子现在在 JIT 模式外返回不同的结果。例如,jax.random.PRNGKey(-1)
以前返回[4294967295, 4294967295]
,现在返回[0, 4294967295]
。这与 JIT 中的行为一致。 - JIT 外部的
int64
不能表示的范围外的种子现在会导致OverflowError
而不是TypeError
。这与 JIT 中的行为一致。
要恢复在 jax_enable_x64=False
时以前针对负整数返回的键,可以使用:
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
当尝试访问已删除其值的 DeviceArray
时,现在会引发 RuntimeError
而不是 ValueError
。
jaxlib 0.1.58 (2021 年 1 月 12 日)
- 修复了 JAX 有时返回平台特定类型(如
np.cint
)而不是标准类型(如np.int32
)的 Bug (#4903)。 - 修复了在执行某些 int16 操作时常量折叠导致崩溃的问题 (#4971)。
- 在
pytree.flatten()
中添加了一个is_leaf
谓词。
jaxlib 0.1.57 (2020 年 11 月 12 日)
- 修复了 GPU wheels 中的 manylinux2010 兼容性问题。
- 将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
- 修复了 bfloat16 值哈希未正确初始化并可能更改的 Bug (#4651)。
- 添加了对将数组传递给 DLPack 时保留所有权的支持 (#4636)。
- 修复了批量三角求解的一个 Bug,对大于 128 但不是 128 的倍数的情况。
- 修复了在多个 GPU 上同时进行并发 FFT 时的 Bug (#3518)。
- 在分析器中修复了工具缺失的 Bug (#4427)。
- 放弃了对 CUDA 10.0 的支持。
jax 0.2.5 (2020 年 10 月 27 日)
- GitHub 提交记录。
- 改进:
- 确保
check_jaxpr
不执行 FLOPS。参见 #4650。 - 扩展了由 jax2tf 转换的 JAX 原语集。参见 primitives_with_limited_support.md。
- 确保
jax 0.2.4 (2020 年 10 月 19 日)
- GitHub 提交记录。
- 改进:
- 为
jax.experimental.host_callback
添加了对remat
的支持。参见 #4608。
- 为
- 弃用
- 现在,使用非元组序列进行索引已被弃用,遵循 Numpy 中的类似弃用。在将来的版本中,这将导致 TypeError。参见 #4564。
jaxlib 0.1.56 (2020 年 10 月 14 日)。
jax 0.2.3 (2020 年 10 月 14 日)。
- GitHub 提交记录。
- 由于需要暂时回退新的 jit 快速通路,因此又进行了一个新的发布。
jax 0.2.2 (2020 年 10 月 13 日)。
- GitHub 提交记录。
jax 0.2.1 (2020 年 10 月 6 日)。
- GitHub 提交记录。
- 改进:
- 作为全阶段的一个好处,即使
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
的结果未在计算中使用,也会按程序顺序执行 host_callback 函数。
- 作为全阶段的一个好处,即使
jax (0.2.0) (2020 年 9 月 23 日)。
- GitHub 提交记录。
- 改进:
- 默认情况下启用全阶段。参见 #3370 和 omnistaging。
jax (0.1.77) (2020 年 9 月 15 日)。
- 破坏性变更:
jax.experimental.host_callback.id_tap()
的新简化接口 (#4101)。
jaxlib 0.1.55 (2020 年 9 月 8 日)。
- 更新 XLA:
- 修复 DLPackManagedTensorToBuffer 中的错误 (#4196)。
jax 0.1.76 (2020 年 9 月 8 日)。
- GitHub 提交记录。
jax 0.1.75 (2020 年 7 月 30 日)。
- GitHub 提交记录。
- Bug 修复:
- 使 jnp.abs() 适用于无符号输入 (#3914)。
- 改进:
- 添加了“全阶段”行为,但在默认情况下已禁用 (#3370)。
jax 0.1.74 (2020 年 7 月 29 日)。
- GitHub 提交记录。
- 新功能:
- BFGS (#3101)。
- TPU 支持半精度算术 (#3878)。
- Bug 修复:
- 防止一些意外的 dtype 警告 (#3874)。
- 修复自定义导数中的多线程错误 (#3845, #3869)。
- 改进:
- 更快的 searchsorted 实现 (#3873)。
- 为 jax.numpy 排序算法提供更好的测试覆盖率 (#3836)。
jaxlib 0.1.52 (2020 年 7 月 22 日)。
- 更新 XLA。
jax 0.1.73 (2020 年 7 月 22 日)。
- GitHub 提交记录。
- jaxlib 的最低版本现在是 0.1.51。
- 新功能:
- jax.image.resize. (#3703)。
- hfft 和 ihfft (#3664)。
- jax.numpy.intersect1d (#3726)。
- jax.numpy.lexsort (#3812)。
- 当降低到 XLA 时,
lax.scan
和scan
原语支持一个unroll
参数用于循环展开 (#3738)。
- Bug 修复:
- 修复重复轴错误的约简 (#3618)。
- 修复 lax.pad 对输入维度大小为 0 的形状规则错误。 (#3608)。
- 使 psum 转置处理零余切 (#3653)。
- 修复在尺寸为 0 的轴上进行 reduce-prod 的 JVP 的形状错误 (#3729)。
- 支持通过 jax.lax.all_to_all 进行微分。
- 解决了 jax.scipy.special.zeta 中的 nan 问题。(#3777)
- 改进:
- 对 jax2tf 进行了许多改进。
- 重新实现了使用单次变量减少的 argmin/argmax。(#3611)
- 默认启用 XLA SPMD 分区。(#3151)
- 支持 0d 转置卷积。(#3643)
- 使低秩矩阵的 LU 梯度工作。
- 支持 jet 中的多结果和自定义 JVPs。
- 通用化了 reduce-window 的填充,支持(lo, hi)对。(#3728)
- 在 CPU 和 GPU 上实现复杂卷积。(#3735)
- 使 jnp.take 在空数组的空切片上工作。(#3751)
- 放宽了 dot_general 的维度排序规则。(#3778)
- 启用 GPU 的缓冲捐赠。(#3800)
- 为减少窗口操作添加了基本扩张和窗口扩张支持…(#3803)
jaxlib 0.1.51(2020 年 7 月 2 日)
- 更新 XLA。
- 添加了对 host_callback 的新运行时支持。
jax 0.1.72(2020 年 6 月 28 日)
- GitHub 提交记录。
- Bug 修复:
- 修复了前一个版本中引入的 odeint Bug,见 #3587。
jax 0.1.71(2020 年 6 月 25 日)
- GitHub 提交记录。
- 现在的 jaxlib 最低版本要求是 0.1.48。
- Bug 修复:
- 允许
jax.experimental.ode.odeint
动态函数在我们对其进行微分的值上进行闭包 #3562。
- 允许
jaxlib 0.1.50(2020 年 6 月 25 日)
- 增加了对 CUDA 11.0 的支持。
- 放弃对 CUDA 9.2 的支持(我们只支持最后四个 CUDA 版本)。
- 更新 XLA。
jaxlib 0.1.49(2020 年 6 月 19 日)
- Bug 修复:
- 修复了编译问题,可能导致编译速度慢(tensorflow/tensorflow)。
jaxlib 0.1.48(2020 年 6 月 12 日)
- 新特性:
- 增加了快速回溯收集的支持。
- 增加了对设备堆分析的初步支持。
- 为
bfloat16
类型实现了np.nextafter
。 - CPU 和 GPU 上的 Complex128 支持 FFT。
- Bug 修复:
- 改进了在 GPU 上
tanh
的 float64 精度。 - GPU 上的 float64 散布现在更快了。
- 在 CPU 上的复杂矩阵乘法应该更快了。
- CPU 上的稳定排序现在实际上是稳定的了。
- CPU 后端的并发 Bug 修复。
- 改进了在 GPU 上
jax 0.1.70(2020 年 6 月 8 日)
- GitHub 提交记录。
- 新特性:
lax.switch
引入了带有多分支的索引条件,并与cond
原语的泛化一起使用 #3318。
jax 0.1.69(2020 年 6 月 3 日)
- GitHub 提交记录。
jax 0.1.68(2020 年 5 月 21 日)
- GitHub 提交记录。
- 新特性:
lax.cond()
支持单操作数形式,作为两个分支的参数 #2993。
- 注意事项改动:
jax.experimental.host_callback.id_tap()
原语的transforms
关键字格式已更改 #3132。
jax 0.1.67(2020 年 5 月 12 日)
- GitHub 提交记录。
- 新功能:
- 支持使用
axis_index_groups
对 pmapped 轴的子集进行缩减 #2382。 - 实验性支持从编译代码调用和打印主机端 Python 函数。参见 id_print 和 id_tap(#3006)。
- 支持使用
- 显著变更:
- 从
jax.numpy
导出的名称的可见性已加强。这可能会破坏之前无意中使用这些名称的代码。
- 从
jaxlib 0.1.47(2020 年 5 月 8 日)
- 修复 outfeed 引起的崩溃。
jax 0.1.66(2020 年 5 月 5 日)
- GitHub 提交记录。
- 新功能:
- 支持在
pmap()
上使用in_axes=None
进行缩减 #2896。
- 支持在
jaxlib 0.1.46(2020 年 5 月 5 日)
- 修复 Mac OS X 上线性代数函数的崩溃(#432)。
- 修复使用 AVX512 指令时因操作系统或虚拟化程序禁用而导致的非法指令崩溃问题(#2906)。
jax 0.1.65(2020 年 4 月 30 日)
- GitHub 提交记录。
- 新功能:
- 对奇异矩阵行列式的微分 #2809。
- Bug 修复:
- 修复
odeint()
对于具有时间依赖动态的常微分方程的时间微分问题 #2817,并添加 ODE CI 测试。 - 修复
lax_linalg.qr()
的微分问题 #2867。
- 修复
jaxlib 0.1.45(2020 年 4 月 21 日)
- 修复段错误:#2755
- 在 Sort HLO 上通过 Plumb 选项支持稳定性。
jax 0.1.64(2020 年 4 月 21 日)
- GitHub 提交记录。
- 新功能:
- 添加函数式索引更新的语法糖 #2684。
- 添加
jax.numpy.linalg.multi_dot()
#2726。 - 添加
jax.numpy.unique()
#2760。 - 添加
jax.numpy.rint()
#2724。 - 添加
jax.numpy.rint()
#2724。 - 为
jax.experimental.jet()
添加更多原始规则。
- Bug 修复:
- 修复
logaddexp()
和logaddexp2()
在零处的微分问题 #2107。 - 在没有
jit()
的情况下改进反向模式自动微分的内存使用情况 #2719。
- 修复
- 更好的错误修复:
- 改进
lax.while_loop()
的反向模式微分的错误消息 #2129。
- 改进
jaxlib 0.1.44(2020 年 4 月 16 日)
- 修复了一个 bug,即当存在多个不同型号的 GPU 时,JAX 只会编译适用于第一个 GPU 的程序。
- 修复了
batch_group_count
卷积的错误。 - 为更多 GPU 版本添加了预编译的 SASS,以避免启动时 PTX 编译挂起。
jax 0.1.63 (2020 年 4 月 12 日)
- GitHub 提交记录。
- 添加了
jax.custom_jvp
和jax.custom_vjp
,来源于 #2026,请参阅教程笔记本。弃用了jax.custom_transforms
并将其从文档中删除(尽管它仍然可用)。 - 添加了
scipy.sparse.linalg.cg
#2566。 - 更改了 Tracers 的打印方式,以显示更多有用的调试信息 #2591。
- 修复了
jax.numpy.isclose
正确处理nan
和inf
的方式 #2501。 - 添加了几个
jax.experimental.jet
的新规则 #2537。 - 当未提供
scale
/center
时,修复了jax.experimental.stax.BatchNorm
。 - 修复了
jax.numpy.einsum
中广播的一些缺失情况 #2512。 - 通过并行前缀扫描实现了
jax.numpy.cumsum
和jax.numpy.cumprod
,并使reduce_prod
对任意阶数可微分 #2596 #2597。 - 在
conv_general_dilated
中添加了batch_group_count
#2635。 - 为
test_util.check_grads
添加了文档字符串 #2656。 - 添加了
callback_transform
#2665。 - 实现了
rollaxis
、convolve
/correlate
的 1 维和 2 维、copysign
、trunc
、roots
以及quantile
/percentile
的插值选项。
jaxlib 0.1.43 (2020 年 3 月 31 日)
- 修复了 GPU 上 Resnet-50 的性能回归问题。
jax 0.1.62 (2020 年 3 月 21 日)
- GitHub 提交记录。
- JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更新版本。
- 删除了内部函数
lax._safe_mul
,该函数实现了约定0. * nan == 0.
。此更改意味着在某些程序被微分时会产生 nan,而不是以前产生正确值,尽管这确保了对其他程序产生 nan 而不是静默的不正确结果。详见 #2447 和 #1052。 - 添加了一个
all_gather
并行便利函数。 - 在核心代码中增加了更多类型注解。
jaxlib 0.1.42 (2020 年 3 月 19 日)
- jaxlib 0.1.41 由于 API 不兼容性破坏了云 TPU 支持。此版本修复了这个问题。
- JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更新版本。
jax 0.1.61 (2020 年 3 月 17 日)
- GitHub 提交记录。
- 修复 Python 3.5 支持。这将是 JAX 或 jaxlib 版本的最后一个支持 Python 3.5 的版本。
jax 0.1.60(2020 年 3 月 17 日)
- GitHub 提交。
- 新功能:
-
jax.pmap()
增加了static_broadcast_argnums
参数,该参数允许用户指定应该作为编译时常数处理的参数,并应广播到所有设备。它类似于jax.jit()
中的static_argnums
。 - 改善了错误消息,以防止错误地在全局状态中保存跟踪器。
- 添加了
jax.nn.one_hot()
实用函数。 - 添加了
jax.experimental.jet
,用于更快的高阶自动微分。 - 对
jax.lax.broadcast_in_dim()
的参数进行了更多正确性检查。
-
- 最小 jaxlib 版本现已是 0.1.41。
jaxlib 0.1.40(2020 年 3 月 4 日)
- 添加了 Jaxlib 对 TensorFlow 分析仪的实验性支持,该分析仪允许从 TensorBoard 跟踪 CPU 和 GPU 计算。
- 包括多主机 GPU 计算支持的原型,该计算通过 NCCL 通信。
- 改善了在 GPU 上的 NCCL 集合性能。
- 添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。
- 支持在 XLA 编译时已知的设备分配。
jax 0.1.59(2020 年 2 月 11 日)
- GitHub 提交。
- 重大更改
- 最小 jaxlib 版本现已是 0.1.38。
- 简化
Jaxpr
,通过删除Jaxpr.freevars
和Jaxpr.bound_subjaxprs
。调用基本功能(xla_call
、xla_pmap
、sharded_call
和remat_call
)获取一个新的参数call_jaxpr
,它具有一个完全闭合(无constvars
)的 jaxpr。此外,还添加了一个新的字段call_primitive
到基本功能。
- 新功能:
- 反向模式自动微分(例如
grad
)对lax.cond
的支持,使其在两种模式下都可微分(#2091) - JAX 现在支持 DLPack,它允许以零副本方式共享 CPU 和 GPU 数组与其他库(例如 PyTorch)。
- JAX GPU DeviceArrays 现在支持
__cuda_array_interface__
,这是另一种用于与 CuPy 和 Numba 等库共享 GPU 数组的零副本协议。 - JAX 的 CPU 设备缓冲区现在实现了 Python 缓冲区协议,这允许 JAX 和 NumPy 之间的零副本缓冲区共享。
- 添加了名为 JAX_SKIP_SLOW_TESTS 的环境变量,以跳过已知为慢的测试。
- 反向模式自动微分(例如
jaxlib 0.1.39(2020 年 2 月 11 日)
- 更新 XLA。
jaxlib 0.1.38(2020 年 1 月 29 日)
- 不再支持 CUDA 9.0。
- 默认构建 CUDA 10.2 的轮。
jax 0.1.58(2020 年 1 月 28 日)
- [GitHub GitHub 提交。
- 重大更改
- JAX 已弃用对 Python 2 的支持,因为 Python 2 于 2020 年 1 月 1 日达到生命周期结束。请更新到 Python 3.5 或更新版本。
- 新功能
- 正向模式自动微分(
jvp
)对 while 循环的支持(#1980) - 新的 NumPy 和 SciPy 功能:
-
jax.numpy.fft.fft2()
jax.numpy.fft.ifft2()
jax.numpy.fft.rfft()
jax.numpy.fft.irfft()
jax.numpy.fft.rfft2()
jax.numpy.fft.irfft2()
jax.numpy.fft.rfftn()
jax.numpy.fft.irfftn()
jax.numpy.fft.fftfreq()
jax.numpy.fft.rfftfreq()
jax.numpy.linalg.matrix_rank()
jax.numpy.linalg.matrix_power()
jax.scipy.special.betainc()
- 现在在 GPU 上进行批次 Cholesky 分解时使用了更高效的批次核心。
- 正向模式自动微分(
显著的错误修复
- 使用 Python 3 升级后,JAX 不再依赖于
fastcache
,这应该有助于安装。