torch.mps.profiler 的源代码
import contextlib
import torch
__all__ = ["start", "stop", "profile"]
[docs]def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
r"""从MPS后端开始OS Signpost跟踪。
生成的OS Signposts可以在XCode Instruments Logging工具中记录和查看。
参数:
mode(str): OS Signpost跟踪模式可以是"interval"、"event",
或两者"interval,event"。
间隔模式跟踪操作执行的持续时间,
而事件模式标记执行的完成。
有关更多信息,请参阅文档`Recording Performance Data`_。
wait_until_completed(bool): 等待MPS Stream完成
执行每个编码的GPU操作。这有助于在跟踪的时间轴上生成单个调度。
请注意,启用此选项会对性能产生负面影响。
.. _Recording Performance Data:
https://developer.apple.com/documentation/os/logging/recording_performance_data
"""
mode_normalized = mode.lower().replace(" ", "")
torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
[docs]def stop():
r"""停止从MPS后端生成OS Signpost跟踪。"""
torch._C._mps_profilerStopTrace()
[docs]@contextlib.contextmanager
def profile(mode: str = "interval", wait_until_completed: bool = False):
r"""上下文管理器,用于启用从MPS后端生成OS Signpost跟踪。
参数:
mode(str): OS Signpost跟踪模式可以是"interval"、"event",
或两者"interval,event"。
间隔模式跟踪操作执行的持续时间,
而事件模式标记执行的完成。
有关更多信息,请参阅文档`Recording Performance Data`_。
wait_until_completed(bool): 等待MPS Stream完成
执行每个编码的GPU操作。这有助于在跟踪的时间轴上生成单个调度。
请注意,启用此选项会对性能产生负面影响。
.. _Recording Performance Data:
https://developer.apple.com/documentation/os/logging/recording_performance_data
"""
try:
start(mode, wait_until_completed)
yield
finally:
stop()