Shortcuts

torch.jit.annotate

torch.jit.annotate(the_type, the_value)[源代码]

用于在TorchScript编译器中指定the_value的类型。

此方法是一个透传函数,返回the_value,用于提示TorchScript编译器the_value的类型。在TorchScript之外运行时,它是一个空操作。

尽管 TorchScript 可以为大多数 Python 表达式推断出正确的类型,但在某些情况下,类型推断可能会出错,包括:

  • 空的容器,如 []{},TorchScript 假设它们是 Tensor 的容器

  • 可选类型如 Optional[T] 但被赋予了类型 T 的有效值,TorchScript 会假设它是类型 T 而不是 Optional[T]

请注意,annotate()torch.nn.Module 子类的 __init__ 方法中不起作用,因为它是在急切模式下执行的。要注释 torch.nn.Module 属性的类型,请改用 Annotate()

示例:

import torch
from typing import Dict

@torch.jit.script
def fn():
    # 告诉 TorchScript 这个空字典是一个 (str -> int) 字典
    # 而不是默认的 (str -> Tensor) 字典类型。
    d = torch.jit.annotate(Dict[str, int], {})

    # 如果没有上面的 `torch.jit.annotate`,下面的语句会因为
    # 类型不匹配而失败。
    d["name"] = 20
Parameters
  • the_type – 应传递给 TorchScript 编译器的 Python 类型,作为 the_value 的类型提示

  • the_value – 值或表达式以提示类型。

Returns

the_value 作为返回值被传递回来。