jax.export.Exported#
- class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[源代码][源代码]#
一个降低到 StableHLO 的 JAX 函数。
- 参数:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- in_tree#
描述降低后的 JAX 函数元组 (args, kwargs) 的 PyTreeDef。实际的降低过程不依赖于 in_tree,但可以使用相同的参数结构调用导出的函数。
- 类型:
tree_util.PyTreeDef
- out_tree#
描述降低后的 JAX 函数结果的 PyTreeDef。
- 类型:
tree_util.PyTreeDef
- in_shardings_hlo#
扁平化的输入分片,一个与 in_avals 长度相同的序列。None 表示未指定的分片。请注意,这些不包括网格或网格中使用的实际设备。有关如何将这些转换为可与 JAX API 一起使用的分片规范,请参见 in_shardings_jax。
- 类型:
tuple[HloSharding | None, …]
- out_shardings_hlo#
扁平化的输出分片,一个与 out_avals 长度相同的序列。None 表示未指定的分片。请注意,这些不包括网格或网格中使用的实际设备。请参阅 out_shardings_jax 以了解如何将这些转换为可与 JAX API 一起使用的分片规范。
- 类型:
tuple[HloSharding | None, …]
- nr_devices#
该模块的设备数量已减少。
- 类型:
整数
- platforms#
一个包含应导出函数的平台的元组。JAX 中的平台集合是开放的;用户可以添加平台。JAX 内置平台有:’tpu’, ‘cpu’, ‘cuda’, ‘rocm’。参见 https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export。
- ordered_effects#
序列化模块中存在的顺序效应。这是从序列化版本9开始存在的。有关存在顺序效应时的调用约定,请参见 https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention。
- 类型:
tuple[effects.Effect, …]
- mlir_module_serialized#
序列化的低级 VHLO 模块。
- 类型:
字节
- calling_convention_version#
导出模块的调用约定的版本号。更多版本详情请参见 https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions。
- 类型:
整数
- uses_global_constants#
是否 mlir_module_serialized 使用了形状多态性或多平台导出。这可能是因为 in_avals 包含了维度变量,或者由于导出模块的内部调用具有维度变量或平台索引参数。此类模块在XLA编译之前需要进行形状细化。
- 类型:
布尔
- disabled_safety_checks#
在导出时已禁用的安全检查描述符列表。参见 DisabledSafetyCheck 的文档字符串。
- 类型:
Sequence[DisabledSafetyCheck]
- _get_vjp#
一个可选函数,它接受当前导出的函数并返回导出的 VJP 函数。VJP 函数接受一个扁平的参数列表,从原始参数开始,然后是每个原始输出的一个余切参数。它返回一个元组,其中包含与扁平原始输入相对应的余切。
- 类型:
Callable[[导出], 导出] | None
查看 [关于 mlir_module 调用约定的描述](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention)。
- __init__(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)#
- 参数:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- 返回类型:
None
方法
__init__(fun_name, in_tree, in_avals, ...)call(*args, **kwargs)has_vjp()返回此导出是否支持 VJP。
in_shardings_jax(mesh)创建与 self.in_shardings_hlo 对应的 Shardings。
mlir_module()out_shardings_jax(mesh)创建与 self.out_shardings_hlo 对应的 Shardings。
serialize([vjp_order])序列化一个导出的对象。
vjp()获取导出的 VJP。
属性
in_shardingslowering_platforms已弃用。
mlir_module_serialization_version已弃用。
out_shardingsuses_shape_polymorphism已弃用。