torch.utils.backend_registration 的源代码
import torch
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
from typing import List, Optional, Union
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
# TODO: 应该使用 `torch._C._get_privateuse1_backend_name()` 来获取
# 重命名的后端名称 `privateuse1`,但该函数会导致 torch.jit.script 出错,
# 所以我们使用名为 `_privateuse1_backend_name` 的全局变量。
_privateuse1_backend_name = "privateuseone"
[docs]def rename_privateuse1_backend(backend_name: str) -> None:
r"""
重命名 privateuse1 后端设备,使其在 PyTorch API 中使用更方便。
步骤如下:
(1) (在 C++ 中)为各种 torch 操作实现内核,并将其注册到 PrivateUse1 调度键。
(2) (在 python 中)调用 torch.utils.rename_privateuse1_backend("foo")
你现在可以在 python 中将 "foo" 作为普通设备字符串使用。
注意:此 API 每个进程只能调用一次。尝试在已经设置外部后端后更改它将导致错误。
注意(AMP):如果你想在你的设备上支持 AMP,你可以注册一个自定义后端模块。
后端必须使用 ``torch._register_device_module("foo", BackendModule)`` 注册一个自定义后端模块。
BackendModule 需要有以下 API:
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
获取 AMP 中 "foo" 设备支持的 dtypes,可能 "foo" 设备支持更多的 dtype。
(2) ``is_autocast_enabled() -> bool``
检查 AMP 是否在你的 "foo" 设备上启用。
(3) ``get_autocast_dtype() -> torch.dtype``
获取 AMP 中 "foo" 设备支持的 dtype,由 ``set_autocast_dtype`` 设置或默认 dtype,默认 dtype 为 ``torch.float16``。
(4) ``set_autocast_enabled(bool) -> None``
启用或禁用 "foo" 设备上的 AMP。
(5) ``set_autocast_dtype(dtype) -> None``
在 AMP 中为 "foo" 设备设置支持的 dtype,dtype 应包含在从 ``get_amp_supported_dtype`` 获取的 dtypes 中。
注意(随机):如果你想支持为你的设备设置种子,BackendModule 需要有以下 API:
(1) ``_is_in_bad_fork() -> bool``
如果当前处于 bad_fork 中,返回 ``True``,否则返回 ``False``。
(2) ``manual_seed_all(seed int) -> None``
为你的设备设置生成随机数的种子。
(3) ``device_count() -> int``
返回可用的 "foo" 数量。
(4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor``
返回所有设备的随机数状态列表。
(5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None``
设置指定 "foo" 设备的随机数生成器状态。
还有一些常见的函数:
(1) ``is_available() -> bool``
返回一个布尔值,指示 "foo" 是否当前可用。
(2) ``current_device() -> int``
返回当前选择的设备的索引。
更多详情请参见 https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
现有示例请参见 https://github.com/bdhirsh/pytorch_open_registration_example
示例::
>>> # xdoctest: +SKIP("failing")
>>> torch.utils.rename_privateuse1_backend("foo")
# 这将起作用,假设你已经实现了正确的 C++ 内核来实现 torch.ones。
>>> a = torch.ones(2, device="foo")
"""
_rename_privateuse1_backend(backend_name)
global _privateuse1_backend_name
_privateuse1_backend_name = backend_name
def _check_register_once(module, attr):
if hasattr(module, attr):
raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
def _get_current_device_index():
_get_device_index = "current_device"
if hasattr(torch, custom_backend_name) and \
hasattr(getattr(torch, custom_backend_name), _get_device_index):
return getattr(getattr(torch, custom_backend_name), _get_device_index)()
else:
# 默认设备索引为 0。
return 0
if device is None:
return _get_current_device_index()
# 如果 isinstance(device, str),这意味着传入的参数是字符串格式 "foo:0"
# 将其转换为 torch.device 对象,然后统一处理
elif isinstance(device, str):
device = torch.device(device)
# 变量 device 只能是 torch.device 类型或 int 类型
if isinstance(device, torch.device):
if device.type != custom_backend_name:
raise RuntimeError(f"Invalid device, must be {custom_backend_name} device")
elif device.index is None:
device_idx = _get_current_device_index()
else:
device_idx = device.index
# 如果 isinstance(device, int),我们可以直接获取索引号
else:
device_idx = device
return device_idx
def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
@property # type: ignore[misc]
def wrap_tensor_backend(self: torch.Tensor) -> bool:
return self.device.type == custom_backend