jax.export.符号参数规格

jax.export.符号参数规格#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None, symbolic_constraints=(), symbolic_scope=None)[源代码][源代码]#

export 构建一个 jax.ShapeDtypeSpec 参数规格的 pytree。

详情请参阅 jax.export.symbolic_shape() 的文档以及 [形状多态性文档](https://jax.readthedocs.io/en/latest/export/shape_poly.html)。

参数:
  • args – 参数的 pytree。这些可以是 jax.Array 或 jax.ShapeDTypeSpec。它们用于学习参数的 pytree 结构、它们的 dtypes,并在 shapes_specs 包含占位符的地方填充实际形状。请注意,仅从 args 中使用 shapes_specs 为占位符的形状维度。

  • shapes_specs – 应该是 None`(所有参数都有静态形状),一个字符串(参见 :func:`jax.export.symbolic_shapeshape_spec;适用于所有参数),或者一个与 args 前缀匹配的 pytree。参见 [如何将可选参数匹配到参数](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。

  • constraints (Sequence[str]) – 至于 jax.export.symbolic_shape()

  • scope (SymbolicScope | None) – 至于 jax.export.symbolic_shape()

  • symbolic_constraints (Sequence[str]) – 已弃用,请使用 constraints

  • symbolic_scope (SymbolicScope | None) – 已弃用,请使用 scope

返回: 一个与 args 形状匹配的 jax.ShapeDTypeStruct 的 pytree

替换为 shapes_specs 指定的符号维度。