torch.jit.annotate¶
- torch.jit.annotate(the_type, the_value)[源代码]¶
用于在TorchScript编译器中指定the_value的类型。
此方法是一个透传函数,返回the_value,用于提示TorchScript编译器the_value的类型。在TorchScript之外运行时,它是一个空操作。
尽管 TorchScript 可以为大多数 Python 表达式推断出正确的类型,但在某些情况下,类型推断可能会出错,包括:
空的容器,如 [] 和 {},TorchScript 假设它们是 Tensor 的容器
可选类型如 Optional[T] 但被赋予了类型 T 的有效值,TorchScript 会假设它是类型 T 而不是 Optional[T]
请注意,annotate() 在 torch.nn.Module 子类的 __init__ 方法中不起作用,因为它是在急切模式下执行的。要注释 torch.nn.Module 属性的类型,请改用
Annotate()
。示例:
import torch from typing import Dict @torch.jit.script def fn(): # 告诉 TorchScript 这个空字典是一个 (str -> int) 字典 # 而不是默认的 (str -> Tensor) 字典类型。 d = torch.jit.annotate(Dict[str, int], {}) # 如果没有上面的 `torch.jit.annotate`,下面的语句会因为 # 类型不匹配而失败。 d["name"] = 20
- Parameters
the_type – 应传递给 TorchScript 编译器的 Python 类型,作为 the_value 的类型提示
the_value – 值或表达式以提示类型。
- Returns
the_value 作为返回值被传递回来。