JAX 中文文档(十四)

2024-06-22 08:45:44 浏览数 (1)

原文:jax.readthedocs.io/en/latest/

jax.scipy 模块

原文:jax.readthedocs.io/en/latest/jax.scipy.html

jax.scipy.cluster

| vq(obs, code_book[, check_finite]) | 将观测值分配给代码簿中的代码。 | ## jax.scipy.fft

dct(x[, type, n, axis, norm])

计算输入的离散余弦变换

dctn(x[, type, s, axes, norm])

计算输入的多维离散余弦变换

idct(x[, type, n, axis, norm])

计算输入的离散余弦变换的逆变换

| idctn(x[, type, s, axes, norm]) | 计算输入的多维离散余弦变换的逆变换 | ## jax.scipy.integrate

| trapezoid(y[, x, dx, axis]) | 使用复合梯形法则沿指定轴积分。 | ## jax.scipy.interpolate

| RegularGridInterpolator(points, values[, …]) | 对正规矩形网格上的点进行插值。 | ## jax.scipy.linalg

block_diag(*arrs)

从输入数组创建块对角矩阵。

cho_factor(a[, lower, overwrite_a, check_finite])

基于 Cholesky 的线性求解因式分解

cho_solve(c_and_lower, b[, overwrite_b, …])

使用 Cholesky 分解解线性系统

cholesky(a[, lower, overwrite_a, check_finite])

计算矩阵的 Cholesky 分解。

det(a[, overwrite_a, check_finite])

计算矩阵的行列式

eigh()

计算 Hermitian 矩阵的特征值和特征向量

eigh_tridiagonal(d, e, *[, eigvals_only, …])

解对称实三对角矩阵的特征值问题

expm(A, *[, upper_triangular, max_squarings])

计算矩阵指数

expm_frechet()

计算矩阵指数的 Frechet 导数

funm(A, func[, disp])

评估矩阵值函数

hessenberg()

计算矩阵的 Hessenberg 形式

hilbert(n)

创建阶数为 n 的 Hilbert 矩阵。

inv(a[, overwrite_a, check_finite])

返回方阵的逆矩阵

lu()

计算 LU 分解

lu_factor(a[, overwrite_a, check_finite])

基于 LU 的线性求解因式分解

lu_solve(lu_and_piv, b[, trans, …])

使用 LU 分解解线性系统

polar(a[, side, method, eps, max_iterations])

计算极分解

qr()

计算数组的 QR 分解

rsf2csf(T, Z[, check_finite])

将实数舒尔形式转换为复数舒尔形式。

schur(a[, output])

计算舒尔分解

solve(a, b[, lower, overwrite_a, …])

解线性方程组

solve_triangular(a, b[, trans, lower, …])

解上(或下)三角线性方程组

sqrtm(A[, blocksize])

计算矩阵的平方根

svd()

计算奇异值分解

| toeplitz(c[, r]) | 构造 Toeplitz 矩阵 | ## jax.scipy.ndimage

| map_coordinates(input, coordinates, order[, …]) | 使用插值将输入数组映射到新坐标。 | ## jax.scipy.optimize

minimize(fun, x0[, args, tol, options])

最小化一个或多个变量的标量函数。

| OptimizeResults(x, success, status, fun, …) | 优化结果对象。 | ## jax.scipy.signal

fftconvolve(in1, in2[, mode, axes])

使用快速傅里叶变换(FFT)卷积两个 N 维数组。

convolve(in1, in2[, mode, method, precision])

两个 N 维数组的卷积。

convolve2d(in1, in2[, mode, boundary, …])

两个二维数组的卷积。

correlate(in1, in2[, mode, method, precision])

两个 N 维数组的互相关。

correlate2d(in1, in2[, mode, boundary, …])

两个二维数组的互相关。

csd(x, y[, fs, window, nperseg, noverlap, …])

使用 Welch 方法估计交叉功率谱密度(CSD)。

detrend(data[, axis, type, bp, overwrite_data])

从数据中移除线性或分段线性趋势。

istft(Zxx[, fs, window, nperseg, noverlap, …])

执行逆短时傅里叶变换(ISTFT)。

stft(x[, fs, window, nperseg, noverlap, …])

计算短时傅里叶变换(STFT)。

| welch(x[, fs, window, nperseg, noverlap, …]) | 使用 Welch 方法估计功率谱密度(PSD)。 | ## jax.scipy.spatial.transform

Rotation(quat)

三维旋转。

| Slerp(times, timedelta, rotations, rotvecs) | 球面线性插值旋转。 | ## jax.scipy.sparse.linalg

bicgstab(A, b[, x0, tol, atol, maxiter, M])

使用双共轭梯度稳定迭代解决 Ax = b。

cg(A, b[, x0, tol, atol, maxiter, M])

使用共轭梯度法解决 Ax = b。

| gmres(A, b[, x0, tol, atol, restart, …]) | GMRES 解决线性系统 A x = b,给定 A 和 b。 | ## jax.scipy.special

bernoulli(n)

生成前 N 个伯努利数。

beta()

贝塔函数

betainc(a, b, x)

正则化的不完全贝塔函数。

betaln(a, b)

贝塔函数绝对值的自然对数

digamma(x)

Digamma 函数

entr(x)

熵函数

erf(x)

误差函数

erfc(x)

误差函数的补函数

erfinv(x)

误差函数的反函数

exp1(x)

指数积分函数。

expi

指数积分函数。

expit(x)

逻辑 sigmoid(expit)函数

expn

广义指数积分函数。

factorial(n[, exact])

阶乘函数

gamma(x)

伽马函数。

gammainc(a, x)

正则化的下不完全伽马函数。

gammaincc(a, x)

正则化的上不完全伽马函数。

gammaln(x)

伽马函数绝对值的自然对数。

gammasgn(x)

伽马函数的符号。

hyp1f1

1F1 超几何函数。

i0(x)

修改贝塞尔函数零阶。

i0e(x)

指数缩放的修改贝塞尔函数零阶。

i1(x)

修改贝塞尔函数一阶。

i1e(x)

指数缩放的修改贝塞尔函数一阶。

log_ndtr

对数正态分布函数。

logit

对数几率函数。

logsumexp()

对数-总和-指数归约。

lpmn(m, n, z)

第一类相关勒让德函数(ALFs)。

lpmn_values(m, n, z, is_normalized)

第一类相关勒让德函数(ALFs)。

multigammaln(a, d)

多变量伽马函数的自然对数。

ndtr(x)

正态分布函数。

ndtri§

正态分布函数的反函数。

poch

Pochhammer 符号。

polygamma(n, x)

多次伽马函数。

spence(x)

斯宾斯函数,也称实数域下的二元对数函数。

sph_harm(m, n, theta, phi[, n_max])

计算球谐函数。

xlog1py

计算 x*log(1 y),当 x=0 时返回 0。

xlogy

计算 x*log(y),当 x=0 时返回 0。

zeta

赫维茨 ζ 函数。

kl_div(p, q)

库尔巴克-莱布勒散度。

| rel_entr(p, q) | 相对熵函数。 | ## jax.scipy.stats

mode(a[, axis, nan_policy, keepdims])

计算数组沿轴的众数(最常见的值)。

rankdata(a[, method, axis, nan_policy])

计算数组沿轴的排名。

sem(a[, axis, ddof, nan_policy, keepdims])

计算均值的标准误差。

jax.scipy.stats.bernoulli

logpmf(k, p[, loc])

伯努利对数概率质量函数。

pmf(k, p[, loc])

伯努利概率质量函数。

cdf(k, p)

伯努利累积分布函数。

| ppf(q, p) | 伯努利百分位点函数。 | ### jax.scipy.stats.beta

logpdf(x, a, b[, loc, scale])

Beta 对数概率分布函数。

pdf(x, a, b[, loc, scale])

Beta 概率分布函数。

cdf(x, a, b[, loc, scale])

Beta 累积分布函数。

logcdf(x, a, b[, loc, scale])

Beta 对数累积分布函数。

sf(x, a, b[, loc, scale])

Beta 分布生存函数。

| logsf(x, a, b[, loc, scale]) | Beta 分布对数生存函数。 | ### jax.scipy.stats.betabinom

logpmf(k, n, a, b[, loc])

Beta-二项式对数概率质量函数。

| pmf(k, n, a, b[, loc]) | Beta-二项式概率质量函数。 | ### jax.scipy.stats.binom

logpmf(k, n, p[, loc])

二项式对数概率质量函数。

| pmf(k, n, p[, loc]) | 二项式概率质量函数。 | ### jax.scipy.stats.cauchy

logpdf(x[, loc, scale])

柯西对数概率分布函数。

pdf(x[, loc, scale])

柯西概率分布函数。

cdf(x[, loc, scale])

柯西累积分布函数。

logcdf(x[, loc, scale])

柯西对数累积分布函数。

sf(x[, loc, scale])

柯西分布对数生存函数。

logsf(x[, loc, scale])

柯西对数生存函数。

isf(q[, loc, scale])

柯西分布逆生存函数。

| ppf(q[, loc, scale]) | 柯西分布分位点函数。 | ### jax.scipy.stats.chi2

logpdf(x, df[, loc, scale])

卡方分布对数概率分布函数。

pdf(x, df[, loc, scale])

卡方概率分布函数。

cdf(x, df[, loc, scale])

卡方累积分布函数。

logcdf(x, df[, loc, scale])

卡方对数累积分布函数。

sf(x, df[, loc, scale])

卡方生存函数。

| logsf(x, df[, loc, scale]) | 卡方对数生存函数。 | ### jax.scipy.stats.dirichlet

logpdf(x, alpha)

狄利克雷对数概率分布函数。

| pdf(x, alpha) | 狄利克雷概率分布函数。 | ### jax.scipy.stats.expon

logpdf(x[, loc, scale])

指数对数概率分布函数。

| pdf(x[, loc, scale]) | 指数概率分布函数。 | ### jax.scipy.stats.gamma

logpdf(x, a[, loc, scale])

伽玛对数概率分布函数。

pdf(x, a[, loc, scale])

伽玛概率分布函数。

cdf(x, a[, loc, scale])

伽玛累积分布函数。

logcdf(x, a[, loc, scale])

伽玛对数累积分布函数。

sf(x, a[, loc, scale])

伽玛生存函数。

| logsf(x, a[, loc, scale]) | 伽玛对数生存函数。 | ### jax.scipy.stats.gennorm

cdf(x, beta)

广义正态累积分布函数。

logpdf(x, beta)

广义正态对数概率分布函数。

| pdf(x, beta) | 广义正态概率分布函数。 | ### jax.scipy.stats.geom

logpmf(k, p[, loc])

几何对数概率质量函数。

| pmf(k, p[, loc]) | 几何概率质量函数。 | ### jax.scipy.stats.laplace

cdf(x[, loc, scale])

拉普拉斯累积分布函数。

logpdf(x[, loc, scale])

拉普拉斯对数概率分布函数。

| pdf(x[, loc, scale]) | 拉普拉斯概率分布函数。 | ### jax.scipy.stats.logistic

cdf(x[, loc, scale])

Logistic 累积分布函数。

isf(x[, loc, scale])

Logistic 分布逆生存函数。

logpdf(x[, loc, scale])

Logistic 对数概率分布函数。

pdf(x[, loc, scale])

Logistic 概率分布函数。

ppf(x[, loc, scale])

Logistic 分位点函数。

| sf(x[, loc, scale]) | Logistic 分布生存函数。 | ### jax.scipy.stats.multinomial

logpmf(x, n, p)

多项式对数概率质量函数。

| pmf(x, n, p) | 多项分布概率质量函数。 | ### jax.scipy.stats.multivariate_normal

logpdf(x, mean, cov[, allow_singular])

多元正态分布对数概率分布函数。

| pdf(x, mean, cov) | 多元正态分布概率分布函数。 | ### jax.scipy.stats.nbinom

logpmf(k, n, p[, loc])

负二项分布对数概率质量函数。

| pmf(k, n, p[, loc]) | 负二项分布概率质量函数。 | ### jax.scipy.stats.norm

logpdf(x[, loc, scale])

正态分布对数概率分布函数。

pdf(x[, loc, scale])

正态分布概率分布函数。

cdf(x[, loc, scale])

正态分布累积分布函数。

logcdf(x[, loc, scale])

正态分布对数累积分布函数。

ppf(q[, loc, scale])

正态分布百分点函数。

sf(x[, loc, scale])

正态分布生存函数。

logsf(x[, loc, scale])

正态分布对数生存函数。

| isf(q[, loc, scale]) | 正态分布逆生存函数。 | ### jax.scipy.stats.pareto

logpdf(x, b[, loc, scale])

帕累托对数概率分布函数。

| pdf(x, b[, loc, scale]) | 帕累托分布概率分布函数。 | ### jax.scipy.stats.poisson

logpmf(k, mu[, loc])

泊松分布对数概率质量函数。

pmf(k, mu[, loc])

泊松分布概率质量函数。

| cdf(k, mu[, loc]) | 泊松分布累积分布函数。 | ### jax.scipy.stats.t

logpdf(x, df[, loc, scale])

学生 t 分布对数概率分布函数。

| pdf(x, df[, loc, scale]) | 学生 t 分布概率分布函数。 | ### jax.scipy.stats.truncnorm

cdf(x, a, b[, loc, scale])

截断正态分布累积分布函数。

logcdf(x, a, b[, loc, scale])

截断正态分布对数累积分布函数。

logpdf(x, a, b[, loc, scale])

截断正态分布对数概率分布函数。

logsf(x, a, b[, loc, scale])

截断正态分布对数生存函数。

pdf(x, a, b[, loc, scale])

截断正态分布概率分布函数。

| sf(x, a, b[, loc, scale]) | 截断正态分布对数生存函数。 | ### jax.scipy.stats.uniform

logpdf(x[, loc, scale])

均匀分布对数概率分布函数。

pdf(x[, loc, scale])

均匀分布概率分布函数。

cdf(x[, loc, scale])

均匀分布累积分布函数。

ppf(q[, loc, scale])

均匀分布百分点函数。

jax.scipy.stats.gaussian_kde

gaussian_kde(dataset[, bw_method, weights])

高斯核密度估计器

gaussian_kde.evaluate(points)

对给定点评估高斯核密度估计器。

gaussian_kde.integrate_gaussian(mean, cov)

加权高斯积分分布。

gaussian_kde.integrate_box_1d(low, high)

在给定限制下积分分布。

gaussian_kde.integrate_kde(other)

集成两个高斯核密度估计分布的乘积。

gaussian_kde.resample(key[, shape])

从估计的概率密度函数中随机采样数据集

gaussian_kde.pdf(x)

概率密度函数

gaussian_kde.logpdf(x)

对数概率密度函数

jax.scipy.stats.vonmises

logpdf(x, kappa)

von Mises 对数概率分布函数。

| pdf(x, kappa) | von Mises 概率分布函数。 | ### jax.scipy.stats.wrapcauchy

logpdf(x, c)

Wrapped Cauchy 对数概率分布函数。

pdf(x, c)

Wrapped Cauchy 概率分布函数。

jax.scipy.stats.bernoulli.logpmf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.logpmf.html

代码语言:javascript复制
jax.scipy.stats.bernoulli.logpmf(k, p, loc=0)

伯努利对数概率质量函数。

scipy.stats.bernoulli 的 JAX 实现 logpmf

伯努利概率质量函数定义如下

[begin{split}f(k) = begin{cases} 1 - p, & k = 0 p, & k = 1 0, & mathrm{otherwise} end{cases}end{split}]

参数:

  • k (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,要评估 PMF 的值
  • p (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布形状参数
  • loc (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布偏移量

返回值:

logpmf 值的数组

返回类型:

Array

另请参阅

  • jax.scipy.stats.bernoulli.cdf()
  • jax.scipy.stats.bernoulli.pmf()
  • jax.scipy.stats.bernoulli.ppf()

jax.scipy.stats.bernoulli.pmf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.pmf.html

代码语言:javascript复制
jax.scipy.stats.bernoulli.pmf(k, p, loc=0)

伯努利概率质量函数。

scipy.stats.bernoulli pmf 的 JAX 实现

伯努利概率质量函数定义为

[begin{split}f(k) = begin{cases} 1 - p, & k = 0 p, & k = 1 0, & mathrm{otherwise} end{cases}end{split}]

参数:

  • k (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,要评估 PMF 的值
  • p (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,分布形状参数
  • loc (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,分布偏移

返回:

pmf 值数组

返回类型:

数组

参见

  • jax.scipy.stats.bernoulli.cdf()
  • jax.scipy.stats.bernoulli.logpmf()
  • jax.scipy.stats.bernoulli.ppf()

jax.scipy.stats.bernoulli.cdf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.cdf.html

代码语言:javascript复制
jax.scipy.stats.bernoulli.cdf(k, p)

伯努利累积分布函数。

scipy.stats.bernoulli 的 JAX 实现 cdf

伯努利累积分布函数被定义为:

[f_{cdf}(k, p) = sum_{i=0}^k f_{pmf}(k, p)]

其中 (f_{pmf}(k, p)) 是伯努利概率质量函数 jax.scipy.stats.bernoulli.pmf()

参数:

  • k (Array | ndarray | bool | number | bool | int | float | complex) – 数组,用于评估 CDF 的值
  • p (Array | ndarray | bool | number | bool | int | float | complex) – 数组,分布形状参数
  • loc – 数组,分布偏移

返回:

cdf 值的数组

返回类型:

Array

另请参见

  • jax.scipy.stats.bernoulli.logpmf()
  • jax.scipy.stats.bernoulli.pmf()
  • jax.scipy.stats.bernoulli.ppf()

jax.scipy.stats.bernoulli.ppf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.ppf.html

代码语言:javascript复制
jax.scipy.stats.bernoulli.ppf(q, p)

伯努利百分点函数。

JAX 实现的 scipy.stats.bernoulli ppf

百分点函数是累积分布函数的反函数,jax.scipy.stats.bernoulli.cdf()

参数:

  • k – arraylike,评估 PPF 的值
  • p (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布形状参数
  • loc – arraylike,分布偏移
  • q (Array | ndarray | bool | number | bool | int | float | complex)

返回:

ppf 值数组

返回类型:

Array

另见

  • jax.scipy.stats.bernoulli.cdf()
  • jax.scipy.stats.bernoulli.logpmf()
  • jax.scipy.stats.bernoulli.pmf()

jax.lax 模块

原文:jax.readthedocs.io/en/latest/jax.lax.html

jax.lax 是支持诸如 jax.numpy 等库的基本操作的库。通常会定义转换规则,例如 JVP 和批处理规则,作为对 jax.lax 基元的转换。

许多基元都是等价于 XLA 操作的薄包装,详细描述请参阅XLA 操作语义文档。

在可能的情况下,优先使用诸如 jax.numpy 等库,而不是直接使用 jax.laxjax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,更不易更改。

Operators

abs(x)

按元素绝对值:(|x|)。

acos(x)

按元素求反余弦:(mathrm{acos}(x))。

acosh(x)

按元素求反双曲余弦:(mathrm{acosh}(x))。

add(x, y)

按元素加法:(x y)。

after_all(*operands)

合并一个或多个 XLA 令牌值。

approx_max_k(operand, k[, …])

以近似方式返回 operand 的最大 k 值及其索引。

approx_min_k(operand, k[, …])

以近似方式返回 operand 的最小 k 值及其索引。

argmax(operand, axis, index_dtype)

计算沿着 axis 的最大元素的索引。

argmin(operand, axis, index_dtype)

计算沿着 axis 的最小元素的索引。

asin(x)

按元素求反正弦:(mathrm{asin}(x))。

asinh(x)

按元素求反双曲正弦:(mathrm{asinh}(x))。

atan(x)

按元素求反正切:(mathrm{atan}(x))。

atan2(x, y)

两个变量的按元素反正切:(mathrm{atan}({x over y}))。

atanh(x)

按元素求反双曲正切:(mathrm{atanh}(x))。

batch_matmul(lhs, rhs[, precision])

批量矩阵乘法。

bessel_i0e(x)

指数缩放修正贝塞尔函数 (0) 阶:(mathrm{i0e}(x) = e^{-|x|} mathrm{i0}(x))

bessel_i1e(x)

指数缩放修正贝塞尔函数 (1) 阶:(mathrm{i1e}(x) = e^{-|x|} mathrm{i1}(x))

betainc(a, b, x)

按元素的正则化不完全贝塔积分。

bitcast_convert_type(operand, new_dtype)

按元素位转换。

bitwise_and(x, y)

按位与运算:(x wedge y)。

bitwise_not(x)

按位取反:(neg x)。

bitwise_or(x, y)

按位或运算:(x vee y)。

bitwise_xor(x, y)

按位异或运算:(x oplus y)。

population_count(x)

按元素计算 popcount,即每个元素中设置的位数。

broadcast(operand, sizes)

广播数组,添加新的前导维度。

broadcast_in_dim(operand, shape, …)

包装 XLA 的 BroadcastInDim 操作符。

broadcast_shapes()

返回经过 NumPy 广播后的形状。

broadcast_to_rank(x, rank)

添加 1 的前导维度,使 x 的等级为 rank。

broadcasted_iota(dtype, shape, dimension)

iota的便捷封装器。

cbrt(x)

元素级立方根:(sqrt[3]{x})。

ceil(x)

元素级向上取整:(leftlceil x rightrceil)。

clamp(min, x, max)

元素级 clamp 函数。

clz(x)

元素级计算前导零的个数。

collapse(operand, start_dimension[, …])

将数组的维度折叠为单个维度。

complex(x, y)

元素级构造复数:(x jy)。

concatenate(operands, dimension)

沿指定维度连接一系列数组。

conj(x)

元素级复数的共轭函数:(overline{x})。

conv(lhs, rhs, window_strides, padding[, …])

conv_general_dilated的便捷封装器。

convert_element_type(operand, new_dtype)

元素级类型转换。

conv_dimension_numbers(lhs_shape, rhs_shape, …)

将卷积维度编号转换为 ConvDimensionNumbers。

conv_general_dilated(lhs, rhs, …[, …])

带有可选扩展的通用 n 维卷积运算符。

conv_general_dilated_local(lhs, rhs, …[, …])

带有可选扩展的通用 n 维非共享卷积运算符。

conv_general_dilated_patches(lhs, …[, …])

提取符合 conv_general_dilated 接受域的补丁。

conv_transpose(lhs, rhs, strides, padding[, …])

计算 N 维卷积的“转置”的便捷封装器。

conv_with_general_padding(lhs, rhs, …[, …])

conv_general_dilated的便捷封装器。

cos(x)

元素级余弦函数:(mathrm{cos}(x))。

cosh(x)

元素级双曲余弦函数:(mathrm{cosh}(x))。

cumlogsumexp(operand[, axis, reverse])

沿轴计算累积 logsumexp。

cummax(operand[, axis, reverse])

沿轴计算累积最大值。

cummin(operand[, axis, reverse])

沿轴计算累积最小值。

cumprod(operand[, axis, reverse])

沿轴计算累积乘积。

cumsum(operand[, axis, reverse])

沿轴计算累积和。

digamma(x)

元素级 digamma 函数:(psi(x))。

div(x, y)

元素级除法:(x over y)。

dot(lhs, rhs[, precision, …])

向量/向量,矩阵/向量和矩阵/矩阵乘法。

dot_general(lhs, rhs, dimension_numbers[, …])

通用的点积/收缩运算符。

dynamic_index_in_dim(operand, index[, axis, …])

对 dynamic_slice 的便捷封装,用于执行整数索引。

dynamic_slice(operand, start_indices, …)

封装了 XLA 的 DynamicSlice 操作符。

dynamic_slice_in_dim(operand, start_index, …)

方便地封装了应用于单个维度的 lax.dynamic_slice()。

dynamic_update_index_in_dim(operand, update, …)

方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新大小为 1 的切片。

dynamic_update_slice(operand, update, …)

封装了 XLA 的 DynamicUpdateSlice 操作符。

dynamic_update_slice_in_dim(operand, update, …)

方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新一个切片。

eq(x, y)

元素级相等:(x = y)。

erf(x)

元素级误差函数:(mathrm{erf}(x))。

erfc(x)

元素级补充误差函数:(mathrm{erfc}(x) = 1 - mathrm{erf}(x))。

erf_inv(x)

元素级反误差函数:(mathrm{erf}^{-1}(x))。

exp(x)

元素级指数函数:(e^x)。

expand_dims(array, dimensions)

将任意数量的大小为 1 的维度插入到数组中。

expm1(x)

元素级运算 (e^{x} - 1)。

fft(x, fft_type, fft_lengths)

floor(x)

元素级向下取整:(leftlfloor x rightrfloor)。

full(shape, fill_value[, dtype, sharding])

返回填充值为 fill_value 的形状数组。

full_like(x, fill_value[, dtype, shape, …])

基于示例数组 x 创建类似于 np.full 的完整数组。

gather(operand, start_indices, …[, …])

Gather 操作符。

ge(x, y)

元素级大于或等于:(x geq y)。

gt(x, y)

元素级大于:(x > y)。

igamma(a, x)

元素级正则化不完全 gamma 函数。

igammac(a, x)

元素级补充正则化不完全 gamma 函数。

imag(x)

提取复数的虚部:(mathrm{Im}(x))。

index_in_dim(operand, index[, axis, keepdims])

方便地封装了 lax.slice(),用于执行整数索引。

index_take(src, idxs, axes)

integer_pow(x, y)

元素级幂运算:(x^y),其中 (y) 是固定整数。

iota(dtype, size)

封装了 XLA 的 Iota 操作符。

is_finite(x)

元素级 (mathrm{isfinite})。

le(x, y)

元素级小于或等于:(x leq y)。

lgamma(x)

元素级对数 gamma 函数:(mathrm{log}(Gamma(x)))。

log(x)

元素级自然对数:(mathrm{log}(x))。

log1p(x)

元素级 (mathrm{log}(1 x))。

logistic(x)

元素级 logistic(sigmoid)函数:(frac{1}{1 e^{-x}})。

lt(x, y)

元素级小于:(x < y)。

max(x, y)

元素级最大值:(mathrm{max}(x, y))

min(x, y)

元素级最小值:(mathrm{min}(x, y))

mul(x, y)

元素级乘法:(x times y)。

ne(x, y)

按位不等于:(x neq y)。

neg(x)

按位取负:(-x)。

nextafter(x1, x2)

返回 x1 在 x2 方向上的下一个可表示的值。

pad(operand, padding_value, padding_config)

对数组应用低、高和/或内部填充。

polygamma(m, x)

按位多次 gamma 函数:(psi^{(m)}(x))。

population_count(x)

按位人口统计,统计每个元素中设置的位数。

pow(x, y)

按位幂运算:(x^y)。

random_gamma_grad(a, x)

Gamma 分布导数的按位计算。

real(x)

按位提取实部:(mathrm{Re}(x))。

reciprocal(x)

按位倒数:(1 over x)。

reduce(operands, init_values, computation, …)

封装了 XLA 的 Reduce 运算符。

reduce_precision(operand, exponent_bits, …)

封装了 XLA 的 ReducePrecision 运算符。

reduce_window(operand, init_value, …[, …])

rem(x, y)

按位取余:(x bmod y)。

reshape(operand, new_sizes[, dimensions])

封装了 XLA 的 Reshape 运算符。

rev(operand, dimensions)

封装了 XLA 的 Rev 运算符。

rng_bit_generator(key, shape[, dtype, algorithm])

无状态的伪随机数位生成器。

rng_uniform(a, b, shape)

有状态的伪随机数生成器。

round(x[, rounding_method])

按位四舍五入。

rsqrt(x)

按位倒数平方根:(1 over sqrt{x})。

scatter(operand, scatter_indices, updates, …)

Scatter-update 运算符。

scatter_add(operand, scatter_indices, …[, …])

Scatter-add 运算符。

scatter_apply(operand, scatter_indices, …)

Scatter-apply 运算符。

scatter_max(operand, scatter_indices, …[, …])

Scatter-max 运算符。

scatter_min(operand, scatter_indices, …[, …])

Scatter-min 运算符。

scatter_mul(operand, scatter_indices, …[, …])

Scatter-multiply 运算符。

shift_left(x, y)

按位左移:(x ll y)。

shift_right_arithmetic(x, y)

按位算术右移:(x gg y)。

shift_right_logical(x, y)

按位逻辑右移:(x gg y)。

sign(x)

按位符号函数。

sin(x)

按位正弦函数:(mathrm{sin}(x))。

sinh(x)

按位双曲正弦函数:(mathrm{sinh}(x))。

slice(operand, start_indices, limit_indices)

封装了 XLA 的 Slice 运算符。

slice_in_dim(operand, start_index, limit_index)

lax.slice() 的单维度应用封装。

sort()

封装了 XLA 的 Sort 运算符。

sort_key_val(keys, values[, dimension, …])

沿着dimension排序keys并对values应用相同的置换。

sqrt(x)

逐元素平方根:(sqrt{x})。

square(x)

逐元素平方:(x²)。

squeeze(array, dimensions)

从数组中挤出任意数量的大小为 1 的维度。

sub(x, y)

逐元素减法:(x - y)。

tan(x)

逐元素正切:(mathrm{tan}(x))。

tanh(x)

逐元素双曲正切:(mathrm{tanh}(x))。

top_k(operand, k)

返回operand最后一轴上的前k个值及其索引。

transpose(operand, permutation)

包装 XLA 的Transpose运算符。

zeros_like_array(x)

zeta(x, q)

逐元素 Hurwitz zeta 函数:(zeta(x, q))

控制流操作符

associative_scan(fn, elems[, reverse, axis])

使用关联二元操作并行执行扫描。

cond(pred, true_fun, false_fun, *operands[, …])

根据条件应用true_fun或false_fun。

fori_loop(lower, upper, body_fun, init_val, *)

通过归约到jax.lax.while_loop()从lower到upper循环。

map(f, xs)

在主要数组轴上映射函数。

scan(f, init[, xs, length, reverse, unroll, …])

在主要数组轴上扫描函数并携带状态。

select(pred, on_true, on_false)

根据布尔谓词在两个分支之间选择。

select_n(which, *cases)

从多个情况中选择数组值。

switch(index, branches, *operands[, operand])

根据index应用恰好一个branches。

while_loop(cond_fun, body_fun, init_val)

在cond_fun为 True 时重复调用body_fun。

自定义梯度操作符

stop_gradient(x)

停止梯度计算。

custom_linear_solve(matvec, b, solve[, …])

使用隐式定义的梯度执行无矩阵线性求解。

custom_root(f, initial_guess, solve, …[, …])

可微分求解函数的根。

并行操作符

all_gather(x, axis_name, *[, …])

在所有副本中收集x的值。

all_to_all(x, axis_name, split_axis, …[, …])

映射轴的实例化和映射不同轴。

pdot(x, y, axis_name[, pos_contract, …])

psum(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上进行全归约求和。

psum_scatter(x, axis_name, *[, …])

像psum(x, axis_name),但每个设备仅保留部分结果。

pmax(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上计算全归约最大值。

pmin(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上计算全归约最小值。

pmean(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上计算全归约均值。

ppermute(x, axis_name, perm)

根据置换 perm 执行集体置换。

pshuffle(x, axis_name, perm)

使用替代置换编码的 jax.lax.ppermute 的便捷包装器

pswapaxes(x, axis_name, axis, *[, …])

将 pmapped 轴 axis_name 与非映射轴 axis 交换。

axis_index(axis_name)

返回沿映射轴 axis_name 的索引。

与分片相关的操作符

with_sharding_constraint(x, shardings)

在 jitted 计算中约束数组的分片机制

线性代数操作符 (jax.lax.linalg)

cholesky(x, *[, symmetrize_input])

Cholesky 分解。

eig(x, *[, compute_left_eigenvectors, …])

一般矩阵的特征分解。

eigh(x, *[, lower, symmetrize_input, …])

Hermite 矩阵的特征分解。

hessenberg(a)

将方阵约化为上 Hessenberg 形式。

lu(x)

带有部分主元列主元分解。

householder_product(a, taus)

单元 Householder 反射的乘积。

qdwh(x, *[, is_hermitian, max_iterations, …])

基于 QR 的动态加权 Halley 迭代进行极分解。

qr(x, *[, full_matrices])

QR 分解。

schur(x, *[, compute_schur_vectors, …])

svd()

奇异值分解。

triangular_solve(a, b, *[, left_side, …])

三角解法。

tridiagonal(a, *[, lower])

将对称/Hermitian 矩阵约化为三对角形式。

tridiagonal_solve(dl, d, du, b)

计算三对角线性系统的解。

参数类

代码语言:javascript复制
class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)

描述卷积的批量、空间和特征维度。

参数:

  • lhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
  • rhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(输出特征维度,输入特征维度,空间维度…)。
  • out_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
代码语言:javascript复制
jax.lax.ConvGeneralDilatedDimensionNumbers

alias of tuple[str, str, str] | ConvDimensionNumbers | None

代码语言:javascript复制
class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)

描述了传递给 XLA 的 Gather 运算符 的维度号参数。有关维度号含义的详细信息,请参阅 XLA 文档。

Parameters:

  • offset_dims (tuple[int, …**]) – gather 输出中偏移到从操作数切片的数组中的维度的集合。必须是升序整数元组,每个代表输出的一个维度编号。
  • collapsed_slice_dims (tuple[int, …**]) – operand 中具有 slice_sizes[i] == 1 的维度 i 的集合,这些维度不应在 gather 输出中具有对应维度。必须是一个升序整数元组。
  • start_index_map (tuple[int, …**]) – 对于 start_indices 中的每个维度,给出应该被切片的操作数中对应的维度。必须是一个大小等于 start_indices.shape[-1] 的整数元组。

与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐含的;总是存在一个索引向量维度,且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。

代码语言:javascript复制
class jax.lax.GatherScatterMode(value)

描述了如何处理 gather 或 scatter 中的越界索引。

可能的值包括:

CLIP:

索引将被夹在最近的范围内值上,即整个要收集的窗口都在范围内。

FILL_OR_DROP:

如果收集窗口的任何部分越界,则返回整个窗口,即使其他部分原本在界内的元素也将用常量填充。如果分散窗口的任何部分越界,则整个窗口将被丢弃。

PROMISE_IN_BOUNDS:

用户承诺索引在范围内。不会执行额外检查。实际上,根据当前的 XLA 实现,这意味着越界的 gather 将被夹在范围内,但越界的 scatter 将被丢弃。如果索引越界,则梯度将不正确。

代码语言:javascript复制
class jax.lax.Precision(value)

lax 函数的精度枚举

JAX 函数的精度参数通常控制加速器后端(即 TPU 和 GPU)上的数组计算速度和精度之间的权衡。成员包括:

默认:

最快模式,但最不准确。在 bfloat16 中执行计算。别名:'default''fastest''bfloat16'

高:

较慢但更准确。以 3 个 bfloat16 传递执行 float32 计算,或在可用时使用 tensorfloat32。别名:'high''bfloat16_3x''tensorfloat32'

最高:

最慢但最准确。根据适用情况在 float32 或 float64 中执行计算。别名:'highest''float32'

代码语言:javascript复制
jax.lax.PrecisionLike

别名为 str | Precision | tuple[str, str] | tuple[Precision, Precision] | None

代码语言:javascript复制
class jax.lax.RoundingMethod(value)

一个枚举。

代码语言:javascript复制
class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

描述了对 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。

参数:

  • update_window_dims (Sequence[int]) – 更新中作为窗口维度的维度集合。必须是整数元组,按升序排列,每个表示一个维度编号。
  • inserted_window_dims (Sequence[int]) – 必须插入更新形状的大小为 1 的窗口维度集合。必须是整数元组,按升序排列,每个表示输出的维度编号的镜像图。这些是 gather 情况下 collapsed_slice_dims 的镜像图。
  • scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,大小等于 scatter_indices.shape[-1]。

与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是有一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,添加一个尺寸为 1 的尾随维度。

jax.random 模块

原文:jax.readthedocs.io/en/latest/jax.random.html

伪随机数生成的实用程序。

jax.random 包提供了多种例程,用于确定性生成伪随机数序列。

基本用法

代码语言:javascript复制
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches)) 

PRNG keys

与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key() 函数生成:

代码语言:javascript复制
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0] 

然后,可以在 JAX 的任何随机数生成例程中使用该 key:

代码语言:javascript复制
>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:

代码语言:javascript复制
>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

如果需要新的随机数,可以使用 jax.random.split() 生成新的子 key:

代码语言:javascript复制
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32) 

注意

类型化的 key 数组,例如上述 key<fry>,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32 数组表示,其最终维度表示 key 的位级表示。

两种形式的 key 数组仍然可以通过 jax.random 模块创建和使用。新式的类型化 key 数组使用 jax.random.key() 创建。传统的 uint32 key 数组使用 jax.random.PRNGKey() 创建。

要在两者之间进行转换,使用 jax.random.key_data()jax.random.wrap_key_data()。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。

否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:

  • 它们有一个额外的尾维度。
  • 它们具有数字数据类型 (uint32),允许进行通常不用于 key 的操作,例如整数算术。
  • 它们不包含有关 RNG 实现的信息。当传统 key 传递给 jax.random 函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。

要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263。

高级

设计和背景

TLDR:JAX PRNG = Threefry counter PRNG 一个功能数组导向的 分裂模型

更多详细信息,请参阅 docs/jep/263-prng.md。

总结一下,JAX PRNG 还包括但不限于以下要求:

  1. 确保可重现性,
  2. 良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。
高级 RNG 配置

JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。

  • 默认,“threefry2x32”: 基于 Threefry 哈希函数构建的基于计数器的 PRNG。
  • 实验性 一种仅包装了 XLA 随机位生成器(RBG)算法的 PRNG。请参阅 TF 文档。
    • “rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
    • “unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。

    这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。

不使用默认 RNG 的可能原因是:

  1. 可能编译速度较慢(特别是对于 Google Cloud TPU)
  2. 在 TPU 上执行速度较慢
  3. 不支持高效的自动分片/分区

这里是一个简短的总结:

属性

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

在 TPU 上最快

可以高效分片(使用 pjit)

在分片中相同

在 CPU/GPU/TPU 上相同

在 JAX/XLA 版本间相同

(*): 设置了jax_threefry_partitionable=1

(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于 jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在 jax.random.split 和 jax.random.fold_in 中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。

要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

API 参考

密钥创建与操作

PRNGKey(seed, *[, impl])

给定整数种子创建伪随机数生成器(PRNG)密钥。

key(seed, *[, impl])

给定整数种子创建伪随机数生成器(PRNG)密钥。

key_data(密钥)

恢复 PRNG 密钥数组下的密钥数据位。

wrap_key_data(key_bits_array, *[, impl])

将密钥数据位数组包装成 PRNG 密钥数组。

fold_in(key, data)

将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。

split(key[, num])

将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。

clone(key)

克隆一个密钥以便重复使用。

随机抽样器

ball(key, d[, p, shape, dtype])

从单位 Lp 球中均匀采样。

bernoulli(key[, p, shape])

采样给定形状和均值的伯努利分布随机值。

beta(key, a, b[, shape, dtype])

采样给定形状和浮点数数据类型的贝塔分布随机值。

binomial(key, n, p[, shape, dtype])

采样给定形状和浮点数数据类型的二项分布随机值。

bits(key[, shape, dtype])

以无符号整数的形式采样均匀比特。

categorical(key, logits[, axis, shape])

从分类分布中采样随机值。

cauchy(key[, shape, dtype])

采样给定形状和浮点数数据类型的柯西分布随机值。

chisquare(key, df[, shape, dtype])

采样给定形状和浮点数数据类型的卡方分布随机值。

choice(key, a[, shape, replace, p, axis])

从给定数组中生成随机样本。

dirichlet(key, alpha[, shape, dtype])

采样给定形状和浮点数数据类型的狄利克雷分布随机值。

double_sided_maxwell(key, loc, scale[, …])

从双边 Maxwell 分布中采样。

exponential(key[, shape, dtype])

采样给定形状和浮点数数据类型的指数分布随机值。

f(key, dfnum, dfden[, shape, dtype])

采样给定形状和浮点数数据类型的 F 分布随机值。

gamma(key, a[, shape, dtype])

采样给定形状和浮点数数据类型的伽马分布随机值。

generalized_normal(key, p[, shape, dtype])

从广义正态分布中采样。

geometric(key, p[, shape, dtype])

采样给定形状和浮点数数据类型的几何分布随机值。

gumbel(key[, shape, dtype])

采样给定形状和浮点数数据类型的 Gumbel 分布随机值。

laplace(key[, shape, dtype])

采样给定形状和浮点数数据类型的拉普拉斯分布随机值。

loggamma(key, a[, shape, dtype])

采样给定形状和浮点数数据类型的对数伽马分布随机值。

logistic(key[, shape, dtype])

采样给定形状和浮点数数据类型的 logistic 随机值。

lognormal(key[, sigma, shape, dtype])

采样给定形状和浮点数数据类型的对数正态分布随机值。

maxwell(key[, shape, dtype])

从单边 Maxwell 分布中采样。

multivariate_normal(key, mean, cov[, shape, …])

采样给定均值和协方差的多变量正态分布随机值。

normal(key[, shape, dtype])

采样给定形状和浮点数数据类型的标准正态分布随机值。

orthogonal(key, n[, shape, dtype])

从正交群 O(n) 中均匀采样。

pareto(key, b[, shape, dtype])

采样给定形状和浮点数数据类型的帕累托分布随机值。

permutation(key, x[, axis, independent])

返回随机排列的数组或范围。

poisson(key, lam[, shape, dtype])

采样给定形状和整数数据类型的泊松分布随机值。

rademacher(key[, shape, dtype])

从 Rademacher 分布中采样。

randint(key, shape, minval, maxval[, dtype])

用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。

[rayleigh(key, scale[, shape, dtype])

用给定的形状和浮点数数据类型示例瑞利随机值。

t(key, df[, shape, dtype])

用给定的形状和浮点数数据类型示例学生 t 分布随机值。

triangular(key, left, mode, right[, shape, …])

用给定的形状和浮点数数据类型示例三角形随机值。

truncated_normal(key, lower, upper[, shape, …])

用给定的形状和数据类型示例截断标准正态随机值。

uniform(key[, shape, dtype, minval, maxval])

用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。

[wald(key, mean[, shape, dtype])

用给定的形状和浮点数数据类型示例瓦尔德随机值。

weibull_min(key, scale, concentration[, …])

从威布尔分布中采样。

jax.sharding 模块

原文:jax.readthedocs.io/en/latest/jax.sharding.html

代码语言:javascript复制
class jax.sharding.Sharding

描述了jax.Array如何跨设备布局。

代码语言:javascript复制
property addressable_devices: set[Device]

Sharding中由当前进程可寻址的设备集合。

代码语言:javascript复制
addressable_devices_indices_map(global_shape)

从可寻址设备到它们包含的数组数据切片的映射。

addressable_devices_indices_map 包含适用于可寻址设备的device_indices_map部分。

参数:

global_shape (tuple[int, …**])

返回类型:

Mapping[Device, tuple[slice, …] | None]

代码语言:javascript复制
property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript复制
devices_indices_map(global_shape)

返回从设备到它们包含的数组切片的映射。

映射包括所有全局设备,即包括来自其他进程的不可寻址设备。

参数:

global_shape (tuple[int, …**])

返回类型:

Mapping[Device, tuple[slice, …]]

代码语言:javascript复制
is_equivalent_to(other, ndim)

如果两个分片等效,则返回True

如果它们在相同设备上放置了相同的逻辑数组分片,则两个分片是等效的。

例如,如果NamedShardingPositionalSharding都将数组的相同分片放置在相同的设备上,则它们可能是等效的。

参数:

  • self (Sharding)
  • other (Sharding)
  • ndim (int)

返回类型:

bool

代码语言:javascript复制
property is_fully_addressable: bool

此分片是否是完全可寻址的?

如果当前进程能够寻址Sharding中列出的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable 等效于 “is_local”。

代码语言:javascript复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则分片是完全复制的。

代码语言:javascript复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript复制
shard_shape(global_shape)

返回每个设备上数据的形状。

此函数返回的分片形状是从global_shape和分片属性计算得出的。

参数:

global_shape (tuple[int, …**])

返回类型:

tuple[int, …]

代码语言:javascript复制
with_memory_kind(kind)

返回具有指定内存类型的新分片实例。

参数:

kind (str)

返回类型:

分片

代码语言:javascript复制
class jax.sharding.SingleDeviceSharding

基类:分片

一个将其数据放置在单个设备上的分片

参数:

device – 单个设备

示例

代码语言:javascript复制
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0]) 
代码语言:javascript复制
property device_set: set[Device]

分片跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的非可寻址设备。

代码语言:javascript复制
devices_indices_map(global_shape)

返回从设备到每个包含的数组片段的映射。

映射包括所有全局设备,即包括来自其他进程的非可寻址设备。

参数:

global_shape (tuple[int, …**])

返回类型:

映射[设备, tuple[slice, …]]

代码语言:javascript复制
property is_fully_addressable: bool

此分片是否完全可寻址?

如果当前进程可以寻址分片中命名的所有设备,则称分片完全可寻址。is_fully_addressable在多进程 JAX 中等同于“is_local”。

代码语言:javascript复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则分片完全复制。

代码语言:javascript复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript复制
with_memory_kind(kind)

返回具有指定内存类型的新分片实例。

参数:

kind (str)

返回类型:

单设备分片

代码语言:javascript复制
class jax.sharding.NamedSharding

基类:分片

一个NamedSharding使用命名轴来表示分片。

一个NamedSharding是设备Mesh和描述如何跨该网格对数组进行分片的PartitionSpec的组合。

一个Mesh是 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x''y'

一个PartitionSpec是一个元组,其元素可以是None、一个网格轴或一组网格轴的元组。每个元素描述如何在零个或多个网格维度上对输入维度进行分区。例如,PartitionSpec('x', 'y')表示数据的第一维在网格的 x 轴上进行分片,第二维在网格的 y 轴上进行分片。

分布式数组和自动并行化(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)教程详细讲解了如何使用MeshPartitionSpec,包括更多细节和图示。

参数:

  • mesh – 一个jax.sharding.Mesh对象。
  • spec – 一个 jax.sharding.PartitionSpec 对象。

示例

代码语言:javascript复制
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec) 
代码语言:javascript复制
property addressable_devices: set[Device]

当前进程可以访问的Sharding中的设备集。

代码语言:javascript复制
property device_set: set[Device]

Sharding跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript复制
property is_fully_addressable: bool

此分片是否完全可寻址?

一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。

代码语言:javascript复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则称分片为完全复制。

代码语言:javascript复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript复制
property mesh

(self) -> object

代码语言:javascript复制
property spec

(self) -> object

代码语言:javascript复制
with_memory_kind(kind)

返回具有指定内存类型的新Sharding实例。

参数:

kind (str)

返回类型:

NamedSharding

代码语言:javascript复制
class jax.sharding.PositionalSharding(devices, *, memory_kind=None)

基类:Sharding

参数:

  • devices (Sequence*[xc.Device]* | np.ndarray)
  • memory_kind (str | None)
代码语言:javascript复制
property device_set: set[Device]

Sharding跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript复制
property is_fully_addressable: bool

此分片是否完全可寻址?

一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。

代码语言:javascript复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则称分片为完全复制。

代码语言:javascript复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript复制
with_memory_kind(kind)

返回具有指定内存类型的新Sharding实例。

参数:

kind (str)

返回类型:

PositionalSharding

代码语言:javascript复制
class jax.sharding.PmapSharding

基类:Sharding

描述了jax.pmap()使用的分片。

代码语言:javascript复制
classmethod default(shape, sharded_dim=0, devices=None)

创建一个PmapSharding,与jax.pmap()使用的默认放置方式匹配。

参数:

  • shape (tuple[int, …**]) – 输入数组的形状。
  • sharded_dim (int") – 输入数组进行分片的维度。默认为 0。
  • devicesSequence[Device] | None) – 可选的设备序列。如果省略,隐含的
  • usedpmap 使用的设备顺序是) – jax.local_devices()
  • of这是顺序) – jax.local_devices()

返回类型:

PmapSharding

代码语言:javascript复制
property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括其他进程的非可寻址设备。

代码语言:javascript复制
property devices

(self)-> ndarray

代码语言:javascript复制
devices_indices_map(global_shape)

返回设备到每个包含的数组切片的映射。

映射包括所有全局设备,即包括其他进程的非可寻址设备。

参数:

global_shape元组[int,…**]

返回类型:

Mapping[Device,元组[切片,…]]

代码语言:javascript复制
is_equivalent_to(other, ndim)

如果两个分片等效,则返回True

如果它们将相同的逻辑数组分片放置在相同的设备上,则两个分片是等效的。

例如,如果NamedShardingPositionalSharding将数组的相同分片放置在相同的设备上,则它们可能是等效的。

参数:

  • selfPmapSharding
  • otherPmapSharding
  • ndimint

返回类型:

布尔(“in Python v3.12”)

代码语言:javascript复制
property is_fully_addressable: bool

这个分片是否完全可寻址?

如果当前进程能够处理Sharding中命名的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable相当于“is_local”。

代码语言:javascript复制
property is_fully_replicated: bool

这个分片是否完全复制?

如果每个设备都有完整数据的副本,则分片是完全复制的。

代码语言:javascript复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript复制
shard_shape(global_shape)

返回每个设备上数据的形状。

此函数返回的分片形状是从global_shape和分片属性计算而来的。

参数:

global_shape元组[int,…**]

返回类型:

元组[int,…]

代码语言:javascript复制
property sharding_spec

(self)-> jax::ShardingSpec

代码语言:javascript复制
with_memory_kind(kind)

返回具有指定内存类型的新 Sharding 实例。

参数:

kindstr

代码语言:javascript复制
class jax.sharding.GSPMDSharding

基类:Sharding

代码语言:javascript复制
property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript复制
property is_fully_addressable: bool

此分片是否完全可寻址?

如果当前进程可以访问Sharding中命名的所有设备,则分片是完全可寻址的。is_fully_addressable相当于多进程 JAX 中的“is_local”。

代码语言:javascript复制
property is_fully_replicated: bool

此分片是否完全复制?

一个分片是完全复制的,如果每个设备都有整个数据的完整副本。

代码语言:javascript复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript复制
with_memory_kind(kind)

返回具有指定内存类型的新 Sharding 实例。

参数:

kindstr

返回类型:

GSPMDSharding

代码语言:javascript复制
class jax.sharding.PartitionSpec(*partitions)

元组描述如何在设备网格上对数组进行分区。

每个元素都可以是None、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding的文档。

此类存在,以便 JAX 的 pytree 实用程序可以区分分区规范和应视为 pytrees 的元组。

代码语言:javascript复制
class jax.sharding.Mesh(devices, axis_names)

声明在此管理器范围内可用的硬件资源。

特别是,所有axis_names在管理块内都变成有效的资源名称,并且可以在jax.experimental.pjit.pjit()in_axis_resources参数中使用,还请参阅 JAX 的多进程编程模型(jax.readthedocs.io/en/latest/multi_process.html)和分布式数组与自动并行化教程(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

如果您在多线程中编译,请确保with Mesh上下文管理器位于线程将执行的函数内部。

参数:

  • devicesndarray) - 包含 JAX 设备对象(例如从jax.devices()获得的对象)的 NumPy ndarray 对象。
  • axis_namestuple[Any, …**]) - 资源轴名称序列,用于分配给devices参数的维度。其长度应与devices的秩匹配。

示例

代码语言:javascript复制
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
代码语言:javascript复制
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
代码语言:javascript复制
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
代码语言:javascript复制
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 

jax.debug 模块

原文:jax.readthedocs.io/en/latest/jax.debug.html

运行时值调试实用工具

jax.debug.print 和 jax.debug.breakpoint 描述了如何利用 JAX 的运行时值调试功能。

callback(callback, *args[, ordered])

调用可分阶段的 Python 回调函数。

print(fmt, *args[, ordered])

打印值,并在 JAX 函数中工作。

breakpoint(*[, backend, filter_frames, …])

在程序中某一点设置断点。

调试分片实用工具

能够在分段函数内(和外部)检查和可视化数组分片的函数。

inspect_array_sharding(value, *, callback)

在 JIT 编译函数内部启用检查数组分片。

visualize_array_sharding(arr, **kwargs)

可视化数组的分片。

visualize_sharding(shape, sharding, *[, …])

使用 rich 可视化 Sharding。

jax.dlpack 模块

原文:jax.readthedocs.io/en/latest/jax.dlpack.html

from_dlpack(external_array[, device, copy])

返回一个 DLPack 张量的 Array 表示形式。

to_dlpack(x[, stream, src_device, …])

返回一个封装了 Array x 的 DLPack 张量。

jax.distributed 模块

原文:jax.readthedocs.io/en/latest/jax.distributed.html

initialize([coordinator_address, …])

初始化 JAX 分布式系统。

shutdown()

关闭分布式系统。

jax.dtypes 模块

原文:jax.readthedocs.io/en/latest/jax.dtypes.html

bfloat16

bfloat16 浮点数值

canonicalize_dtype(dtype[, allow_extended_dtype])

根据config.x64_enabled配置将 dtype 转换为规范的 dtype。

float0

对应于相同名称的标量类型和 dtype 的 DType 类。

issubdtype(a, b)

如果第一个参数是类型代码在类型层次结构中较低/相等,则返回 True。

prng_key()

PRNG Key dtypes 的标量类。

result_type(*args[, return_weak_type_flag])

方便函数,用于应用 JAX 参数 dtype 提升。

scalar_type_of(x)

返回与 JAX 值关联的标量类型。

jax.flatten_util 模块

原文:jax.readthedocs.io/en/latest/jax.flatten_util.html

函数列表

-

ravel_pytree(pytree)

将一个数组的 pytree 展平(压缩)为一个 1D 数组。

jax.image 模块

原文:jax.readthedocs.io/en/latest/jax.image.html

图像操作函数。

更多的图像操作函数可以在建立在 JAX 之上的库中找到,例如 PIX。

图像操作函数

resize(image, shape, method[, antialias, …])

图像调整大小。

scale_and_translate(image, shape, …[, …])

对图像应用缩放和平移。

参数类

代码语言:javascript复制
class jax.image.ResizeMethod(value)

图像调整大小方法。

可能的取值包括:

NEAREST:

最近邻插值。

LINEAR:

线性插值。

LANCZOS3:

Lanczos 重采样,使用半径为 3 的核。

LANCZOS5:

Lanczos 重采样,使用半径为 5 的核。

CUBIC:

三次插值,使用 Keys 三次核。

jax.nn 模块

原文:jax.readthedocs.io/en/latest/jax.nn.html

  • jax.nn.initializers 模块

神经网络库常见函数。

激活函数

relu

线性整流单元激活函数。

relu6

线性整流单元 6 激活函数。

sigmoid(x)

Sigmoid 激活函数。

softplus(x)

Softplus 激活函数。

sparse_plus(x)

稀疏加法函数。

sparse_sigmoid(x)

稀疏 Sigmoid 激活函数。

soft_sign(x)

Soft-sign 激活函数。

silu(x)

SiLU(又称 swish)激活函数。

swish(x)

SiLU(又称 swish)激活函数。

log_sigmoid(x)

对数 Sigmoid 激活函数。

leaky_relu(x[, negative_slope])

泄漏整流线性单元激活函数。

hard_sigmoid(x)

硬 Sigmoid 激活函数。

hard_silu(x)

硬 SiLU(swish)激活函数。

hard_swish(x)

硬 SiLU(swish)激活函数。

hard_tanh(x)

硬tanh 激活函数。

elu(x[, alpha])

指数线性单元激活函数。

celu(x[, alpha])

连续可微的指数线性单元激活函数。

selu(x)

缩放的指数线性单元激活函数。

gelu(x[, approximate])

高斯误差线性单元激活函数。

glu(x[, axis])

门控线性单元激活函数。

squareplus(x[, b])

Squareplus 激活函数。

mish(x)

Mish 激活函数。

其他函数

softmax(x[, axis, where, initial])

Softmax 函数。

log_softmax(x[, axis, where, initial])

对数 Softmax 函数。

logsumexp()

对数-总和-指数归约。

standardize(x[, axis, mean, variance, …])

通过减去mean并除以(sqrt{mathrm{variance}})来标准化数组。

one_hot(x, num_classes, *[, dtype, axis])

对给定索引进行 One-hot 编码。

jax.nn.initializers 模块

原文:jax.readthedocs.io/en/latest/jax.nn.initializers.html

与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。

初始化器

该模块提供了与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。

初始化器是一个函数,接受三个参数:(key, shape, dtype),并返回一个具有形状shape和数据类型dtype的数组。参数key是一个 PRNG 密钥(例如来自jax.random.key()),用于生成初始化数组的随机数。

constant(value[, dtype])

构建一个返回常数值数组的初始化器。

delta_orthogonal([scale, column_axis, dtype])

构建一个用于增量正交核的初始化器。

glorot_normal([in_axis, out_axis, …])

构建一个 Glorot 正态初始化器(又称 Xavier 正态初始化器)。

glorot_uniform([in_axis, out_axis, …])

构建一个 Glorot 均匀初始化器(又称 Xavier 均匀初始化器)。

he_normal([in_axis, out_axis, batch_axis, dtype])

构建一个 He 正态初始化器(又称 Kaiming 正态初始化器)。

he_uniform([in_axis, out_axis, batch_axis, …])

构建一个 He 均匀初始化器(又称 Kaiming 均匀初始化器)。

lecun_normal([in_axis, out_axis, …])

构建一个 Lecun 正态初始化器。

lecun_uniform([in_axis, out_axis, …])

构建一个 Lecun 均匀初始化器。

normal([stddev, dtype])

构建一个返回实数正态分布随机数组的初始化器。

ones(key, shape[, dtype])

返回一个填充为一的常数数组的初始化器。

orthogonal([scale, column_axis, dtype])

构建一个返回均匀分布正交矩阵的初始化器。

truncated_normal([stddev, dtype, lower, upper])

构建一个返回截断正态分布随机数组的初始化器。

uniform([scale, dtype])

构建一个返回实数均匀分布随机数组的初始化器。

variance_scaling(scale, mode, distribution)

初始化器,根据权重张量的形状调整其尺度。

zeros(key, shape[, dtype])

返回一个填充零的常数数组的初始化器。

jax.ops 模块

原文:jax.readthedocs.io/en/latest/jax.ops.html

段落约简运算符

| segment_max(data, segment_ids[, …]) | 计算数组段内的最大值。 |

函数 jax.ops.index_update、jax.ops.index_add 等已在 JAX 0.2.22 中弃用,并已移除。请改用 JAX 数组上的 jax.numpy.ndarray.at 属性。

segment_min(data, segment_ids[, …])

segment_prod(data, segment_ids[, …])

segment_sum(data, segment_ids[, …])

jax.profiler 模块

原文:jax.readthedocs.io/en/latest/jax.profiler.html

跟踪和时间分析

描述了如何利用 JAX 的跟踪和时间分析功能进行程序性能分析。

start_server(port)

在指定端口启动分析器服务器。

start_trace(log_dir[, create_perfetto_link, …])

启动性能分析跟踪。

stop_trace()

停止当前正在运行的性能分析跟踪。

trace(log_dir[, create_perfetto_link, …])

上下文管理器,用于进行性能分析跟踪。

annotate_function(func[, name])

生成函数执行的跟踪事件的装饰器。

TraceAnnotation

在分析器中生成跟踪事件的上下文管理器。

StepTraceAnnotation(name, **kwargs)

在分析器中生成步骤跟踪事件的上下文管理器。

设备内存分析

请参阅设备内存分析,了解 JAX 的设备内存分析功能简介。

device_memory_profile([backend])

捕获 JAX 设备内存使用情况,格式为 pprof 协议缓冲区。

save_device_memory_profile(filename[, backend])

收集设备内存使用情况,并将其写入文件。

jax.stages 模块

原文:jax.readthedocs.io/en/latest/jax.stages.html

接口到编译执行过程的各个阶段。

JAX 转换,例如jax.jitjax.pmap,也支持一种通用的显式降阶和预编译执行 ahead of time 的方式。 该模块定义了代表这一过程各个阶段的类型。

有关更多信息,请参阅AOT walkthrough。

代码语言:javascript复制
class jax.stages.Wrapped(*args, **kwargs)

一个准备好进行追踪、降阶和编译的函数。

此协议反映了诸如jax.jit之类的函数的输出。 调用它会导致 JIT(即时)降阶、编译和执行。 它也可以在编译之前明确降阶,并在执行之前编译结果。

代码语言:javascript复制
__call__(*args, **kwargs)

执行包装的函数,根据需要进行降阶和编译。

代码语言:javascript复制
lower(*args, **kwargs)

明确为给定的参数降阶此函数。

一个降阶函数被从 Python 阶段化,并翻译为编译器的输入语言,可能以依赖于后端的方式。 它已准备好进行编译,但尚未编译。

返回:

一个Lowered实例,表示降阶。

返回类型:

降阶

代码语言:javascript复制
trace(*args, **kwargs)

明确为给定的参数追踪此函数。

一个追踪函数被从 Python 阶段化,并翻译为一个 jaxpr。 它已准备好进行降阶,但尚未降阶。

返回:

一个Traced实例,表示追踪。

返回类型:

追踪

代码语言:javascript复制
class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)

降阶一个根据参数类型和值特化的函数。

降阶是一种准备好进行编译的计算。 此类将降阶与稍后编译和执行所需的剩余信息一起携带。 它还提供了一个通用的 API,用于查询 JAX 各种降阶路径(jit()pmap()等)中降阶计算的属性。

参数:

  • 降阶XlaLowering
  • args_infoAny
  • out_treePyTreeDef
  • no_kwargsbool
代码语言:javascript复制
as_text(dialect=None)

此降阶的人类可读文本表示。

旨在可视化和调试目的。 这不必是有效的也不一定可靠的序列化。 它直接传递给外部调用者。

参数:

方言str | ) – 可选字符串,指定一个降阶方言(例如,“stablehlo”)

返回类型:

str

代码语言:javascript复制
compile(compiler_options=None)

编译,并返回相应的Compiled实例。

参数:

compiler_options (dict[str, str | bool] | None)

返回类型:

Compiled

代码语言:javascript复制
compiler_ir(dialect=None)

这种降低的任意对象表示。

旨在调试目的。这不是有效的也不是可靠的序列化。输出在不同调用之间没有一致性的保证。

如果不可用,则返回None,例如基于后端、编译器或运行时。

参数:

dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)

返回类型:

Any | None

代码语言:javascript复制
cost_analysis()

执行成本估算的摘要。

旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

Any | None

代码语言:javascript复制
property in_tree: PyTreeDef

一对(位置参数、关键字参数)的树结构。

代码语言:javascript复制
class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)

编译后的函数专门针对类型/值进行了优化表示。

编译计算与可执行文件相关联,并提供执行所需的剩余信息。它还为查询 JAX 的各种编译路径和后端中编译计算属性提供了一个共同的 API。

参数:

  • args_info (Any)
  • out_tree (PyTreeDef)
代码语言:javascript复制
__call__(*args, **kwargs)

将自身作为函数调用。

代码语言:javascript复制
as_text()

这是可执行文件的人类可读文本表示。

旨在可视化和调试。这不是有效的也不是可靠的序列化。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

str | None

代码语言:javascript复制
cost_analysis()

执行成本估算的摘要。

旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

Any | None

代码语言:javascript复制
property in_tree: PyTreeDef

(位置参数,关键字参数) 的树结构。

代码语言:javascript复制
memory_analysis()

估计内存需求的摘要。

用于可视化和调试目的。由此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如嵌套的字典、列表和具有数字叶子的元组)。然而,其结构可以是任意的:在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能是不一致的。

返回 None 如果不可用,例如基于后端、编译器或运行时。

返回类型:

任意 | None

代码语言:javascript复制
runtime_executable()

此可执行对象的任意对象表示。

用于调试目的。这不是有效也不是可靠的序列化。输出不能保证在不同调用之间的一致性。

返回 None 如果不可用,例如基于后端、编译器或运行时。

返回类型:

任意 | None

0 人点赞