公共API: jax 包#
子包#
jax.numpy模块jax.scipy模块jax.lax模块jax.random模块jax.sharding模块jax.debug模块jax.dlpack模块jax.distributed模块jax.dtypes模块jax.flatten_util模块jax.image模块jax.nn模块jax.ops模块jax.profiler模块jax.stages模块jax.tree模块jax.tree_util模块jax.typing模块jax.export模块jax.extend模块jax.example_libraries模块jax.experimental模块
配置#
用于 jax_check_tracer_leaks 配置选项的上下文管理器。 |
|
用于 jax_check_tracer_leaks 配置选项的上下文管理器。 |
|
用于 jax_debug_nans 配置选项的上下文管理器。 |
|
用于 jax_debug_infs 配置选项的上下文管理器。 |
|
jax_default_device 配置选项的上下文管理器。 |
|
用于 jax_default_matmul_precision 配置选项的上下文管理器。 |
|
jax_default_prng_impl 配置选项的上下文管理器。 |
|
jax_enable_checks 配置选项的上下文管理器。 |
|
用于 jax_enable_custom_prng 配置选项的上下文管理器(临时)。 |
|
用于 jax_enable_custom_vjp_by_custom_transpose 配置选项的上下文管理器(瞬态)。 |
|
jax_log_compiles 配置选项的上下文管理器。 |
|
jax_numpy_rank_promotion 配置选项的上下文管理器。 |
|
|
一个上下文管理器,用于控制所有传输的传输保护级别。 |
即时编译 (jit)#
|
为 |
|
在其动态上下文中禁用 |
上下文管理器以确保在跟踪/编译时进行评估(或错误)。 |
|
|
创建一个函数,该函数在给定示例参数的情况下生成其XLA计算。 |
|
创建一个函数,该函数在给定示例参数的情况下生成其 jaxpr。 |
|
在不进行任何浮点运算的情况下计算 |
|
用于存储数组的形状、数据类型和其他静态属性的容器。 |
|
将 |
|
将数组传输到每个指定的设备并形成数组。 |
|
将数组分片传输到指定设备并形成数组。 |
|
将 |
返回默认XLA后端的平台名称。 |
|
|
在分阶段执行JAX计算时,为函数添加用户指定的名称。 |
|
一个上下文管理器,将用户指定的名称添加到 JAX 名称堆栈中。 |
尝试在 pytree 叶子上调用 |
自动微分#
|
创建一个评估 |
|
创建一个函数,该函数计算 |
|
|
|
|
|
|
|
计算 |
使用 |
|
|
转置一个保证为线性的函数。 |
|
计算 |
|
用于定义自定义 VJP 规则(即自定义梯度)的便捷函数。 |
|
闭包转换工具,用于高阶自定义导数。 |
|
当微分时,使 |
custom_jvp#
|
为自定义 JVP 规则定义设置一个可 JAX 变换的函数。 |
|
为此实例表示的函数定义一个自定义的 JVP 规则。 |
|
为每个参数分别定义JVP的便捷包装器。 |
custom_vjp#
|
为自定义 VJP 规则定义设置一个可 JAX 变换的函数。 |
|
为此实例表示的函数定义一个自定义的 VJP 规则。 |
jax.Array (jax.Array)#
|
JAX 的数组基类 |
|
通过从 |
|
从一个设备上的 |
|
使用进程中可用的数据创建分布式张量。 |
数组属性和方法#
可寻址分片列表。 |
|
|
测试沿给定轴的所有数组元素是否评估为 True。 |
|
测试沿给定轴的任何数组元素是否评估为 True。 |
|
返回最大值的索引。 |
|
返回最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回对数组进行排序的索引。 |
|
复制数组并转换为指定的数据类型。 |
用于索引更新功能的辅助属性。 |
|
|
从多个数组的元素中构建一个数组。 |
|
返回一个数组,其值被限制在指定范围内。 |
|
返回沿给定轴的此数组的选择切片。 |
返回数组的复共轭。 |
|
返回数组的复共轭。 |
|
返回数组的副本。 |
|
异步地将 |
|
|
返回数组的累积乘积。 |
|
返回数组的累计和。 |
与数组API兼容的设备属性。 |
|
|
返回数组中指定的对角线。 |
|
计算两个数组的点积。 |
数组的 数据类型 ( |
|
使用 |
|
|
将数组展平为1维形状。 |
全局分片列表。 |
|
返回数组的虚部。 |
|
这个数组是完全可寻址的吗? |
|
这个数组是完全复制的吗? |
|
|
将数组的一个元素复制到一个标准的 Python 标量并返回它。 |
一个数组元素的字节长度。 |
|
|
返回沿给定轴的数组元素的最大值。 |
|
返回沿给定轴的数组元素的平均值。 |
|
返回沿给定轴的数组元素的最小值。 |
数组元素消耗的总字节数。 |
|
数组的维度数量。 |
|
|
返回数组中非零元素的索引。 |
|
返回数组元素在给定轴上的乘积。 |
|
返回给定轴上的峰峰值范围。 |
|
将数组展平为1维形状。 |
返回数组的实部。 |
|
|
从重复元素构建数组。 |
|
返回一个包含相同数据但形状不同的新数组。 |
|
将数组元素四舍五入到给定的小数位。 |
|
在一个已排序的数组中执行二分查找。 |
数组的形状。 |
|
数组的切片。 |
|
数组中元素的总数。 |
|
|
返回数组的排序副本。 |
|
从数组中移除一个或多个长度为1的轴。 |
|
计算沿给定轴的标准差。 |
|
数组元素在给定轴上的总和。 |
|
交换数组的两个轴。 |
|
从数组中提取元素。 |
|
返回指定设备上的数组副本 |
|
返回对角线上的元素之和。 |
|
返回数组的一个转置副本。 |
|
计算沿指定轴的方差。 |
|
返回数组的按位复制,视作新的数据类型。 |
计算全轴数组转置。 |
|
计算(批量)矩阵转置。 |
向量化 (vmap)#
|
向量化映射。 |
|
定义一个带有广播功能的矢量化函数。 |
并行化 (pmap)#
|
支持集体操作的并行映射。 |
|
返回给定后端的所有设备列表。 |
|
类似于 |
|
返回此进程的整数进程索引。 |
|
返回设备总数。 |
|
返回此进程可寻址的设备数量。 |
|
返回与后端关联的JAX进程数。 |
|
返回与后端关联的所有 JAX 进程索引的列表。 |
回调#
|
调用一个纯Python回调函数。 |
|
调用一个不纯的Python回调函数。 |
|
调用一个可分阶段的 Python 回调。 |
|
打印值并在分阶段输出的 JAX 函数中工作。 |
杂项#
一个可用设备的描述。 |
|
|
返回一个包含本地环境及 JAX 安装信息字符串。 |
|
返回 platform 后端中的所有活动数组。 |
清除所有编译和暂存缓存。 |