torch.xpu.streams 的源代码
import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from .._utils import _dummy_type
if not hasattr(torch._C, "_XpuStreamBase"):
# 定义虚拟基类
torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
[docs]class Stream(torch._C._XpuStreamBase, _StreamBase):
r"""XPU 流包装器。
XPU 流是属于特定设备的线性执行序列,独立于其他流。
参数:
device(torch.device 或 int, 可选): 分配流的设备。如果 :attr:`device` 为 ``None`` (默认) 或负整数,将使用当前设备。
priority(int, 可选): 流的优先级,应为 0 或负数,负数表示高优先级。默认情况下,流的优先级为 0。
"""
def __new__(cls, device=None, priority=0, **kwargs):
# 设置设备管理器开销较大,因此除非必要,否则我们避免使用它
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
return super().__new__(cls, priority=priority, **kwargs)
else:
with torch.xpu.device(device):
return super().__new__(cls, priority=priority, **kwargs)
[docs] def wait_event(self, event):
r"""使提交到流的所有未来工作等待一个事件。
参数:
event (torch.xpu.Event): 等待的事件。
"""
event.wait(self)
[docs] def wait_stream(self, stream):
r"""与另一个流同步。
提交到此流的所有未来工作将等待直到在调用时提交到给定流的所有内核完成。
参数:
stream (Stream): 要同步的流。
"""
self.wait_event(stream.record_event())
[docs] def record_event(self, event=None):
r"""记录一个事件。
参数:
event (torch.xpu.Event, 可选): 要记录的事件。如果未给出,将分配一个新的。
返回:
记录的事件。
"""
if event is None:
event = Event()
event.record(self)
return event
[docs] def query(self):
r"""检查所有提交的工作是否已完成。
返回:
一个布尔值,指示此流中的所有内核是否已完成。
"""
return super().query()
[docs] def synchronize(self):
r"""等待此流中的所有内核完成。"""
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.sycl_queue)
def __eq__(self, o):
if isinstance(o, Stream):
return super().__eq__(o)
return False
def __hash__(self):
return hash((self.sycl_queue, self.device))
def __repr__(self):
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
[docs]class Event(torch._C._XpuEventBase, _EventBase):
r"""XPU 事件包装器。
XPU 事件是同步标记,可用于监控设备的进度,并同步 XPU 流。
底层 XPU 事件在事件首次记录时延迟初始化。创建后,只有同一设备上的流可以记录事件。但是,任何设备上的流都可以等待事件。
参数:
enable_timing (bool, 可选): 指示事件是否应测量时间 (默认: ``False``)
"""
def __new__(cls, enable_timing=False<span