原文:
jax.readthedocs.io/en/latest/
jax.scipy 模块
原文:
jax.readthedocs.io/en/latest/jax.scipy.html
jax.scipy.cluster
| vq
(obs, code_book[, check_finite]) | 将观测值分配给代码簿中的代码。 | ## jax.scipy.fft
dct(x[, type, n, axis, norm]) | 计算输入的离散余弦变换 |
---|---|
dctn(x[, type, s, axes, norm]) | 计算输入的多维离散余弦变换 |
idct(x[, type, n, axis, norm]) | 计算输入的离散余弦变换的逆变换 |
| idctn
(x[, type, s, axes, norm]) | 计算输入的多维离散余弦变换的逆变换 | ## jax.scipy.integrate
| trapezoid
(y[, x, dx, axis]) | 使用复合梯形法则沿指定轴积分。 | ## jax.scipy.interpolate
| RegularGridInterpolator
(points, values[, …]) | 对正规矩形网格上的点进行插值。 | ## jax.scipy.linalg
block_diag(*arrs) | 从输入数组创建块对角矩阵。 |
---|---|
cho_factor(a[, lower, overwrite_a, check_finite]) | 基于 Cholesky 的线性求解因式分解 |
cho_solve(c_and_lower, b[, overwrite_b, …]) | 使用 Cholesky 分解解线性系统 |
cholesky(a[, lower, overwrite_a, check_finite]) | 计算矩阵的 Cholesky 分解。 |
det(a[, overwrite_a, check_finite]) | 计算矩阵的行列式 |
eigh() | 计算 Hermitian 矩阵的特征值和特征向量 |
eigh_tridiagonal(d, e, *[, eigvals_only, …]) | 解对称实三对角矩阵的特征值问题 |
expm(A, *[, upper_triangular, max_squarings]) | 计算矩阵指数 |
expm_frechet() | 计算矩阵指数的 Frechet 导数 |
funm(A, func[, disp]) | 评估矩阵值函数 |
hessenberg() | 计算矩阵的 Hessenberg 形式 |
hilbert(n) | 创建阶数为 n 的 Hilbert 矩阵。 |
inv(a[, overwrite_a, check_finite]) | 返回方阵的逆矩阵 |
lu() | 计算 LU 分解 |
lu_factor(a[, overwrite_a, check_finite]) | 基于 LU 的线性求解因式分解 |
lu_solve(lu_and_piv, b[, trans, …]) | 使用 LU 分解解线性系统 |
polar(a[, side, method, eps, max_iterations]) | 计算极分解 |
qr() | 计算数组的 QR 分解 |
rsf2csf(T, Z[, check_finite]) | 将实数舒尔形式转换为复数舒尔形式。 |
schur(a[, output]) | 计算舒尔分解 |
solve(a, b[, lower, overwrite_a, …]) | 解线性方程组 |
solve_triangular(a, b[, trans, lower, …]) | 解上(或下)三角线性方程组 |
sqrtm(A[, blocksize]) | 计算矩阵的平方根 |
svd() | 计算奇异值分解 |
| toeplitz
(c[, r]) | 构造 Toeplitz 矩阵 | ## jax.scipy.ndimage
| map_coordinates
(input, coordinates, order[, …]) | 使用插值将输入数组映射到新坐标。 | ## jax.scipy.optimize
minimize(fun, x0[, args, tol, options]) | 最小化一个或多个变量的标量函数。 |
---|
| OptimizeResults
(x, success, status, fun, …) | 优化结果对象。 | ## jax.scipy.signal
fftconvolve(in1, in2[, mode, axes]) | 使用快速傅里叶变换(FFT)卷积两个 N 维数组。 |
---|---|
convolve(in1, in2[, mode, method, precision]) | 两个 N 维数组的卷积。 |
convolve2d(in1, in2[, mode, boundary, …]) | 两个二维数组的卷积。 |
correlate(in1, in2[, mode, method, precision]) | 两个 N 维数组的互相关。 |
correlate2d(in1, in2[, mode, boundary, …]) | 两个二维数组的互相关。 |
csd(x, y[, fs, window, nperseg, noverlap, …]) | 使用 Welch 方法估计交叉功率谱密度(CSD)。 |
detrend(data[, axis, type, bp, overwrite_data]) | 从数据中移除线性或分段线性趋势。 |
istft(Zxx[, fs, window, nperseg, noverlap, …]) | 执行逆短时傅里叶变换(ISTFT)。 |
stft(x[, fs, window, nperseg, noverlap, …]) | 计算短时傅里叶变换(STFT)。 |
| welch
(x[, fs, window, nperseg, noverlap, …]) | 使用 Welch 方法估计功率谱密度(PSD)。 | ## jax.scipy.spatial.transform
Rotation(quat) | 三维旋转。 |
---|
| Slerp
(times, timedelta, rotations, rotvecs) | 球面线性插值旋转。 | ## jax.scipy.sparse.linalg
bicgstab(A, b[, x0, tol, atol, maxiter, M]) | 使用双共轭梯度稳定迭代解决 Ax = b。 |
---|---|
cg(A, b[, x0, tol, atol, maxiter, M]) | 使用共轭梯度法解决 Ax = b。 |
| gmres
(A, b[, x0, tol, atol, restart, …]) | GMRES 解决线性系统 A x = b
,给定 A 和 b。 | ## jax.scipy.special
bernoulli(n) | 生成前 N 个伯努利数。 |
---|---|
beta() | 贝塔函数 |
betainc(a, b, x) | 正则化的不完全贝塔函数。 |
betaln(a, b) | 贝塔函数绝对值的自然对数 |
digamma(x) | Digamma 函数 |
entr(x) | 熵函数 |
erf(x) | 误差函数 |
erfc(x) | 误差函数的补函数 |
erfinv(x) | 误差函数的反函数 |
exp1(x) | 指数积分函数。 |
expi | 指数积分函数。 |
expit(x) | 逻辑 sigmoid(expit)函数 |
expn | 广义指数积分函数。 |
factorial(n[, exact]) | 阶乘函数 |
gamma(x) | 伽马函数。 |
gammainc(a, x) | 正则化的下不完全伽马函数。 |
gammaincc(a, x) | 正则化的上不完全伽马函数。 |
gammaln(x) | 伽马函数绝对值的自然对数。 |
gammasgn(x) | 伽马函数的符号。 |
hyp1f1 | 1F1 超几何函数。 |
i0(x) | 修改贝塞尔函数零阶。 |
i0e(x) | 指数缩放的修改贝塞尔函数零阶。 |
i1(x) | 修改贝塞尔函数一阶。 |
i1e(x) | 指数缩放的修改贝塞尔函数一阶。 |
log_ndtr | 对数正态分布函数。 |
logit | 对数几率函数。 |
logsumexp() | 对数-总和-指数归约。 |
lpmn(m, n, z) | 第一类相关勒让德函数(ALFs)。 |
lpmn_values(m, n, z, is_normalized) | 第一类相关勒让德函数(ALFs)。 |
multigammaln(a, d) | 多变量伽马函数的自然对数。 |
ndtr(x) | 正态分布函数。 |
ndtri§ | 正态分布函数的反函数。 |
poch | Pochhammer 符号。 |
polygamma(n, x) | 多次伽马函数。 |
spence(x) | 斯宾斯函数,也称实数域下的二元对数函数。 |
sph_harm(m, n, theta, phi[, n_max]) | 计算球谐函数。 |
xlog1py | 计算 x*log(1 y),当 x=0 时返回 0。 |
xlogy | 计算 x*log(y),当 x=0 时返回 0。 |
zeta | 赫维茨 ζ 函数。 |
kl_div(p, q) | 库尔巴克-莱布勒散度。 |
| rel_entr
(p, q) | 相对熵函数。 | ## jax.scipy.stats
mode(a[, axis, nan_policy, keepdims]) | 计算数组沿轴的众数(最常见的值)。 |
---|---|
rankdata(a[, method, axis, nan_policy]) | 计算数组沿轴的排名。 |
sem(a[, axis, ddof, nan_policy, keepdims]) | 计算均值的标准误差。 |
jax.scipy.stats.bernoulli
logpmf(k, p[, loc]) | 伯努利对数概率质量函数。 |
---|---|
pmf(k, p[, loc]) | 伯努利概率质量函数。 |
cdf(k, p) | 伯努利累积分布函数。 |
| ppf
(q, p) | 伯努利百分位点函数。 | ### jax.scipy.stats.beta
logpdf(x, a, b[, loc, scale]) | Beta 对数概率分布函数。 |
---|---|
pdf(x, a, b[, loc, scale]) | Beta 概率分布函数。 |
cdf(x, a, b[, loc, scale]) | Beta 累积分布函数。 |
logcdf(x, a, b[, loc, scale]) | Beta 对数累积分布函数。 |
sf(x, a, b[, loc, scale]) | Beta 分布生存函数。 |
| logsf
(x, a, b[, loc, scale]) | Beta 分布对数生存函数。 | ### jax.scipy.stats.betabinom
logpmf(k, n, a, b[, loc]) | Beta-二项式对数概率质量函数。 |
---|
| pmf
(k, n, a, b[, loc]) | Beta-二项式概率质量函数。 | ### jax.scipy.stats.binom
logpmf(k, n, p[, loc]) | 二项式对数概率质量函数。 |
---|
| pmf
(k, n, p[, loc]) | 二项式概率质量函数。 | ### jax.scipy.stats.cauchy
logpdf(x[, loc, scale]) | 柯西对数概率分布函数。 |
---|---|
pdf(x[, loc, scale]) | 柯西概率分布函数。 |
cdf(x[, loc, scale]) | 柯西累积分布函数。 |
logcdf(x[, loc, scale]) | 柯西对数累积分布函数。 |
sf(x[, loc, scale]) | 柯西分布对数生存函数。 |
logsf(x[, loc, scale]) | 柯西对数生存函数。 |
isf(q[, loc, scale]) | 柯西分布逆生存函数。 |
| ppf
(q[, loc, scale]) | 柯西分布分位点函数。 | ### jax.scipy.stats.chi2
logpdf(x, df[, loc, scale]) | 卡方分布对数概率分布函数。 |
---|---|
pdf(x, df[, loc, scale]) | 卡方概率分布函数。 |
cdf(x, df[, loc, scale]) | 卡方累积分布函数。 |
logcdf(x, df[, loc, scale]) | 卡方对数累积分布函数。 |
sf(x, df[, loc, scale]) | 卡方生存函数。 |
| logsf
(x, df[, loc, scale]) | 卡方对数生存函数。 | ### jax.scipy.stats.dirichlet
logpdf(x, alpha) | 狄利克雷对数概率分布函数。 |
---|
| pdf
(x, alpha) | 狄利克雷概率分布函数。 | ### jax.scipy.stats.expon
logpdf(x[, loc, scale]) | 指数对数概率分布函数。 |
---|
| pdf
(x[, loc, scale]) | 指数概率分布函数。 | ### jax.scipy.stats.gamma
logpdf(x, a[, loc, scale]) | 伽玛对数概率分布函数。 |
---|---|
pdf(x, a[, loc, scale]) | 伽玛概率分布函数。 |
cdf(x, a[, loc, scale]) | 伽玛累积分布函数。 |
logcdf(x, a[, loc, scale]) | 伽玛对数累积分布函数。 |
sf(x, a[, loc, scale]) | 伽玛生存函数。 |
| logsf
(x, a[, loc, scale]) | 伽玛对数生存函数。 | ### jax.scipy.stats.gennorm
cdf(x, beta) | 广义正态累积分布函数。 |
---|---|
logpdf(x, beta) | 广义正态对数概率分布函数。 |
| pdf
(x, beta) | 广义正态概率分布函数。 | ### jax.scipy.stats.geom
logpmf(k, p[, loc]) | 几何对数概率质量函数。 |
---|
| pmf
(k, p[, loc]) | 几何概率质量函数。 | ### jax.scipy.stats.laplace
cdf(x[, loc, scale]) | 拉普拉斯累积分布函数。 |
---|---|
logpdf(x[, loc, scale]) | 拉普拉斯对数概率分布函数。 |
| pdf
(x[, loc, scale]) | 拉普拉斯概率分布函数。 | ### jax.scipy.stats.logistic
cdf(x[, loc, scale]) | Logistic 累积分布函数。 |
---|---|
isf(x[, loc, scale]) | Logistic 分布逆生存函数。 |
logpdf(x[, loc, scale]) | Logistic 对数概率分布函数。 |
pdf(x[, loc, scale]) | Logistic 概率分布函数。 |
ppf(x[, loc, scale]) | Logistic 分位点函数。 |
| sf
(x[, loc, scale]) | Logistic 分布生存函数。 | ### jax.scipy.stats.multinomial
logpmf(x, n, p) | 多项式对数概率质量函数。 |
---|
| pmf
(x, n, p) | 多项分布概率质量函数。 | ### jax.scipy.stats.multivariate_normal
logpdf(x, mean, cov[, allow_singular]) | 多元正态分布对数概率分布函数。 |
---|
| pdf
(x, mean, cov) | 多元正态分布概率分布函数。 | ### jax.scipy.stats.nbinom
logpmf(k, n, p[, loc]) | 负二项分布对数概率质量函数。 |
---|
| pmf
(k, n, p[, loc]) | 负二项分布概率质量函数。 | ### jax.scipy.stats.norm
logpdf(x[, loc, scale]) | 正态分布对数概率分布函数。 |
---|---|
pdf(x[, loc, scale]) | 正态分布概率分布函数。 |
cdf(x[, loc, scale]) | 正态分布累积分布函数。 |
logcdf(x[, loc, scale]) | 正态分布对数累积分布函数。 |
ppf(q[, loc, scale]) | 正态分布百分点函数。 |
sf(x[, loc, scale]) | 正态分布生存函数。 |
logsf(x[, loc, scale]) | 正态分布对数生存函数。 |
| isf
(q[, loc, scale]) | 正态分布逆生存函数。 | ### jax.scipy.stats.pareto
logpdf(x, b[, loc, scale]) | 帕累托对数概率分布函数。 |
---|
| pdf
(x, b[, loc, scale]) | 帕累托分布概率分布函数。 | ### jax.scipy.stats.poisson
logpmf(k, mu[, loc]) | 泊松分布对数概率质量函数。 |
---|---|
pmf(k, mu[, loc]) | 泊松分布概率质量函数。 |
| cdf
(k, mu[, loc]) | 泊松分布累积分布函数。 | ### jax.scipy.stats.t
logpdf(x, df[, loc, scale]) | 学生 t 分布对数概率分布函数。 |
---|
| pdf
(x, df[, loc, scale]) | 学生 t 分布概率分布函数。 | ### jax.scipy.stats.truncnorm
cdf(x, a, b[, loc, scale]) | 截断正态分布累积分布函数。 |
---|---|
logcdf(x, a, b[, loc, scale]) | 截断正态分布对数累积分布函数。 |
logpdf(x, a, b[, loc, scale]) | 截断正态分布对数概率分布函数。 |
logsf(x, a, b[, loc, scale]) | 截断正态分布对数生存函数。 |
pdf(x, a, b[, loc, scale]) | 截断正态分布概率分布函数。 |
| sf
(x, a, b[, loc, scale]) | 截断正态分布对数生存函数。 | ### jax.scipy.stats.uniform
logpdf(x[, loc, scale]) | 均匀分布对数概率分布函数。 |
---|---|
pdf(x[, loc, scale]) | 均匀分布概率分布函数。 |
cdf(x[, loc, scale]) | 均匀分布累积分布函数。 |
ppf(q[, loc, scale]) | 均匀分布百分点函数。 |
jax.scipy.stats.gaussian_kde
gaussian_kde(dataset[, bw_method, weights]) | 高斯核密度估计器 |
---|---|
gaussian_kde.evaluate(points) | 对给定点评估高斯核密度估计器。 |
gaussian_kde.integrate_gaussian(mean, cov) | 加权高斯积分分布。 |
gaussian_kde.integrate_box_1d(low, high) | 在给定限制下积分分布。 |
gaussian_kde.integrate_kde(other) | 集成两个高斯核密度估计分布的乘积。 |
gaussian_kde.resample(key[, shape]) | 从估计的概率密度函数中随机采样数据集 |
gaussian_kde.pdf(x) | 概率密度函数 |
gaussian_kde.logpdf(x) | 对数概率密度函数 |
jax.scipy.stats.vonmises
logpdf(x, kappa) | von Mises 对数概率分布函数。 |
---|
| pdf
(x, kappa) | von Mises 概率分布函数。 | ### jax.scipy.stats.wrapcauchy
logpdf(x, c) | Wrapped Cauchy 对数概率分布函数。 |
---|---|
pdf(x, c) | Wrapped Cauchy 概率分布函数。 |
jax.scipy.stats.bernoulli.logpmf
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.logpmf.html
jax.scipy.stats.bernoulli.logpmf(k, p, loc=0)
伯努利对数概率质量函数。
scipy.stats.bernoulli
的 JAX 实现 logpmf
伯努利概率质量函数定义如下
[begin{split}f(k) = begin{cases} 1 - p, & k = 0 p, & k = 1 0, & mathrm{otherwise} end{cases}end{split}]
参数:
- k (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,要评估 PMF 的值
- p (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布形状参数
- loc (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布偏移量
返回值:
logpmf 值的数组
返回类型:
Array
另请参阅
-
jax.scipy.stats.bernoulli.cdf()
-
jax.scipy.stats.bernoulli.pmf()
-
jax.scipy.stats.bernoulli.ppf()
jax.scipy.stats.bernoulli.pmf
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.pmf.html
jax.scipy.stats.bernoulli.pmf(k, p, loc=0)
伯努利概率质量函数。
scipy.stats.bernoulli
pmf
的 JAX 实现
伯努利概率质量函数定义为
[begin{split}f(k) = begin{cases} 1 - p, & k = 0 p, & k = 1 0, & mathrm{otherwise} end{cases}end{split}]
参数:
- k (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,要评估 PMF 的值
- p (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,分布形状参数
- loc (数组 | ndarray | 布尔 | 数值 | 布尔 | 整数 | 浮点数 | 复数*) – 类似数组,分布偏移
返回:
pmf 值数组
返回类型:
数组
参见
-
jax.scipy.stats.bernoulli.cdf()
-
jax.scipy.stats.bernoulli.logpmf()
-
jax.scipy.stats.bernoulli.ppf()
jax.scipy.stats.bernoulli.cdf
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.cdf.html
jax.scipy.stats.bernoulli.cdf(k, p)
伯努利累积分布函数。
scipy.stats.bernoulli
的 JAX 实现 cdf
伯努利累积分布函数被定义为:
[f_{cdf}(k, p) = sum_{i=0}^k f_{pmf}(k, p)]
其中 (f_{pmf}(k, p)) 是伯努利概率质量函数 jax.scipy.stats.bernoulli.pmf()
。
参数:
- k (Array | ndarray | bool | number | bool | int | float | complex) – 数组,用于评估 CDF 的值
- p (Array | ndarray | bool | number | bool | int | float | complex) – 数组,分布形状参数
- loc – 数组,分布偏移
返回:
cdf 值的数组
返回类型:
Array
另请参见
-
jax.scipy.stats.bernoulli.logpmf()
-
jax.scipy.stats.bernoulli.pmf()
-
jax.scipy.stats.bernoulli.ppf()
jax.scipy.stats.bernoulli.ppf
代码语言:javascript复制原文:
jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.ppf.html
jax.scipy.stats.bernoulli.ppf(q, p)
伯努利百分点函数。
JAX 实现的 scipy.stats.bernoulli
ppf
百分点函数是累积分布函数的反函数,jax.scipy.stats.bernoulli.cdf()
。
参数:
- k – arraylike,评估 PPF 的值
- p (Array | ndarray | bool | number | bool | int | float | complex) – arraylike,分布形状参数
- loc – arraylike,分布偏移
- q (Array | ndarray | bool | number | bool | int | float | complex)
返回:
ppf 值数组
返回类型:
Array
另见
-
jax.scipy.stats.bernoulli.cdf()
-
jax.scipy.stats.bernoulli.logpmf()
-
jax.scipy.stats.bernoulli.pmf()
jax.lax 模块
原文:
jax.readthedocs.io/en/latest/jax.lax.html
jax.lax
是支持诸如 jax.numpy
等库的基本操作的库。通常会定义转换规则,例如 JVP 和批处理规则,作为对 jax.lax
基元的转换。
许多基元都是等价于 XLA 操作的薄包装,详细描述请参阅XLA 操作语义文档。
在可能的情况下,优先使用诸如 jax.numpy
等库,而不是直接使用 jax.lax
。jax.numpy
API 遵循 NumPy,因此比 jax.lax
API 更稳定,更不易更改。
Operators
abs(x) | 按元素绝对值:(|x|)。 |
---|---|
acos(x) | 按元素求反余弦:(mathrm{acos}(x))。 |
acosh(x) | 按元素求反双曲余弦:(mathrm{acosh}(x))。 |
add(x, y) | 按元素加法:(x y)。 |
after_all(*operands) | 合并一个或多个 XLA 令牌值。 |
approx_max_k(operand, k[, …]) | 以近似方式返回 operand 的最大 k 值及其索引。 |
approx_min_k(operand, k[, …]) | 以近似方式返回 operand 的最小 k 值及其索引。 |
argmax(operand, axis, index_dtype) | 计算沿着 axis 的最大元素的索引。 |
argmin(operand, axis, index_dtype) | 计算沿着 axis 的最小元素的索引。 |
asin(x) | 按元素求反正弦:(mathrm{asin}(x))。 |
asinh(x) | 按元素求反双曲正弦:(mathrm{asinh}(x))。 |
atan(x) | 按元素求反正切:(mathrm{atan}(x))。 |
atan2(x, y) | 两个变量的按元素反正切:(mathrm{atan}({x over y}))。 |
atanh(x) | 按元素求反双曲正切:(mathrm{atanh}(x))。 |
batch_matmul(lhs, rhs[, precision]) | 批量矩阵乘法。 |
bessel_i0e(x) | 指数缩放修正贝塞尔函数 (0) 阶:(mathrm{i0e}(x) = e^{-|x|} mathrm{i0}(x)) |
bessel_i1e(x) | 指数缩放修正贝塞尔函数 (1) 阶:(mathrm{i1e}(x) = e^{-|x|} mathrm{i1}(x)) |
betainc(a, b, x) | 按元素的正则化不完全贝塔积分。 |
bitcast_convert_type(operand, new_dtype) | 按元素位转换。 |
bitwise_and(x, y) | 按位与运算:(x wedge y)。 |
bitwise_not(x) | 按位取反:(neg x)。 |
bitwise_or(x, y) | 按位或运算:(x vee y)。 |
bitwise_xor(x, y) | 按位异或运算:(x oplus y)。 |
population_count(x) | 按元素计算 popcount,即每个元素中设置的位数。 |
broadcast(operand, sizes) | 广播数组,添加新的前导维度。 |
broadcast_in_dim(operand, shape, …) | 包装 XLA 的 BroadcastInDim 操作符。 |
broadcast_shapes() | 返回经过 NumPy 广播后的形状。 |
broadcast_to_rank(x, rank) | 添加 1 的前导维度,使 x 的等级为 rank。 |
broadcasted_iota(dtype, shape, dimension) | iota的便捷封装器。 |
cbrt(x) | 元素级立方根:(sqrt[3]{x})。 |
ceil(x) | 元素级向上取整:(leftlceil x rightrceil)。 |
clamp(min, x, max) | 元素级 clamp 函数。 |
clz(x) | 元素级计算前导零的个数。 |
collapse(operand, start_dimension[, …]) | 将数组的维度折叠为单个维度。 |
complex(x, y) | 元素级构造复数:(x jy)。 |
concatenate(operands, dimension) | 沿指定维度连接一系列数组。 |
conj(x) | 元素级复数的共轭函数:(overline{x})。 |
conv(lhs, rhs, window_strides, padding[, …]) | conv_general_dilated的便捷封装器。 |
convert_element_type(operand, new_dtype) | 元素级类型转换。 |
conv_dimension_numbers(lhs_shape, rhs_shape, …) | 将卷积维度编号转换为 ConvDimensionNumbers。 |
conv_general_dilated(lhs, rhs, …[, …]) | 带有可选扩展的通用 n 维卷积运算符。 |
conv_general_dilated_local(lhs, rhs, …[, …]) | 带有可选扩展的通用 n 维非共享卷积运算符。 |
conv_general_dilated_patches(lhs, …[, …]) | 提取符合 conv_general_dilated 接受域的补丁。 |
conv_transpose(lhs, rhs, strides, padding[, …]) | 计算 N 维卷积的“转置”的便捷封装器。 |
conv_with_general_padding(lhs, rhs, …[, …]) | conv_general_dilated的便捷封装器。 |
cos(x) | 元素级余弦函数:(mathrm{cos}(x))。 |
cosh(x) | 元素级双曲余弦函数:(mathrm{cosh}(x))。 |
cumlogsumexp(operand[, axis, reverse]) | 沿轴计算累积 logsumexp。 |
cummax(operand[, axis, reverse]) | 沿轴计算累积最大值。 |
cummin(operand[, axis, reverse]) | 沿轴计算累积最小值。 |
cumprod(operand[, axis, reverse]) | 沿轴计算累积乘积。 |
cumsum(operand[, axis, reverse]) | 沿轴计算累积和。 |
digamma(x) | 元素级 digamma 函数:(psi(x))。 |
div(x, y) | 元素级除法:(x over y)。 |
dot(lhs, rhs[, precision, …]) | 向量/向量,矩阵/向量和矩阵/矩阵乘法。 |
dot_general(lhs, rhs, dimension_numbers[, …]) | 通用的点积/收缩运算符。 |
dynamic_index_in_dim(operand, index[, axis, …]) | 对 dynamic_slice 的便捷封装,用于执行整数索引。 |
dynamic_slice(operand, start_indices, …) | 封装了 XLA 的 DynamicSlice 操作符。 |
dynamic_slice_in_dim(operand, start_index, …) | 方便地封装了应用于单个维度的 lax.dynamic_slice()。 |
dynamic_update_index_in_dim(operand, update, …) | 方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新大小为 1 的切片。 |
dynamic_update_slice(operand, update, …) | 封装了 XLA 的 DynamicUpdateSlice 操作符。 |
dynamic_update_slice_in_dim(operand, update, …) | 方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新一个切片。 |
eq(x, y) | 元素级相等:(x = y)。 |
erf(x) | 元素级误差函数:(mathrm{erf}(x))。 |
erfc(x) | 元素级补充误差函数:(mathrm{erfc}(x) = 1 - mathrm{erf}(x))。 |
erf_inv(x) | 元素级反误差函数:(mathrm{erf}^{-1}(x))。 |
exp(x) | 元素级指数函数:(e^x)。 |
expand_dims(array, dimensions) | 将任意数量的大小为 1 的维度插入到数组中。 |
expm1(x) | 元素级运算 (e^{x} - 1)。 |
fft(x, fft_type, fft_lengths) | |
floor(x) | 元素级向下取整:(leftlfloor x rightrfloor)。 |
full(shape, fill_value[, dtype, sharding]) | 返回填充值为 fill_value 的形状数组。 |
full_like(x, fill_value[, dtype, shape, …]) | 基于示例数组 x 创建类似于 np.full 的完整数组。 |
gather(operand, start_indices, …[, …]) | Gather 操作符。 |
ge(x, y) | 元素级大于或等于:(x geq y)。 |
gt(x, y) | 元素级大于:(x > y)。 |
igamma(a, x) | 元素级正则化不完全 gamma 函数。 |
igammac(a, x) | 元素级补充正则化不完全 gamma 函数。 |
imag(x) | 提取复数的虚部:(mathrm{Im}(x))。 |
index_in_dim(operand, index[, axis, keepdims]) | 方便地封装了 lax.slice(),用于执行整数索引。 |
index_take(src, idxs, axes) | |
integer_pow(x, y) | 元素级幂运算:(x^y),其中 (y) 是固定整数。 |
iota(dtype, size) | 封装了 XLA 的 Iota 操作符。 |
is_finite(x) | 元素级 (mathrm{isfinite})。 |
le(x, y) | 元素级小于或等于:(x leq y)。 |
lgamma(x) | 元素级对数 gamma 函数:(mathrm{log}(Gamma(x)))。 |
log(x) | 元素级自然对数:(mathrm{log}(x))。 |
log1p(x) | 元素级 (mathrm{log}(1 x))。 |
logistic(x) | 元素级 logistic(sigmoid)函数:(frac{1}{1 e^{-x}})。 |
lt(x, y) | 元素级小于:(x < y)。 |
max(x, y) | 元素级最大值:(mathrm{max}(x, y)) |
min(x, y) | 元素级最小值:(mathrm{min}(x, y)) |
mul(x, y) | 元素级乘法:(x times y)。 |
ne(x, y) | 按位不等于:(x neq y)。 |
neg(x) | 按位取负:(-x)。 |
nextafter(x1, x2) | 返回 x1 在 x2 方向上的下一个可表示的值。 |
pad(operand, padding_value, padding_config) | 对数组应用低、高和/或内部填充。 |
polygamma(m, x) | 按位多次 gamma 函数:(psi^{(m)}(x))。 |
population_count(x) | 按位人口统计,统计每个元素中设置的位数。 |
pow(x, y) | 按位幂运算:(x^y)。 |
random_gamma_grad(a, x) | Gamma 分布导数的按位计算。 |
real(x) | 按位提取实部:(mathrm{Re}(x))。 |
reciprocal(x) | 按位倒数:(1 over x)。 |
reduce(operands, init_values, computation, …) | 封装了 XLA 的 Reduce 运算符。 |
reduce_precision(operand, exponent_bits, …) | 封装了 XLA 的 ReducePrecision 运算符。 |
reduce_window(operand, init_value, …[, …]) | |
rem(x, y) | 按位取余:(x bmod y)。 |
reshape(operand, new_sizes[, dimensions]) | 封装了 XLA 的 Reshape 运算符。 |
rev(operand, dimensions) | 封装了 XLA 的 Rev 运算符。 |
rng_bit_generator(key, shape[, dtype, algorithm]) | 无状态的伪随机数位生成器。 |
rng_uniform(a, b, shape) | 有状态的伪随机数生成器。 |
round(x[, rounding_method]) | 按位四舍五入。 |
rsqrt(x) | 按位倒数平方根:(1 over sqrt{x})。 |
scatter(operand, scatter_indices, updates, …) | Scatter-update 运算符。 |
scatter_add(operand, scatter_indices, …[, …]) | Scatter-add 运算符。 |
scatter_apply(operand, scatter_indices, …) | Scatter-apply 运算符。 |
scatter_max(operand, scatter_indices, …[, …]) | Scatter-max 运算符。 |
scatter_min(operand, scatter_indices, …[, …]) | Scatter-min 运算符。 |
scatter_mul(operand, scatter_indices, …[, …]) | Scatter-multiply 运算符。 |
shift_left(x, y) | 按位左移:(x ll y)。 |
shift_right_arithmetic(x, y) | 按位算术右移:(x gg y)。 |
shift_right_logical(x, y) | 按位逻辑右移:(x gg y)。 |
sign(x) | 按位符号函数。 |
sin(x) | 按位正弦函数:(mathrm{sin}(x))。 |
sinh(x) | 按位双曲正弦函数:(mathrm{sinh}(x))。 |
slice(operand, start_indices, limit_indices) | 封装了 XLA 的 Slice 运算符。 |
slice_in_dim(operand, start_index, limit_index) | lax.slice() 的单维度应用封装。 |
sort() | 封装了 XLA 的 Sort 运算符。 |
sort_key_val(keys, values[, dimension, …]) | 沿着dimension排序keys并对values应用相同的置换。 |
sqrt(x) | 逐元素平方根:(sqrt{x})。 |
square(x) | 逐元素平方:(x²)。 |
squeeze(array, dimensions) | 从数组中挤出任意数量的大小为 1 的维度。 |
sub(x, y) | 逐元素减法:(x - y)。 |
tan(x) | 逐元素正切:(mathrm{tan}(x))。 |
tanh(x) | 逐元素双曲正切:(mathrm{tanh}(x))。 |
top_k(operand, k) | 返回operand最后一轴上的前k个值及其索引。 |
transpose(operand, permutation) | 包装 XLA 的Transpose运算符。 |
zeros_like_array(x) | |
zeta(x, q) | 逐元素 Hurwitz zeta 函数:(zeta(x, q)) |
控制流操作符
associative_scan(fn, elems[, reverse, axis]) | 使用关联二元操作并行执行扫描。 |
---|---|
cond(pred, true_fun, false_fun, *operands[, …]) | 根据条件应用true_fun或false_fun。 |
fori_loop(lower, upper, body_fun, init_val, *) | 通过归约到jax.lax.while_loop()从lower到upper循环。 |
map(f, xs) | 在主要数组轴上映射函数。 |
scan(f, init[, xs, length, reverse, unroll, …]) | 在主要数组轴上扫描函数并携带状态。 |
select(pred, on_true, on_false) | 根据布尔谓词在两个分支之间选择。 |
select_n(which, *cases) | 从多个情况中选择数组值。 |
switch(index, branches, *operands[, operand]) | 根据index应用恰好一个branches。 |
while_loop(cond_fun, body_fun, init_val) | 在cond_fun为 True 时重复调用body_fun。 |
自定义梯度操作符
stop_gradient(x) | 停止梯度计算。 |
---|---|
custom_linear_solve(matvec, b, solve[, …]) | 使用隐式定义的梯度执行无矩阵线性求解。 |
custom_root(f, initial_guess, solve, …[, …]) | 可微分求解函数的根。 |
并行操作符
all_gather(x, axis_name, *[, …]) | 在所有副本中收集x的值。 |
---|---|
all_to_all(x, axis_name, split_axis, …[, …]) | 映射轴的实例化和映射不同轴。 |
pdot(x, y, axis_name[, pos_contract, …]) | |
psum(x, axis_name, *[, axis_index_groups]) | 在映射的轴axis_name上进行全归约求和。 |
psum_scatter(x, axis_name, *[, …]) | 像psum(x, axis_name),但每个设备仅保留部分结果。 |
pmax(x, axis_name, *[, axis_index_groups]) | 在映射的轴axis_name上计算全归约最大值。 |
pmin(x, axis_name, *[, axis_index_groups]) | 在映射的轴axis_name上计算全归约最小值。 |
pmean(x, axis_name, *[, axis_index_groups]) | 在映射的轴axis_name上计算全归约均值。 |
ppermute(x, axis_name, perm) | 根据置换 perm 执行集体置换。 |
pshuffle(x, axis_name, perm) | 使用替代置换编码的 jax.lax.ppermute 的便捷包装器 |
pswapaxes(x, axis_name, axis, *[, …]) | 将 pmapped 轴 axis_name 与非映射轴 axis 交换。 |
axis_index(axis_name) | 返回沿映射轴 axis_name 的索引。 |
与分片相关的操作符
with_sharding_constraint(x, shardings) | 在 jitted 计算中约束数组的分片机制 |
---|
线性代数操作符 (jax.lax.linalg)
cholesky(x, *[, symmetrize_input]) | Cholesky 分解。 |
---|---|
eig(x, *[, compute_left_eigenvectors, …]) | 一般矩阵的特征分解。 |
eigh(x, *[, lower, symmetrize_input, …]) | Hermite 矩阵的特征分解。 |
hessenberg(a) | 将方阵约化为上 Hessenberg 形式。 |
lu(x) | 带有部分主元列主元分解。 |
householder_product(a, taus) | 单元 Householder 反射的乘积。 |
qdwh(x, *[, is_hermitian, max_iterations, …]) | 基于 QR 的动态加权 Halley 迭代进行极分解。 |
qr(x, *[, full_matrices]) | QR 分解。 |
schur(x, *[, compute_schur_vectors, …]) | |
svd() | 奇异值分解。 |
triangular_solve(a, b, *[, left_side, …]) | 三角解法。 |
tridiagonal(a, *[, lower]) | 将对称/Hermitian 矩阵约化为三对角形式。 |
tridiagonal_solve(dl, d, du, b) | 计算三对角线性系统的解。 |
参数类
代码语言:javascript复制class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
描述卷积的批量、空间和特征维度。
参数:
- lhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
- rhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(输出特征维度,输入特征维度,空间维度…)。
- out_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
jax.lax.ConvGeneralDilatedDimensionNumbers
alias of tuple
[str
, str
, str
] | ConvDimensionNumbers
| None
class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)
描述了传递给 XLA 的 Gather 运算符 的维度号参数。有关维度号含义的详细信息,请参阅 XLA 文档。
Parameters:
- offset_dims (tuple[int, …**]) – gather 输出中偏移到从操作数切片的数组中的维度的集合。必须是升序整数元组,每个代表输出的一个维度编号。
- collapsed_slice_dims (tuple[int, …**]) – operand 中具有 slice_sizes[i] == 1 的维度 i 的集合,这些维度不应在 gather 输出中具有对应维度。必须是一个升序整数元组。
- start_index_map (tuple[int, …**]) – 对于 start_indices 中的每个维度,给出应该被切片的操作数中对应的维度。必须是一个大小等于 start_indices.shape[-1] 的整数元组。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐含的;总是存在一个索引向量维度,且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。
代码语言:javascript复制class jax.lax.GatherScatterMode(value)
描述了如何处理 gather 或 scatter 中的越界索引。
可能的值包括:
CLIP:
索引将被夹在最近的范围内值上,即整个要收集的窗口都在范围内。
FILL_OR_DROP:
如果收集窗口的任何部分越界,则返回整个窗口,即使其他部分原本在界内的元素也将用常量填充。如果分散窗口的任何部分越界,则整个窗口将被丢弃。
PROMISE_IN_BOUNDS:
用户承诺索引在范围内。不会执行额外检查。实际上,根据当前的 XLA 实现,这意味着越界的 gather 将被夹在范围内,但越界的 scatter 将被丢弃。如果索引越界,则梯度将不正确。
代码语言:javascript复制class jax.lax.Precision(value)
lax 函数的精度枚举
JAX 函数的精度参数通常控制加速器后端(即 TPU 和 GPU)上的数组计算速度和精度之间的权衡。成员包括:
默认:
最快模式,但最不准确。在 bfloat16 中执行计算。别名:'default'
,'fastest'
,'bfloat16'
。
高:
较慢但更准确。以 3 个 bfloat16 传递执行 float32 计算,或在可用时使用 tensorfloat32。别名:'high'
,'bfloat16_3x'
,'tensorfloat32'
。
最高:
最慢但最准确。根据适用情况在 float32 或 float64 中执行计算。别名:'highest'
,'float32'
。
jax.lax.PrecisionLike
别名为 str
| Precision
| tuple
[str
, str
] | tuple
[Precision
, Precision
] | None
class jax.lax.RoundingMethod(value)
一个枚举。
代码语言:javascript复制class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)
描述了对 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
参数:
- update_window_dims (Sequence[int]) – 更新中作为窗口维度的维度集合。必须是整数元组,按升序排列,每个表示一个维度编号。
- inserted_window_dims (Sequence[int]) – 必须插入更新形状的大小为 1 的窗口维度集合。必须是整数元组,按升序排列,每个表示输出的维度编号的镜像图。这些是 gather 情况下 collapsed_slice_dims 的镜像图。
- scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,大小等于 scatter_indices.shape[-1]。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是有一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,添加一个尺寸为 1 的尾随维度。
jax.random 模块
原文:
jax.readthedocs.io/en/latest/jax.random.html
伪随机数生成的实用程序。
jax.random
包提供了多种例程,用于确定性生成伪随机数序列。
基本用法
代码语言:javascript复制>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG keys
与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key()
函数生成:
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
然后,可以在 JAX 的任何随机数生成例程中使用该 key:
代码语言:javascript复制>>> random.uniform(key)
Array(0.41845703, dtype=float32)
请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:
代码语言:javascript复制>>> random.uniform(key)
Array(0.41845703, dtype=float32)
如果需要新的随机数,可以使用 jax.random.split()
生成新的子 key:
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)
注意
类型化的 key 数组,例如上述 key<fry>
,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32
数组表示,其最终维度表示 key 的位级表示。
两种形式的 key 数组仍然可以通过 jax.random
模块创建和使用。新式的类型化 key 数组使用 jax.random.key()
创建。传统的 uint32
key 数组使用 jax.random.PRNGKey()
创建。
要在两者之间进行转换,使用 jax.random.key_data()
和 jax.random.wrap_key_data()
。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。
否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:
- 它们有一个额外的尾维度。
- 它们具有数字数据类型 (
uint32
),允许进行通常不用于 key 的操作,例如整数算术。 - 它们不包含有关 RNG 实现的信息。当传统 key 传递给
jax.random
函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。
要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263。
高级
设计和背景
TLDR:JAX PRNG = Threefry counter PRNG 一个功能数组导向的 分裂模型
更多详细信息,请参阅 docs/jep/263-prng.md。
总结一下,JAX PRNG 还包括但不限于以下要求:
- 确保可重现性,
- 良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。
高级 RNG 配置
JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。
- 默认,“threefry2x32”: 基于 Threefry 哈希函数构建的基于计数器的 PRNG。
- 实验性 一种仅包装了 XLA 随机位生成器(RBG)算法的 PRNG。请参阅 TF 文档。
- “rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
- “unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。
这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。
不使用默认 RNG 的可能原因是:
- 可能编译速度较慢(特别是对于 Google Cloud TPU)
- 在 TPU 上执行速度较慢
- 不支持高效的自动分片/分区
这里是一个简短的总结:
属性 | Threefry | Threefry* | rbg | unsafe_rbg | rbg** | unsafe_rbg** |
---|---|---|---|---|---|---|
在 TPU 上最快 | ✅ | ✅ | ✅ | ✅ | ||
可以高效分片(使用 pjit) | ✅ | ✅ | ✅ | |||
在分片中相同 | ✅ | ✅ | ✅ | ✅ | ||
在 CPU/GPU/TPU 上相同 | ✅ | ✅ | ||||
在 JAX/XLA 版本间相同 | ✅ | ✅ |
(*): 设置了jax_threefry_partitionable=1
(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于 jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在 jax.random.split 和 jax.random.fold_in 中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。
要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
API 参考
密钥创建与操作
PRNGKey(seed, *[, impl]) | 给定整数种子创建伪随机数生成器(PRNG)密钥。 |
---|---|
key(seed, *[, impl]) | 给定整数种子创建伪随机数生成器(PRNG)密钥。 |
key_data(密钥) | 恢复 PRNG 密钥数组下的密钥数据位。 |
wrap_key_data(key_bits_array, *[, impl]) | 将密钥数据位数组包装成 PRNG 密钥数组。 |
fold_in(key, data) | 将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。 |
split(key[, num]) | 将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。 |
clone(key) | 克隆一个密钥以便重复使用。 |
随机抽样器
ball(key, d[, p, shape, dtype]) | 从单位 Lp 球中均匀采样。 |
---|---|
bernoulli(key[, p, shape]) | 采样给定形状和均值的伯努利分布随机值。 |
beta(key, a, b[, shape, dtype]) | 采样给定形状和浮点数数据类型的贝塔分布随机值。 |
binomial(key, n, p[, shape, dtype]) | 采样给定形状和浮点数数据类型的二项分布随机值。 |
bits(key[, shape, dtype]) | 以无符号整数的形式采样均匀比特。 |
categorical(key, logits[, axis, shape]) | 从分类分布中采样随机值。 |
cauchy(key[, shape, dtype]) | 采样给定形状和浮点数数据类型的柯西分布随机值。 |
chisquare(key, df[, shape, dtype]) | 采样给定形状和浮点数数据类型的卡方分布随机值。 |
choice(key, a[, shape, replace, p, axis]) | 从给定数组中生成随机样本。 |
dirichlet(key, alpha[, shape, dtype]) | 采样给定形状和浮点数数据类型的狄利克雷分布随机值。 |
double_sided_maxwell(key, loc, scale[, …]) | 从双边 Maxwell 分布中采样。 |
exponential(key[, shape, dtype]) | 采样给定形状和浮点数数据类型的指数分布随机值。 |
f(key, dfnum, dfden[, shape, dtype]) | 采样给定形状和浮点数数据类型的 F 分布随机值。 |
gamma(key, a[, shape, dtype]) | 采样给定形状和浮点数数据类型的伽马分布随机值。 |
generalized_normal(key, p[, shape, dtype]) | 从广义正态分布中采样。 |
geometric(key, p[, shape, dtype]) | 采样给定形状和浮点数数据类型的几何分布随机值。 |
gumbel(key[, shape, dtype]) | 采样给定形状和浮点数数据类型的 Gumbel 分布随机值。 |
laplace(key[, shape, dtype]) | 采样给定形状和浮点数数据类型的拉普拉斯分布随机值。 |
loggamma(key, a[, shape, dtype]) | 采样给定形状和浮点数数据类型的对数伽马分布随机值。 |
logistic(key[, shape, dtype]) | 采样给定形状和浮点数数据类型的 logistic 随机值。 |
lognormal(key[, sigma, shape, dtype]) | 采样给定形状和浮点数数据类型的对数正态分布随机值。 |
maxwell(key[, shape, dtype]) | 从单边 Maxwell 分布中采样。 |
multivariate_normal(key, mean, cov[, shape, …]) | 采样给定均值和协方差的多变量正态分布随机值。 |
normal(key[, shape, dtype]) | 采样给定形状和浮点数数据类型的标准正态分布随机值。 |
orthogonal(key, n[, shape, dtype]) | 从正交群 O(n) 中均匀采样。 |
pareto(key, b[, shape, dtype]) | 采样给定形状和浮点数数据类型的帕累托分布随机值。 |
permutation(key, x[, axis, independent]) | 返回随机排列的数组或范围。 |
poisson(key, lam[, shape, dtype]) | 采样给定形状和整数数据类型的泊松分布随机值。 |
rademacher(key[, shape, dtype]) | 从 Rademacher 分布中采样。 |
randint(key, shape, minval, maxval[, dtype]) | 用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。 |
[rayleigh(key, scale[, shape, dtype]) | 用给定的形状和浮点数数据类型示例瑞利随机值。 |
t(key, df[, shape, dtype]) | 用给定的形状和浮点数数据类型示例学生 t 分布随机值。 |
triangular(key, left, mode, right[, shape, …]) | 用给定的形状和浮点数数据类型示例三角形随机值。 |
truncated_normal(key, lower, upper[, shape, …]) | 用给定的形状和数据类型示例截断标准正态随机值。 |
uniform(key[, shape, dtype, minval, maxval]) | 用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。 |
[wald(key, mean[, shape, dtype]) | 用给定的形状和浮点数数据类型示例瓦尔德随机值。 |
weibull_min(key, scale, concentration[, …]) | 从威布尔分布中采样。 |
jax.sharding 模块
原文:
jax.readthedocs.io/en/latest/jax.sharding.html
类
代码语言:javascript复制class jax.sharding.Sharding
描述了jax.Array
如何跨设备布局。
property addressable_devices: set[Device]
Sharding
中由当前进程可寻址的设备集合。
addressable_devices_indices_map(global_shape)
从可寻址设备到它们包含的数组数据切片的映射。
addressable_devices_indices_map
包含适用于可寻址设备的device_indices_map
部分。
参数:
global_shape (tuple[int, …**])
返回类型:
Mapping[Device, tuple[slice, …] | None]
代码语言:javascript复制property device_set: set[Device]
这个Sharding
跨越的设备集合。
在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
代码语言:javascript复制devices_indices_map(global_shape)
返回从设备到它们包含的数组切片的映射。
映射包括所有全局设备,即包括来自其他进程的不可寻址设备。
参数:
global_shape (tuple[int, …**])
返回类型:
Mapping[Device, tuple[slice, …]]
代码语言:javascript复制is_equivalent_to(other, ndim)
如果两个分片等效,则返回True
。
如果它们在相同设备上放置了相同的逻辑数组分片,则两个分片是等效的。
例如,如果NamedSharding
和PositionalSharding
都将数组的相同分片放置在相同的设备上,则它们可能是等效的。
参数:
- self (Sharding)
- other (Sharding)
- ndim (int)
返回类型:
bool
代码语言:javascript复制property is_fully_addressable: bool
此分片是否是完全可寻址的?
如果当前进程能够寻址Sharding
中列出的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable
等效于 “is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则分片是完全复制的。
代码语言:javascript复制property memory_kind: str | None
返回分片的内存类型。
代码语言:javascript复制shard_shape(global_shape)
返回每个设备上数据的形状。
此函数返回的分片形状是从global_shape
和分片属性计算得出的。
参数:
global_shape (tuple[int, …**])
返回类型:
tuple[int, …]
代码语言:javascript复制with_memory_kind(kind)
返回具有指定内存类型的新分片实例。
参数:
kind (str)
返回类型:
分片
代码语言:javascript复制class jax.sharding.SingleDeviceSharding
基类:分片
一个将其数据放置在单个设备上的分片
。
参数:
device – 单个设备
。
示例
代码语言:javascript复制>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
... jax.devices()[0])
代码语言:javascript复制property device_set: set[Device]
此分片
跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的非可寻址设备。
代码语言:javascript复制devices_indices_map(global_shape)
返回从设备到每个包含的数组片段的映射。
映射包括所有全局设备,即包括来自其他进程的非可寻址设备。
参数:
global_shape (tuple[int, …**])
返回类型:
映射[设备, tuple[slice, …]]
代码语言:javascript复制property is_fully_addressable: bool
此分片是否完全可寻址?
如果当前进程可以寻址分片
中命名的所有设备,则称分片完全可寻址。is_fully_addressable
在多进程 JAX 中等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则分片完全复制。
代码语言:javascript复制property memory_kind: str | None
返回分片的内存类型。
代码语言:javascript复制with_memory_kind(kind)
返回具有指定内存类型的新分片实例。
参数:
kind (str)
返回类型:
单设备分片
代码语言:javascript复制class jax.sharding.NamedSharding
基类:分片
一个NamedSharding
使用命名轴来表示分片。
一个NamedSharding
是设备Mesh
和描述如何跨该网格对数组进行分片的PartitionSpec
的组合。
一个Mesh
是 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x'
或 'y'
。
一个PartitionSpec
是一个元组,其元素可以是None
、一个网格轴或一组网格轴的元组。每个元素描述如何在零个或多个网格维度上对输入维度进行分区。例如,PartitionSpec('x', 'y')
表示数据的第一维在网格的 x
轴上进行分片,第二维在网格的 y
轴上进行分片。
分布式数组和自动并行化(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names
)教程详细讲解了如何使用Mesh
和PartitionSpec
,包括更多细节和图示。
参数:
- mesh – 一个
jax.sharding.Mesh
对象。 - spec – 一个
jax.sharding.PartitionSpec
对象。
示例
代码语言:javascript复制>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
代码语言:javascript复制property addressable_devices: set[Device]
当前进程可以访问的Sharding
中的设备集。
property device_set: set[Device]
该Sharding
跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
代码语言:javascript复制property is_fully_addressable: bool
此分片是否完全可寻址?
一个分片如果当前进程可以访问Sharding
中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable
等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则称分片为完全复制。
代码语言:javascript复制property memory_kind: str | None
返回分片的内存类型。
代码语言:javascript复制property mesh
(self) -> object
property spec
(self) -> object
with_memory_kind(kind)
返回具有指定内存类型的新Sharding
实例。
参数:
kind (str)
返回类型:
NamedSharding
代码语言:javascript复制class jax.sharding.PositionalSharding(devices, *, memory_kind=None)
基类:Sharding
参数:
- devices (Sequence*[xc.Device]* | np.ndarray)
- memory_kind (str | None)
property device_set: set[Device]
该Sharding
跨越的设备集。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
代码语言:javascript复制property is_fully_addressable: bool
此分片是否完全可寻址?
一个分片如果当前进程可以访问Sharding
中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable
等同于“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
如果每个设备都有整个数据的完整副本,则称分片为完全复制。
代码语言:javascript复制property memory_kind: str | None
返回分片的内存类型。
代码语言:javascript复制with_memory_kind(kind)
返回具有指定内存类型的新Sharding
实例。
参数:
kind (str)
返回类型:
PositionalSharding
代码语言:javascript复制class jax.sharding.PmapSharding
基类:Sharding
描述了jax.pmap()
使用的分片。
classmethod default(shape, sharded_dim=0, devices=None)
创建一个PmapSharding
,与jax.pmap()
使用的默认放置方式匹配。
参数:
- shape (tuple[int, …**]) – 输入数组的形状。
- sharded_dim (int") – 输入数组进行分片的维度。默认为 0。
- devices(Sequence[Device] | None) – 可选的设备序列。如果省略,隐含的
- used(pmap 使用的设备顺序是) –
jax.local_devices()
。 - of(这是顺序) –
jax.local_devices()
。
返回类型:
PmapSharding
代码语言:javascript复制property device_set: set[Device]
这个Sharding
跨越的设备集合。
在多控制器 JAX 中,设备集合是全局的,即包括其他进程的非可寻址设备。
代码语言:javascript复制property devices
(self)-> ndarray
代码语言:javascript复制devices_indices_map(global_shape)
返回设备到每个包含的数组切片的映射。
映射包括所有全局设备,即包括其他进程的非可寻址设备。
参数:
global_shape(元组[int,…**])
返回类型:
Mapping[Device,元组[切片,…]]
代码语言:javascript复制is_equivalent_to(other, ndim)
如果两个分片等效,则返回True
。
如果它们将相同的逻辑数组分片放置在相同的设备上,则两个分片是等效的。
例如,如果NamedSharding
和PositionalSharding
将数组的相同分片放置在相同的设备上,则它们可能是等效的。
参数:
- self(PmapSharding)
- other(PmapSharding)
- ndim(int)
返回类型:
布尔(“in Python v3.12”)
代码语言:javascript复制property is_fully_addressable: bool
这个分片是否完全可寻址?
如果当前进程能够处理Sharding
中命名的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable
相当于“is_local”。
property is_fully_replicated: bool
这个分片是否完全复制?
如果每个设备都有完整数据的副本,则分片是完全复制的。
代码语言:javascript复制property memory_kind: str | None
返回分片的内存类型。
代码语言:javascript复制shard_shape(global_shape)
返回每个设备上数据的形状。
此函数返回的分片形状是从global_shape
和分片属性计算而来的。
参数:
global_shape(元组[int,…**])
返回类型:
元组[int,…]
代码语言:javascript复制property sharding_spec
(self)-> jax::ShardingSpec
代码语言:javascript复制with_memory_kind(kind)
返回具有指定内存类型的新 Sharding 实例。
参数:
kind(str)
代码语言:javascript复制class jax.sharding.GSPMDSharding
基类:Sharding
property device_set: set[Device]
这个Sharding
跨越的设备集合。
在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。
代码语言:javascript复制property is_fully_addressable: bool
此分片是否完全可寻址?
如果当前进程可以访问Sharding
中命名的所有设备,则分片是完全可寻址的。is_fully_addressable
相当于多进程 JAX 中的“is_local”。
property is_fully_replicated: bool
此分片是否完全复制?
一个分片是完全复制的,如果每个设备都有整个数据的完整副本。
代码语言:javascript复制property memory_kind: str | None
返回分片的内存类型。
代码语言:javascript复制with_memory_kind(kind)
返回具有指定内存类型的新 Sharding 实例。
参数:
kind(str)
返回类型:
GSPMDSharding
代码语言:javascript复制class jax.sharding.PartitionSpec(*partitions)
元组描述如何在设备网格上对数组进行分区。
每个元素都可以是None
、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding
的文档。
此类存在,以便 JAX 的 pytree 实用程序可以区分分区规范和应视为 pytrees 的元组。
代码语言:javascript复制class jax.sharding.Mesh(devices, axis_names)
声明在此管理器范围内可用的硬件资源。
特别是,所有axis_names
在管理块内都变成有效的资源名称,并且可以在jax.experimental.pjit.pjit()
的in_axis_resources
参数中使用,还请参阅 JAX 的多进程编程模型(jax.readthedocs.io/en/latest/multi_process.html
)和分布式数组与自动并行化教程(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
)
如果您在多线程中编译,请确保with Mesh
上下文管理器位于线程将执行的函数内部。
参数:
- devices(ndarray) - 包含 JAX 设备对象(例如从
jax.devices()
获得的对象)的 NumPy ndarray 对象。 - axis_names(tuple[Any, …**]) - 资源轴名称序列,用于分配给
devices
参数的维度。其长度应与devices
的秩匹配。
示例
代码语言:javascript复制>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
代码语言:javascript复制>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
代码语言:javascript复制>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
代码语言:javascript复制>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
jax.debug 模块
原文:
jax.readthedocs.io/en/latest/jax.debug.html
运行时值调试实用工具
jax.debug.print 和 jax.debug.breakpoint 描述了如何利用 JAX 的运行时值调试功能。
callback(callback, *args[, ordered]) | 调用可分阶段的 Python 回调函数。 |
---|---|
print(fmt, *args[, ordered]) | 打印值,并在 JAX 函数中工作。 |
breakpoint(*[, backend, filter_frames, …]) | 在程序中某一点设置断点。 |
调试分片实用工具
能够在分段函数内(和外部)检查和可视化数组分片的函数。
inspect_array_sharding(value, *, callback) | 在 JIT 编译函数内部启用检查数组分片。 |
---|---|
visualize_array_sharding(arr, **kwargs) | 可视化数组的分片。 |
visualize_sharding(shape, sharding, *[, …]) | 使用 rich 可视化 Sharding。 |
jax.dlpack 模块
原文:
jax.readthedocs.io/en/latest/jax.dlpack.html
from_dlpack(external_array[, device, copy]) | 返回一个 DLPack 张量的 Array 表示形式。 |
---|---|
to_dlpack(x[, stream, src_device, …]) | 返回一个封装了 Array x 的 DLPack 张量。 |
jax.distributed 模块
原文:
jax.readthedocs.io/en/latest/jax.distributed.html
initialize([coordinator_address, …]) | 初始化 JAX 分布式系统。 |
---|---|
shutdown() | 关闭分布式系统。 |
jax.dtypes 模块
原文:
jax.readthedocs.io/en/latest/jax.dtypes.html
bfloat16 | bfloat16 浮点数值 |
---|---|
canonicalize_dtype(dtype[, allow_extended_dtype]) | 根据config.x64_enabled配置将 dtype 转换为规范的 dtype。 |
float0 | 对应于相同名称的标量类型和 dtype 的 DType 类。 |
issubdtype(a, b) | 如果第一个参数是类型代码在类型层次结构中较低/相等,则返回 True。 |
prng_key() | PRNG Key dtypes 的标量类。 |
result_type(*args[, return_weak_type_flag]) | 方便函数,用于应用 JAX 参数 dtype 提升。 |
scalar_type_of(x) | 返回与 JAX 值关联的标量类型。 |
jax.flatten_util 模块
原文:
jax.readthedocs.io/en/latest/jax.flatten_util.html
函数列表
- | ravel_pytree(pytree) | 将一个数组的 pytree 展平(压缩)为一个 1D 数组。 |
---|
jax.image 模块
原文:
jax.readthedocs.io/en/latest/jax.image.html
图像操作函数。
更多的图像操作函数可以在建立在 JAX 之上的库中找到,例如 PIX。
图像操作函数
resize(image, shape, method[, antialias, …]) | 图像调整大小。 |
---|---|
scale_and_translate(image, shape, …[, …]) | 对图像应用缩放和平移。 |
参数类
代码语言:javascript复制class jax.image.ResizeMethod(value)
图像调整大小方法。
可能的取值包括:
NEAREST:
最近邻插值。
LINEAR:
线性插值。
LANCZOS3:
Lanczos 重采样,使用半径为 3 的核。
LANCZOS5:
Lanczos 重采样,使用半径为 5 的核。
CUBIC:
三次插值,使用 Keys 三次核。
jax.nn 模块
原文:
jax.readthedocs.io/en/latest/jax.nn.html
jax.nn.initializers
模块
神经网络库常见函数。
激活函数
relu | 线性整流单元激活函数。 |
---|---|
relu6 | 线性整流单元 6 激活函数。 |
sigmoid(x) | Sigmoid 激活函数。 |
softplus(x) | Softplus 激活函数。 |
sparse_plus(x) | 稀疏加法函数。 |
sparse_sigmoid(x) | 稀疏 Sigmoid 激活函数。 |
soft_sign(x) | Soft-sign 激活函数。 |
silu(x) | SiLU(又称 swish)激活函数。 |
swish(x) | SiLU(又称 swish)激活函数。 |
log_sigmoid(x) | 对数 Sigmoid 激活函数。 |
leaky_relu(x[, negative_slope]) | 泄漏整流线性单元激活函数。 |
hard_sigmoid(x) | 硬 Sigmoid 激活函数。 |
hard_silu(x) | 硬 SiLU(swish)激活函数。 |
hard_swish(x) | 硬 SiLU(swish)激活函数。 |
hard_tanh(x) | 硬tanh 激活函数。 |
elu(x[, alpha]) | 指数线性单元激活函数。 |
celu(x[, alpha]) | 连续可微的指数线性单元激活函数。 |
selu(x) | 缩放的指数线性单元激活函数。 |
gelu(x[, approximate]) | 高斯误差线性单元激活函数。 |
glu(x[, axis]) | 门控线性单元激活函数。 |
squareplus(x[, b]) | Squareplus 激活函数。 |
mish(x) | Mish 激活函数。 |
其他函数
softmax(x[, axis, where, initial]) | Softmax 函数。 |
---|---|
log_softmax(x[, axis, where, initial]) | 对数 Softmax 函数。 |
logsumexp() | 对数-总和-指数归约。 |
standardize(x[, axis, mean, variance, …]) | 通过减去mean并除以(sqrt{mathrm{variance}})来标准化数组。 |
one_hot(x, num_classes, *[, dtype, axis]) | 对给定索引进行 One-hot 编码。 |
jax.nn.initializers 模块
原文:
jax.readthedocs.io/en/latest/jax.nn.initializers.html
与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。
初始化器
该模块提供了与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。
初始化器是一个函数,接受三个参数:(key, shape, dtype)
,并返回一个具有形状shape
和数据类型dtype
的数组。参数key
是一个 PRNG 密钥(例如来自jax.random.key()
),用于生成初始化数组的随机数。
constant(value[, dtype]) | 构建一个返回常数值数组的初始化器。 |
---|---|
delta_orthogonal([scale, column_axis, dtype]) | 构建一个用于增量正交核的初始化器。 |
glorot_normal([in_axis, out_axis, …]) | 构建一个 Glorot 正态初始化器(又称 Xavier 正态初始化器)。 |
glorot_uniform([in_axis, out_axis, …]) | 构建一个 Glorot 均匀初始化器(又称 Xavier 均匀初始化器)。 |
he_normal([in_axis, out_axis, batch_axis, dtype]) | 构建一个 He 正态初始化器(又称 Kaiming 正态初始化器)。 |
he_uniform([in_axis, out_axis, batch_axis, …]) | 构建一个 He 均匀初始化器(又称 Kaiming 均匀初始化器)。 |
lecun_normal([in_axis, out_axis, …]) | 构建一个 Lecun 正态初始化器。 |
lecun_uniform([in_axis, out_axis, …]) | 构建一个 Lecun 均匀初始化器。 |
normal([stddev, dtype]) | 构建一个返回实数正态分布随机数组的初始化器。 |
ones(key, shape[, dtype]) | 返回一个填充为一的常数数组的初始化器。 |
orthogonal([scale, column_axis, dtype]) | 构建一个返回均匀分布正交矩阵的初始化器。 |
truncated_normal([stddev, dtype, lower, upper]) | 构建一个返回截断正态分布随机数组的初始化器。 |
uniform([scale, dtype]) | 构建一个返回实数均匀分布随机数组的初始化器。 |
variance_scaling(scale, mode, distribution) | 初始化器,根据权重张量的形状调整其尺度。 |
zeros(key, shape[, dtype]) | 返回一个填充零的常数数组的初始化器。 |
jax.ops 模块
原文:
jax.readthedocs.io/en/latest/jax.ops.html
段落约简运算符
| segment_max
(data, segment_ids[, …]) | 计算数组段内的最大值。 |
函数 jax.ops.index_update、jax.ops.index_add 等已在 JAX 0.2.22 中弃用,并已移除。请改用 JAX 数组上的 jax.numpy.ndarray.at 属性。 |
---|
segment_min(data, segment_ids[, …]) |
segment_prod(data, segment_ids[, …]) |
segment_sum(data, segment_ids[, …]) |
jax.profiler 模块
原文:
jax.readthedocs.io/en/latest/jax.profiler.html
跟踪和时间分析
描述了如何利用 JAX 的跟踪和时间分析功能进行程序性能分析。
start_server(port) | 在指定端口启动分析器服务器。 |
---|---|
start_trace(log_dir[, create_perfetto_link, …]) | 启动性能分析跟踪。 |
stop_trace() | 停止当前正在运行的性能分析跟踪。 |
trace(log_dir[, create_perfetto_link, …]) | 上下文管理器,用于进行性能分析跟踪。 |
annotate_function(func[, name]) | 生成函数执行的跟踪事件的装饰器。 |
TraceAnnotation | 在分析器中生成跟踪事件的上下文管理器。 |
StepTraceAnnotation(name, **kwargs) | 在分析器中生成步骤跟踪事件的上下文管理器。 |
设备内存分析
请参阅设备内存分析,了解 JAX 的设备内存分析功能简介。
device_memory_profile([backend]) | 捕获 JAX 设备内存使用情况,格式为 pprof 协议缓冲区。 |
---|---|
save_device_memory_profile(filename[, backend]) | 收集设备内存使用情况,并将其写入文件。 |
jax.stages 模块
原文:
jax.readthedocs.io/en/latest/jax.stages.html
接口到编译执行过程的各个阶段。
JAX 转换,例如jax.jit
和jax.pmap
,也支持一种通用的显式降阶和预编译执行 ahead of time 的方式。 该模块定义了代表这一过程各个阶段的类型。
有关更多信息,请参阅AOT walkthrough。
类
代码语言:javascript复制class jax.stages.Wrapped(*args, **kwargs)
一个准备好进行追踪、降阶和编译的函数。
此协议反映了诸如jax.jit
之类的函数的输出。 调用它会导致 JIT(即时)降阶、编译和执行。 它也可以在编译之前明确降阶,并在执行之前编译结果。
__call__(*args, **kwargs)
执行包装的函数,根据需要进行降阶和编译。
代码语言:javascript复制lower(*args, **kwargs)
明确为给定的参数降阶此函数。
一个降阶函数被从 Python 阶段化,并翻译为编译器的输入语言,可能以依赖于后端的方式。 它已准备好进行编译,但尚未编译。
返回:
一个Lowered
实例,表示降阶。
返回类型:
降阶
代码语言:javascript复制trace(*args, **kwargs)
明确为给定的参数追踪此函数。
一个追踪函数被从 Python 阶段化,并翻译为一个 jaxpr。 它已准备好进行降阶,但尚未降阶。
返回:
一个Traced
实例,表示追踪。
返回类型:
追踪
代码语言:javascript复制class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)
降阶一个根据参数类型和值特化的函数。
降阶是一种准备好进行编译的计算。 此类将降阶与稍后编译和执行所需的剩余信息一起携带。 它还提供了一个通用的 API,用于查询 JAX 各种降阶路径(jit()
、pmap()
等)中降阶计算的属性。
参数:
- 降阶(XlaLowering)
- args_info(Any)
- out_tree(PyTreeDef)
- no_kwargs(bool)
as_text(dialect=None)
此降阶的人类可读文本表示。
旨在可视化和调试目的。 这不必是有效的也不一定可靠的序列化。 它直接传递给外部调用者。
参数:
方言(str | 无) – 可选字符串,指定一个降阶方言(例如,“stablehlo”)
返回类型:
str
代码语言:javascript复制compile(compiler_options=None)
编译,并返回相应的Compiled
实例。
参数:
compiler_options (dict[str, str | bool] | None)
返回类型:
Compiled
代码语言:javascript复制compiler_ir(dialect=None)
这种降低的任意对象表示。
旨在调试目的。这不是有效的也不是可靠的序列化。输出在不同调用之间没有一致性的保证。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
参数:
dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)
返回类型:
Any | None
代码语言:javascript复制cost_analysis()
执行成本估算的摘要。
旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
返回类型:
Any | None
代码语言:javascript复制property in_tree: PyTreeDef
一对(位置参数、关键字参数)的树结构。
代码语言:javascript复制class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)
编译后的函数专门针对类型/值进行了优化表示。
编译计算与可执行文件相关联,并提供执行所需的剩余信息。它还为查询 JAX 的各种编译路径和后端中编译计算属性提供了一个共同的 API。
参数:
- args_info (Any)
- out_tree (PyTreeDef)
__call__(*args, **kwargs)
将自身作为函数调用。
代码语言:javascript复制as_text()
这是可执行文件的人类可读文本表示。
旨在可视化和调试。这不是有效的也不是可靠的序列化。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
返回类型:
str | None
代码语言:javascript复制cost_analysis()
执行成本估算的摘要。
旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
返回类型:
Any | None
代码语言:javascript复制property in_tree: PyTreeDef
(位置参数,关键字参数) 的树结构。
代码语言:javascript复制memory_analysis()
估计内存需求的摘要。
用于可视化和调试目的。由此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如嵌套的字典、列表和具有数字叶子的元组)。然而,其结构可以是任意的:在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能是不一致的。
返回 None
如果不可用,例如基于后端、编译器或运行时。
返回类型:
任意 | None
代码语言:javascript复制runtime_executable()
此可执行对象的任意对象表示。
用于调试目的。这不是有效也不是可靠的序列化。输出不能保证在不同调用之间的一致性。
返回 None
如果不可用,例如基于后端、编译器或运行时。
返回类型:
任意 | None