Shortcuts

torch.library 的源代码

```html
from ._ops import OpOverload
from typing import Any, Optional, Set, List
import traceback
import torch
import weakref
import functools
import inspect
import re
import contextlib
import sys

__all__ = [
    'Library',
    'impl',
    'define',
    'fallthrough_kernel',
    'impl_abstract',
    'get_ctx',
]

# 包含已注册新内核的(命名空间、操作符、DispatchKey)组合的集合
# 集合中的键的形式为 `namespace + "/" + op_name + "/" + dispatch_key`。
# 此集合用于确保两个库不会尝试覆盖完全相同的功能,以避免
# 库调用未被意图调用的内核。
_impls: Set[str] = set()
_defs: Set[str] = set()

# prim 由 TorchScript 解释器保留
_reserved_namespaces = ['prim']

[docs]def fallthrough_kernel(): """ 一个虚拟函数,用于传递给 ``Library.impl`` 以注册一个 fallthrough。 """ raise NotImplementedError("fallthrough_kernel() 不应该被调用。")
[docs]class Library: """ 一个用于创建库的类,该库可用于从 Python 注册新操作符或 覆盖现有库中的操作符。 用户可以选择传入一个调度键名,如果他们只想注册 对应于特定调度键的内核。 要创建一个库来覆盖现有库(名称为 ns)中的操作符,请将 kind 设置为 "IMPL"。 要创建一个新库(名称为 ns)来注册新操作符,请将 kind 设置为 "DEF"。 要创建一个可能存在的库的片段来注册操作符(并绕过 给定命名空间只有一个库的限制),请将 kind 设置为 "FRAGMENT"。 参数: ns: 库名称 kind: "DEF", "IMPL"(默认: "IMPL"), "FRAGMENT" dispatch_key: PyTorch 调度键(默认: "") """ def __init__(self, ns, kind, dispatch_key=""): if kind not in ('IMPL', 'DEF', 'FRAGMENT'): raise ValueError("不支持的 kind: ", kind) if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'): raise ValueError(ns, " 是一个保留的命名空间。请尝试使用另一个名称创建库。") frame = traceback.extract_stack(limit=3)[0] filename, lineno = frame.filename, frame.lineno self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno) self.ns = ns self._op_defs: Set[str] = set() self._op_impls: Set[str] = set() self._registration_handles: List["torch._library.utils.RegistrationHandle"] = [] self.kind = kind self.dispatch_key = dispatch_key # 使用 finalizer 来设置 "析构函数" 而不是 __del__。 # Python 的 __del__ 可能会导致奇怪的事情(当 __del__ 实际被调用时,全局变量和局部变量可能已经 # 消失了!)。finalizers 有助于这种情况,因为它让我们捕获引用并保持它们的活动状态 weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles) def __repr__(self): return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
[docs] def define(self, schema, alias_analysis="", *, tags=()): r'''在 ns 命名空间中定义一个新操作符及其语义。 参数: schema: 定义新操作符的函数 schema。 alias_analysis (可选): 指示操作符参数的别名属性是否可以从 schema 推断(默认行为)或不推断("CONSERVATIVE")。 tags (Tag | Sequence[Tag]): 一个或多个 torch.Tag 应用于该操作符。标记操作符会改变操作符在各种 PyTorch 子系统下的行为;请在使用前仔细阅读 torch.Tag 的文档。 返回: 从 schema 推断的操作符名称。 示例:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY) >>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") ''' # 这是为了禁止 PURE_FUNCTION 别名分析,因为它是 C++ 中的有效 AliasAnalysis 类型 if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: raise RuntimeError</