Shortcuts

torch.xpu 的源代码

```html
r"""
此包引入了对XPU后端的支持,特别是针对Intel GPU的优化。

此包是延迟初始化的,因此您可以始终导入它,并使用
:func:`is_available()` 来确定您的系统是否支持XPU。
"""
import threading
import traceback
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch._C
from .. import device as _device
from .._utils import _dummy_type, _LazySeedTracker
from ._utils import _get_device_index
from .streams import Event, Stream

_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls: List[
    Tuple[Callable[[], None], List[str]]
] = []  # 在初始化之前不要调用这些
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = ()  # 类型: 忽略[赋值]


def _is_compiled() -> bool:
    r"""如果编译时支持XPU,则返回true。"""
    return torch._C._has_xpu


if _is_compiled():
    _XpuDeviceProperties = torch._C._XpuDeviceProperties
    _exchange_device = torch._C._xpu_exchangeDevice
    _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
else:
    # 如果PyTorch编译时不带XPU支持,定义虚拟对象
    _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties")  # 类型: 忽略[赋值, 杂项]

    def _exchange_device(device: int) -> int:
        raise NotImplementedError("PyTorch编译时不带XPU支持")

    def _maybe_exchange_device(device: int) -> int:
        raise NotImplementedError("PyTorch编译时不带XPU支持")


[docs]@lru_cache(maxsize=1) def device_count() -> int: r"""返回可用的XPU设备数量。""" if not _is_compiled(): return 0 return torch._C._xpu_getDeviceCount()
[docs]def is_available() -> bool: r"""返回一个布尔值,指示XPU当前是否可用。""" # 此函数从不抛出异常。 return device_count() > 0
def is_bf16_supported(): r"""返回一个布尔值,指示当前XPU设备是否支持dtype bfloat16。""" return True
[docs]def is_initialized(): r"""返回PyTorch的XPU状态是否已初始化。""" return _initialized and not _is_in_bad_fork()
def _lazy_call(callable, **kwargs): if is_initialized(): callable() else: global _lazy_seed_tracker if kwargs.get("seed_all", False): _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) elif kwargs.get("seed", False): _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) else: # 不存储实际的traceback以避免内存循环 _queued_calls.append((callable, traceback.format_stack()))
[docs]def init(): r"""初始化PyTorch的XPU状态。 这是一个关于延迟初始化的Python API,避免在首次访问XPU之前初始化XPU。 如果XPU状态已经初始化,则不执行任何操作。 """ _lazy_init()
def _lazy_init(): global _initialized, _queued_calls if is_initialized() or <span class="
优云智算