Shortcuts

torch.utils.rename_privateuse1_backend

torch.utils.rename_privateuse1_backend(backend_name)[源代码]

将privateuse1后端设备重命名为更方便在PyTorch API中使用的设备名称。

步骤如下:

  1. (在C++中)为各种torch操作实现内核,并将其注册到PrivateUse1调度键。

  2. (在 Python 中)调用 torch.utils.rename_privateuse1_backend(“foo”)

你现在可以在Python中将“foo”用作普通设备字符串。

注意:此API每个进程只能调用一次。尝试在已设置外部后端后更改它将导致错误。

注意(AMP): 如果你想在设备上支持AMP,你可以注册一个自定义的后端模块。 后端必须使用torch._register_device_module("foo", BackendModule)注册一个自定义的后端模块。 BackendModule需要具有以下API:

  1. get_amp_supported_dtype() -> List[torch.dtype] 获取在AMP中“foo”设备上支持的数据类型,可能“foo”设备支持一种额外的数据类型。

  2. is_autocast_enabled() -> bool 检查AMP是否在您的“foo”设备上启用。

  3. get_autocast_dtype() -> torch.dtype 获取在AMP中“foo”设备上支持的dtype,该dtype由set_autocast_dtype设置或默认dtype,默认dtype为torch.float16

  4. set_autocast_enabled(bool) -> None 是否在您的“foo”设备上启用AMP。

  5. set_autocast_dtype(dtype) -> None 在AMP中为您的“foo”设备设置支持的dtype,并且dtype应包含在从get_amp_supported_dtype获得的dtypes中。

注意(随机):如果您希望支持为您的设备设置种子,BackendModule 需要具有以下 API:

  1. _is_in_bad_fork() -> bool 如果当前处于bad_fork中,则返回True,否则返回False

  2. manual_seed_all(seed int) -> None 为您的设备设置生成随机数的种子。

  3. device_count() -> int 返回可用的“foo”数量。

  4. get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor 返回一个表示所有设备随机数状态的ByteTensor列表。

  5. set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None 设置指定“foo”设备的随机数生成器状态。

还有一些常见的函数:

  1. is_available() -> bool 返回一个布尔值,指示“foo”当前是否可用。

  2. current_device() -> int 返回当前选定设备的索引。

更多详情,请参阅 https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend 有关现有示例,请参阅 https://github.com/bdhirsh/pytorch_open_registration_example

示例:

>>> torch.utils.rename_privateuse1_backend("foo")
# 这将有效,假设你已经实现了正确的C++内核
# 来实现torch.ones。
>>> a = torch.ones(2, device="foo")