torch.backends.mps 的源代码
from functools import lru_cache as _lru_cache
from typing import Optional
import torch
from ...library import Library as _Library
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
[docs]def is_built() -> bool:
r"""返回 PyTorch 是否构建了 MPS 支持。
注意,这并不一定意味着 MPS 是可用的;只是说
如果这个 PyTorch 二进制文件在具有工作 MPS 驱动程序
和设备的机器上运行,我们将能够使用它。
"""
return torch._C._has_mps
[docs]@_lru_cache
def is_available() -> bool:
r"""返回一个布尔值,指示 MPS 当前是否可用。"""
return torch._C._mps_is_available()
@_lru_cache
def is_macos_or_newer(major: int, minor: int) -> bool:
r"""返回一个布尔值,指示 MPS 是否在给定的 MacOS 或更新版本上运行。"""
return torch._C._mps_is_on_macos_or_newer(major, minor)
@_lru_cache
def is_macos13_or_newer(minor: int = 0) -> bool:
r"""返回一个布尔值,指示 MPS 是否在 MacOS 13 或更新版本上运行。"""
return torch._C._mps_is_on_macos_or_newer(13, minor)
_lib: Optional[_Library] = None
def _init():
r"""将 prims 注册为 var_mean 和 group_norm 的实现。"""
global _lib
if is_built() is False or _lib is not None:
return
from ..._decomp.decompositions import (
native_group_norm_backward as _native_group_norm_backward,
)
from ..._refs import native_group_norm as _native_group_norm
_lib = _Library("aten", "IMPL")
_lib.impl("native_group_norm", _native_group_norm, "MPS")
_lib.impl("native_group_norm_backward", _native_group_norm_backward, "MPS")