Shortcuts

torch.jit.接口

torch.jit.interface(obj)[源代码]

装饰以注释不同类型的类或模块。

这个装饰器可以用来定义一个接口,该接口可以用于注释不同类型的类或模块。这可以用于注释一个子模块或属性类,该子模块或属性类可能具有实现相同接口的不同类型,或者可能在运行时被交换;或者用于存储一个包含不同类型的模块或类的列表。

它有时用于实现“可调用对象”——实现接口的函数或模块,但它们的实现方式不同,并且可以互换使用。

示例: .. testcode:

import torch
from typing import List

@torch.jit.interface
class InterfaceType:
    def run(self, x: torch.Tensor) -> torch.Tensor:
        pass

# 实现 InterfaceType
@torch.jit.script
class Impl1:
    def run(self, x: torch.Tensor) -> torch.Tensor:
        return x.relu()

class Impl2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.val = torch.rand(())

    @torch.jit.export
    def run(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.val

def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor:
    return impls[idx].run(val)

user_fn_jit = torch.jit.script(user_fn)

impls = [Impl1(), torch.jit.script(Impl2())]
val = torch.rand(4, 4)
user_fn_jit(impls, 0, val)
user_fn_jit(impls, 1, val)