jax.命名作用域#
- jax.named_scope(name)[源代码][源代码]#
一个上下文管理器,将用户指定的名称添加到 JAX 名称堆栈中。
当为即时编译到XLA(或其他后端如TensorFlow)进行计算时,JAX默认情况下不会保留它遇到的Python函数的名称(或其他源元数据)。这会使调试程序的编译表示形式变得复杂,因为每个执行的操作都缺乏上下文信息。
named_scope告诉 JAX 对给定的函数进行分阶段处理,并在底层操作上添加额外的注释。JAX 在内部通过一个名称栈来跟踪这些注释。当分阶段输出的程序通过 XLA 编译时,这些注释会被保留,并在 TensorBoard 中的 TensorFlow Profiler 等调试工具中显示。当使用experimental.jax2tf.convert()将 JAX 程序分阶段输出到 TensorFlow 时,名称也会被保留。- 参数:
name (str) – 用于命名名称范围内所有创建的操作的前缀。
- 生成器:
返回
None,但进入一个上下文,在此上下文中 name 将被附加到活动名称堆栈中。- 返回类型:
Generator[None, None, None]
示例
named_scope可以在编译函数内部作为上下文管理器使用:>>> import jax >>> >>> @jax.jit ... def layer(w, x): ... with jax.named_scope("dot_product"): ... logits = w.dot(x) ... with jax.named_scope("activation"): ... return jax.nn.relu(logits)
它也可以用作装饰器:
>>> @jax.jit ... @jax.named_scope("layer") ... def layer(w, x): ... logits = w.dot(x) ... return jax.nn.relu(logits)