Shortcuts

torch.mps 的源代码

r"""
此包为在Python中访问MPS(Metal Performance Shaders)后端提供了接口。
Metal是Apple用于编程Metal GPU(图形处理器单元)的API。使用MPS意味着可以通过在Metal GPU上运行工作来实现更高的性能。
更多详情请参见https://developer.apple.com/documentation/metalperformanceshaders。
"""
import torch
from .. import Tensor

_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
_default_mps_generator: torch._C.Generator = None  # type: ignore[assignment]


# 本地辅助函数(非公开或导出)
def _get_default_mps_generator() -> torch._C.Generator:
    global _default_mps_generator
    if _default_mps_generator is None:
        _default_mps_generator = torch._C._mps_get_default_generator()
    return _default_mps_generator


[docs]def synchronize() -> None: r"""等待MPS设备上的所有流中的所有内核完成。""" return torch._C._mps_deviceSynchronize()
[docs]def get_rng_state() -> Tensor: r"""返回随机数生成器状态作为ByteTensor。""" return _get_default_mps_generator().get_state()
[docs]def set_rng_state(new_state: Tensor) -> None: r"""设置随机数生成器状态。 参数: new_state (torch.ByteTensor): 所需的状态 """ new_state_copy = new_state.clone(memory_format=torch.contiguous_format) _get_default_mps_generator().set_state(new_state_copy)
[docs]def manual_seed(seed: int) -> None: r"""设置生成随机数的种子。 参数: seed (int): 所需的种子。 """ # torch.mps.manual_seed() 可以从全局的 torch.manual_seed() 调用 # 在 torch/random.py 中。所以我们需要确保 mps 是可用的(否则我们只是返回而不会 # 出错) if not torch._C._has_mps: return seed = int(seed) _get_default_mps_generator().manual_seed(seed)
[docs]def seed() -> None: r"""将生成随机数的种子设置为一个随机数。""" _get_default_mps_generator().seed()
[docs]def empty_cache() -> None: r"""释放缓存分配器当前持有的所有未占用的缓存内存,以便其他GPU应用程序可以使用这些内存。""" torch._C._mps_emptyCache()
[docs]def set_per_process_memory_fraction(fraction) -> None: r"""设置用于限制MPS设备上进程内存分配的内存比例。 允许的值等于比例乘以推荐的最大设备内存(从Metal API device.recommendedMaxWorkingSetSize获得)。 如果尝试在进程中分配超过允许值的内存,将在分配器中引发内存不足错误。 参数: fraction(float): 范围: 0~2。允许的内存等于总内存 * 比例。 .. 注意:: 传递0给fraction意味着无限制的分配(可能会导致系统内存不足)。 传递大于1.0的比例允许超出device.recommendedMaxWorkingSetSize返回值的限制。 """ if not isinstance(fraction, float): raise TypeError("fraction参数的类型无效,必须是`float`") if fraction < 0 or fraction > 2: raise ValueError(f"无效的比例值: {fraction}. 允许的范围: 0~2") torch._C._mps_setMemoryFraction(fraction)
[docs]def current_allocated_memory() -> int: r"""返回当前由张量占用的GPU内存,以字节为单位。 .. 注意:: 返回的大小不包括MPSAllocator内存池中的缓存分配。 """ return torch._C._mps_currentAllocatedMemory()
[docs]def driver_allocated_memory() -> int: r"""返回Metal驱动程序为进程分配的总GPU内存,以字节为单位。 .. 注意:: 返回的大小包括MPSAllocator池中的缓存分配以及来自MPS/MPSGraph框架的分配。 """ return torch._C._mps_driverAllocatedMemory()
from . import profiler from .event import Event __all__ = [ "get_rng_state", "manual_seed",