jax.命名作用域

目录

jax.命名作用域#

jax.named_scope(name)[源代码][源代码]#

一个上下文管理器,将用户指定的名称添加到 JAX 名称堆栈中。

当为即时编译到XLA(或其他后端如TensorFlow)进行计算时,JAX默认情况下不会保留它遇到的Python函数的名称(或其他源元数据)。这会使调试程序的编译表示形式变得复杂,因为每个执行的操作都缺乏上下文信息。

named_scope 告诉 JAX 对给定的函数进行分阶段处理,并在底层操作上添加额外的注释。JAX 在内部通过一个名称栈来跟踪这些注释。当分阶段输出的程序通过 XLA 编译时,这些注释会被保留,并在 TensorBoard 中的 TensorFlow Profiler 等调试工具中显示。当使用 experimental.jax2tf.convert() 将 JAX 程序分阶段输出到 TensorFlow 时,名称也会被保留。

参数:

name (str) – 用于命名名称范围内所有创建的操作的前缀。

生成器:

返回 None,但进入一个上下文,在此上下文中 name 将被附加到活动名称堆栈中。

返回类型:

Generator[None, None, None]

示例

named_scope 可以在编译函数内部作为上下文管理器使用:

>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
...   with jax.named_scope("dot_product"):
...     logits = w.dot(x)
...   with jax.named_scope("activation"):
...     return jax.nn.relu(logits)

它也可以用作装饰器:

>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
...   logits = w.dot(x)
...   return jax.nn.relu(logits)