jax.numpy 模块#
使用 jax.lax 中的原语实现了 NumPy API。
虽然 JAX 尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,NumPy API 中那些就地修改数组的函数无法在 JAX 中实现。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 提供了替代的就地数组更新函数
x.at[i].set(y)(参见ndarray.at),而不是就地数组更新 (x[i] = y)。同样地,一些 NumPy 函数在可能的情况下通常返回数组的视图(例如
transpose()和reshape())。JAX 版本的此类函数将返回副本,尽管在使用jax.jit()编译操作序列时,XLA 通常会优化掉这些副本。NumPy 在将值提升为
float64类型时非常激进。JAX 在类型提升方面有时不那么激进(参见 类型提升)。一些 NumPy 例程的输出形状依赖于数据(例如
unique()和nonzero())。由于 XLA 编译器要求在编译时知道数组形状,因此这些操作与 JIT 不兼容。为此,JAX 为这些函数添加了一个可选的size参数,可以在使用 JIT 时静态指定该参数。
几乎所有适用的 NumPy 函数都在 jax.numpy 命名空间中实现;它们列在下面。
用于索引更新功能的辅助属性。 |
|
|
|
|
逐元素计算绝对值。 |
|
|
|
|
逐元素相加两个数组。 |
|
|
测试沿给定轴的所有数组元素是否评估为 True。 |
|
检查两个数组在容差范围内是否逐元素近似相等。 |
|
别名 |
|
别名 |
|
返回复数值数或数组的角。 |
|
测试沿给定轴的数组元素是否评估为 True。 |
|
返回一个新数组,该数组在原数组的末尾附加了值。 |
|
沿给定轴对1-D切片应用函数。 |
|
在多个轴上重复应用函数。 |
|
创建一个等间距值的数组。 |
|
逐元素计算三角反余弦。 |
|
逐元素计算反双曲余弦。 |
|
逐元素计算反正弦。 |
|
逐元素计算反双曲正弦。 |
|
三角反切函数,逐元素计算。 |
|
逐元素计算 |
|
逐元素计算反双曲正切。 |
|
返回沿某个轴的最大值的索引。 |
|
返回沿某个轴的最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回排序数组的索引。 |
|
查找非零数组元素的索引 |
|
|
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示形式。 |
|
将一个数组分割成子数组。 |
|
返回数组中数据的字符串表示形式。 |
|
将对象转换为 JAX 数组。 |
|
|
|
|
|
将数组转换为指定的数据类型。 |
|
|
|
|
|
|
将输入转换为至少具有一个维度的数组。 |
|
将输入视为至少具有两个维度的数组。 |
|
将输入视为至少具有三个维度的数组。 |
|
|
计算沿指定轴的加权平均值。 |
|
返回 Bartlett 窗口。 |
|
计算整数数组中每个值的出现次数。 |
按元素计算按位与操作。 |
|
|
计算 |
|
按元素计算位反转,或按位 NOT。 |
|
将整数的位向左移动。 |
|
按元素计算位反转,或按位 NOT。 |
按元素计算按位或运算。 |
|
|
|
逐元素计算按位异或运算。 |
|
|
返回布莱克曼窗。 |
|
从嵌套的块列表中组装一个nd数组。 |
|
|
|
将数组广播到公共形状。 |
将输入形状广播到公共输出形状。 |
|
|
将数组广播到指定形状。 |
沿最后一个轴连接切片、标量和类似数组的对象。 |
|
|
如果根据类型转换规则可以在数据类型之间进行转换,则返回 True。 |
|
返回数组的立方根,逐元素进行。 |
|
|
|
将输入向上舍入到最接近的整数。 |
所有字符串标量类型的抽象基类。 |
|
|
从索引数组和选择数组列表中构造一个数组。 |
|
将数组值裁剪到指定范围。 |
|
将一维数组堆叠为二维数组的列。 |
|
|
|
|
|
|
所有由浮点数组成的复数标量类型的抽象基类。 |
|
将复杂数据类型转换为实数数据类型时引发的警告。 |
|
|
使用布尔条件沿给定轴压缩数组。 |
|
沿现有轴连接一系列数组。 |
|
沿现有轴连接一系列数组。 |
|
返回逐元素的复共轭。 |
|
返回逐元素的复共轭。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
返回皮尔逊积矩相关系数。 |
|
两个一维数组的关联。 |
|
计算输入中每个元素的三角余弦值。 |
|
双曲余弦,逐元素计算。 |
|
返回沿指定轴的非零元素的数量。 |
|
给定数据和权重,估计一个协方差矩阵。 |
|
返回两个(数组)向量的叉积。 |
|
|
|
沿轴的元素累积乘积。 |
|
沿轴的元素累计和。 |
|
沿数组轴的累积和。 |
|
将角度从度转换为弧度。 |
|
将角度从弧度转换为度数。 |
|
从数组中删除条目或多个条目。 |
|
返回指定的对角线或构造一个对角线数组。 |
|
返回用于访问多维数组主对角线的索引。 |
|
返回用于访问给定数组主对角线的索引。 |
|
返回一个二维数组,其中扁平化的输入数组沿对角线排列。 |
|
返回数组的指定对角线。 |
|
计算沿给定轴的第 n 次离散差分。 |
|
返回输入数组中每个值所属的箱子的索引。 |
|
|
|
计算 x1 除以 x2 的整数商和余数,逐元素进行。 |
|
计算两个数组的点积。 |
|
|
|
将数组深度分割为子数组。 |
|
按深度顺序堆叠数组(沿第三轴)。 |
|
创建一个数据类型对象。 |
|
数组中连续元素之间的差异。 |
|
爱因斯坦求和 |
在不评估einsum的情况下,评估最优的收缩路径。 |
|
|
创建一个空数组。 |
|
创建一个与指定数组具有相同形状和数据类型的空数组。 |
|
逐元素返回 (x1 == x2)。 |
|
计算输入的逐元素指数。 |
|
计算输入的逐元素以2为底的指数。 |
|
将长度为1的维度插入数组 |
|
计算输入的每个元素的 |
|
返回满足条件的数组元素。 |
|
创建一个方形或矩形的单位矩阵 |
|
计算实值输入的逐元素绝对值。 |
|
返回一个数组的副本,其中对角线被覆盖。 |
|
浮点类型的机器限制。 |
|
将输入四舍五入到最接近的整数,趋向于零。 |
|
返回展平数组中非零元素的索引 |
|
所有无预定义长度的标量类型的抽象基类。 |
|
沿给定轴反转数组元素的顺序。 |
|
沿轴1反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
|
|
|
计算元素级别的 |
|
|
|
|
|
|
|
所有浮点标量类型的抽象基类。 |
|
将输入向下舍入到最接近的整数。 |
|
计算 x1 与 x2 的逐元素整除 |
|
返回输入数组中逐元素的最大值。 |
|
返回输入数组中逐元素的最小值。 |
|
返回逐元素的除法余数。 |
|
将 x 的元素分解为尾数和二的指数。 |
|
将缓冲区解释为一维数组。 |
|
未实现的 JAX 包装器用于 jnp.fromfile。 |
|
通过在每个坐标上执行一个函数来构造一个数组。 |
|
未实现的 JAX 包装器用于 jnp.fromiter。 |
|
从任意兼容JAX的标量函数创建一个JAX ufunc。 |
|
从字符串中的文本数据初始化的新 1-D 数组。 |
|
通过 DLPack 构建 JAX 数组。 |
|
创建一个充满指定值的数组。 |
|
创建一个充满指定值的数组,其形状和数据类型与另一个数组相同。 |
|
计算两个数组的最大公约数。 |
|
numpy 标量类型的基类。 |
|
返回在对数尺度上均匀分布的数字(几何级数)。 |
返回当前的打印选项。 |
|
|
返回一个 N 维数组的梯度。 |
|
返回 (x1 > x2) 的元素级真值。 |
|
返回 (x1 >= x2) 的元素级真值。 |
|
返回汉明窗。 |
|
返回汉宁窗。 |
|
计算 Heaviside 阶跃函数。 |
|
计算数据集的直方图。 |
|
用于计算 histogram 使用的箱子边缘的函数 |
|
计算两个数据样本的二维直方图。 |
|
计算某些数据的多维直方图。 |
|
将数组水平分割成子数组。 |
|
按顺序水平堆叠数组(按列)。 |
|
给定直角三角形的“边”,返回其斜边。 |
第一类修正贝塞尔函数,0阶。 |
|
|
创建一个单位矩阵 |
|
|
|
返回复数参数的虚部。 |
构建数组索引元组的更好方法。 |
|
|
返回一个表示网格索引的数组。 |
|
所有数值标量类型的抽象基类,其范围内的值可能具有(潜在的)不精确表示,例如浮点数。 |
|
计算两个数组的内积。 |
|
在给定的轴上,在给定的索引之前插入值。 |
|
|
|
|
|
|
|
|
|
|
|
所有整数标量类型的抽象基类。 |
|
针对单调递增样本点的一维线性插值。 |
|
计算两个一维数组的集合交集。 |
|
按元素计算位反转,或按位 NOT。 |
|
检查两个数组的元素是否在容差范围内近似相等。 |
|
返回布尔数组,显示输入是否为复数。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示提供的 dtype 是否属于指定类型。 |
|
测试逐元素是否为有限值(不是无穷大且不是非数值)。 |
|
确定 |
|
测试元素是否为正无穷或负无穷。 |
|
逐元素测试 NaN 并返回结果为布尔数组。 |
|
逐元素测试正无穷大,返回结果为布尔数组。 |
|
逐元素测试正无穷大,返回结果为布尔数组。 |
|
返回一个布尔数组,显示输入是否为实数。 |
|
检查输入是否不是复数或包含复数元素的数组。 |
|
如果 element 的类型是标量类型,则返回 True。 |
|
如果第一个参数在类型层次结构中低于或等于第二个参数,则返回 True。 |
|
检查一个对象是否可以被迭代。 |
|
从 N 个一维序列返回一个多维网格(开放网格)。 |
|
返回凯泽窗。 |
|
计算两个输入数组的Kronecker积。 |
|
计算两个数组的最小公倍数。 |
|
返回 x1 * 2**x2,逐元素计算。 |
|
将整数的位向左移动。 |
|
返回 (x1 < x2) 的元素级真值。 |
|
返回 (x1 <= x2) 的元素级真值。 |
|
使用一系列键执行间接稳定排序。 |
|
返回区间内的等间隔数字。 |
|
从 |
|
计算输入的逐元素自然对数。 |
|
计算 x 元素的以 10 为底的对数 |
|
计算输入元素加一的对数, |
|
计算 x 的逐元素以 2 为底的对数 |
|
计算 |
以2为底的对数,输入的指数和。 |
|
逐元素计算逻辑与操作。 |
|
|
计算 NOT x 的元素级真值。 |
计算逐元素的逻辑或运算。 |
|
逐元素计算逻辑异或运算。 |
|
|
返回在对数刻度上均匀间隔的数字。 |
|
返回一个掩码函数的索引,以访问 (n, n) 数组。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
返回沿给定轴的数组元素的最大值。 |
|
返回输入数组中逐元素的最大值。 |
|
返回沿给定轴的数组元素的平均值。 |
|
返回沿给定轴的数组元素的中位数。 |
|
从坐标向量返回坐标矩阵的元组。 |
返回密集的多维“网格”。 |
|
|
返回沿给定轴的数组元素的最小值。 |
|
返回输入数组中逐元素的最小值。 |
|
返回逐元素的除法余数。 |
|
返回数组中每个元素的分数部分和整数部分。 |
|
将数组轴移动到新位置 |
逐元素相乘两个数组。 |
|
|
将 NaN 替换为零,将无穷大替换为大的有限数(默认) |
|
返回指定轴上最大值的索引,忽略 |
|
返回指定轴上最小值的索引,忽略 |
|
沿轴的元素累积乘积,忽略NaN值。 |
|
沿轴的元素累积和,忽略NaN值。 |
|
返回沿给定轴的数组元素的最大值,忽略 NaNs。 |
|
返回沿给定轴的数组元素的均值,忽略 NaNs。 |
|
返回沿给定轴的数组元素的中位数,忽略 NaNs。 |
|
返回沿给定轴的数组元素的最小值,忽略 NaNs。 |
|
计算沿指定轴的数据百分位数,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的乘积,忽略 NaNs。 |
|
计算沿指定轴的数据分位数,忽略 NaNs。 |
|
计算沿指定轴的标准差,忽略 NaNs。 |
|
返回沿给定轴的数组元素之和,忽略 NaNs。 |
|
计算沿给定轴的数组元素的方差,忽略 NaNs。 |
|
|
|
返回数组的维度数。 |
|
返回输入元素的负值。 |
|
返回 |
|
返回数组中非零元素的索引。 |
|
逐元素返回 (x1 != x2)。 |
|
所有数值标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放的多维“网格”。 |
|
|
创建一个充满1的数组。 |
|
创建一个与给定数组具有相同形状和数据类型的全一数组。 |
|
计算两个数组的外积。 |
|
将二值数组的元素打包成 uint8 数组中的位。 |
|
填充数组。 |
|
返回数组的部分排序副本。 |
|
计算数据沿指定轴的百分位数。 |
|
置换数组的轴/维度。 |
|
在整个定义域上分段评估一个函数。 |
|
基于掩码更新数组元素。 |
|
返回给定根序列的多项式的系数。 |
|
返回两个多项式的和。 |
|
返回指定阶数多项式的导数的系数。 |
|
返回多项式除法的商和余数。 |
|
最小二乘多项式拟合数据。 |
|
返回多项式的指定阶数积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
在特定值处计算多项式。 |
|
返回输入元素的正值。 |
|
第一个数组的元素按第二个数组的元素逐个求幂。 |
|
第一个数组的元素按第二个数组的元素逐个求幂。 |
|
用于设置打印选项的上下文管理器。 |
|
返回数组元素在给定轴上的乘积。 |
|
返回二元运算应将其参数转换为的类型。 |
|
返回给定轴上的峰峰值范围。 |
|
将元素放入指定索引的数组中。 |
|
计算数据沿指定轴的分位数。 |
沿第一个轴连接切片、标量和类似数组的对象。 |
|
|
将角度从弧度转换为度数。 |
|
将角度从度转换为弧度。 |
|
将数组展平为1维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的实部。 |
|
返回参数的倒数,逐元素进行。 |
|
返回逐元素的除法余数。 |
|
从重复元素构建数组。 |
|
返回数组的重新形状的副本。 |
|
返回一个具有指定形状的新数组。 |
|
返回应用 NumPy 后的类型 |
|
将 |
|
将 x 的元素四舍五入到最近的整数 |
|
沿指定轴滚动数组的元素。 |
|
将指定轴滚动到给定位置。 |
|
返回给定系数 |
|
在由轴指定的平面内将数组逆时针旋转90度。 |
|
将输入值四舍五入到给定的位数。 |
|
将输入值四舍五入到给定的位数。 |
构建数组索引元组的更好方法。 |
|
|
将数组保存为 NumPy |
|
将多个数组保存到一个未压缩的 |
|
在一个已排序的数组中执行二分查找。 |
|
根据一系列条件选择值。 |
|
设置打印选项。 |
|
计算两个一维数组的集合差。 |
|
计算两个数组中元素的集合异或。 |
|
返回数组的形状。 |
|
返回输入的逐元素符号指示。 |
|
返回元素级 True,其中符号位已设置(小于零)。 |
所有有符号整数标量类型的抽象基类。 |
|
|
计算输入中每个元素的三角正弦值。 |
|
返回归一化的sinc函数。 |
|
|
|
双曲正弦,逐元素计算。 |
|
返回沿给定轴的元素数量。 |
|
返回数组的排序副本。 |
|
首先按实部排序,然后按虚部排序一个复杂的数组。 |
|
将一个数组分割成子数组。 |
|
返回数组的非负平方根,逐元素进行。 |
|
返回输入的逐元素平方。 |
|
从数组中移除一个或多个长度为1的轴 |
|
沿新轴连接数组的序列。 |
|
计算沿给定轴的标准差。 |
|
逐元素减去参数。 |
|
数组元素在给定轴上的总和。 |
|
交换数组的两个轴。 |
|
从数组中提取元素。 |
|
从数组中提取元素。 |
|
计算输入中每个元素的三角正切值。 |
|
逐元素计算双曲正切。 |
|
计算两个N维数组的张量点积。 |
|
通过重复 A 的次数来构造一个数组,次数由 reps 给出。 |
|
返回数组对角线上的元素之和。 |
|
使用复合梯形法则沿给定轴进行积分。 |
|
返回一个 N 维数组的转置版本。 |
|
返回一个数组,其中对角线及其下方为1,其他位置为0。 |
|
返回数组的下三角部分。 |
|
返回大小为 |
|
返回给定数组的下三角索引。 |
|
修剪输入数组的前导和/或尾随零。 |
|
返回数组的上三角部分。 |
|
返回大小为 |
|
返回给定数组的上三角索引。 |
|
计算 x1 与 x2 的逐元素除法 |
|
将输入四舍五入到最接近的整数,趋向于零。 |
|
对数组进行逐元素操作的通用函数。 |
|
|
|
|
|
|
|
|
|
|
|
计算两个一维数组的并集。 |
|
返回数组中的唯一值。 |
|
从 x 中返回唯一值,以及索引、逆索引和计数。 |
|
从 x 中返回唯一值及其计数。 |
|
从 x 中返回唯一值,以及索引、逆索引和计数。 |
|
从 x 中返回唯一值,以及索引、逆索引和计数。 |
|
将 uint8 数组的元素解包到二进制值的输出数组中。 |
|
将平面索引转换为多维索引。 |
|
沿着给定轴将数组分割成一系列数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
通过取大增量相对于周期的补码来进行解包。 |
|
生成一个范德蒙矩阵。 |
|
计算沿指定轴的方差。 |
|
对两个一维向量执行共轭乘法。 |
|
执行两个批量向量的共轭乘法。 |
|
定义一个带有广播功能的矢量化函数。 |
|
将数组垂直分割成子数组。 |
|
按顺序垂直堆叠数组(按行)。 |
|
根据条件从两个数组中选择元素。 |
|
创建一个充满零的数组。 |
|
创建一个充满零的数组,其形状和数据类型与给定数组相同。 |
jax.numpy.fft#
|
沿给定轴计算一维离散傅里叶变换。 |
|
沿着给定的轴计算二维离散傅里叶变换。 |
|
返回离散傅里叶变换的样本频率。 |
|
沿给定轴计算多维离散傅里叶变换。 |
|
将零频分量移到频谱中心。 |
|
计算具有厄米特对称性的数组的 1-D FFT。 |
|
计算一维逆离散傅里叶变换。 |
|
计算二维逆离散傅里叶变换。 |
|
计算多维逆离散傅里叶变换。 |
|
fftshift 的逆操作。 |
|
计算具有厄米特对称性的数组的 1-D 逆 FFT。 |
|
计算一个实数值的一维离散傅里叶逆变换。 |
|
计算一个实数值的二维离散傅里叶逆变换。 |
|
计算实数值多维逆离散傅里叶变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的样本频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的 Cholesky 分解。 |
|
计算矩阵的条件数。 |
|
计算两个三维向量的叉积 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算厄米矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算厄米矩阵的特征值。 |
|
返回一个方阵的逆矩阵 |
|
返回线性方程的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆栈的范数。 |
|
将一个方阵提升到一个整数幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆栈。 |
|
高效计算数组序列之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的 (Moore-Penrose) 伪逆。 |
|
计算数组的QR分解 |
|
计算数组的符号和(自然)对数行列式。 |
|
求解线性方程组 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个N维数组的张量点积。 |
|
计算数组的张量逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或一批向量的向量范数。 |
|
计算两个数组的(批量)向量共轭点积。 |
JAX 数组#
JAX 的 ndarray)是 JAX 中的核心数组对象:你可以将其视为 JAX 中与 numpy.ndarray 等效的对象。与 numpy.ndarray 类似,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数(如 array()、arange()、linspace() 等)来创建它们。
复制与序列化#
JAX Array 对象旨在在适当的情况下与 Python 标准库工具无缝协作。
使用内置的 copy 模块,当 copy.copy() 或 copy.deepcopy() 遇到 Array 时,它等同于调用 copy() 方法,该方法将在与原始数组相同的设备上创建缓冲区的副本。这在跟踪/JIT 编译的代码中将正确工作,尽管在此上下文中编译器可能会省略复制操作。
当内置的 pickle 模块遇到一个 Array 时,它将通过一种紧凑的位表示形式进行序列化,类似于被pickle的 numpy.ndarray 对象。当反序列化时,结果将是一个新的 Array 对象 在默认设备上。这是因为通常情况下,序列化和反序列化可能发生在不同的运行时环境中,并且没有一种通用的方法可以将一个运行时的设备ID映射到另一个运行时的设备ID。如果在追踪/JIT编译的代码中使用 pickle,将会导致 ConcretizationTypeError。
Python 数组 API 标准#
备注
在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api 以启用 JAX 数组的数组 API。在 JAX v0.4.32 之后,导入此模块不再需要,并且会引发弃用警告。
从 JAX v0.4.32 开始,jax.Array 和 jax.numpy 与 Python 数组 API 标准 兼容。您可以通过 jax.Array.__array_namespace__() 访问数组 API 命名空间:
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在某些方面与标准有所不同,主要是因为 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块解决。
更多信息,请参阅 Python 数组 API 标准 文档。