torch.cuda.nvtx 的源代码
r"""此包增加了对用于分析的NVIDIA工具扩展(NVTX)的支持。"""
from contextlib import contextmanager
try:
from torch._C import _nvtx
except ImportError:
class _NVTXStub:
@staticmethod
def _fail(*args, **kwargs):
raise RuntimeError(
"未安装NVTX函数。您确定您有CUDA构建吗?"
)
rangePushA = _fail
rangePop = _fail
markA = _fail
_nvtx = _NVTXStub() # type: ignore[assignment]
__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
[docs]def range_push(msg):
"""
将一个范围推入嵌套范围跨度的堆栈。返回开始范围的基于零的深度。
参数:
msg (str): 与范围关联的ASCII消息
"""
return _nvtx.rangePushA(msg)
[docs]def range_pop():
"""从嵌套范围跨度的堆栈中弹出一个范围。返回结束范围的基于零的深度。"""
return _nvtx.rangePop()
def range_start(msg) -> int:
"""
用字符串消息标记范围的开始。它返回一个唯一的句柄
用于将此范围传递给相应的rangeEnd()调用。
此函数与range_push/range_pop的一个关键区别在于
range_start/range_end版本支持跨线程的范围(在一个线程上开始
并在另一个线程上结束)。
返回: 一个范围句柄(uint64_t),可以传递给range_end()。
参数:
msg (str): 与范围关联的ASCII消息。
"""
return _nvtx.rangeStartA(msg)
def range_end(range_id) -> None:
"""
标记给定range_id的范围的结束。
参数:
range_id (int): 开始范围的唯一句柄。
"""
_nvtx.rangeEnd(range_id)
[docs]def mark(msg):
"""
描述在某个时间点发生的瞬时事件。
参数:
msg (str): 与事件关联的ASCII消息。
"""
return _nvtx.markA(msg)
@contextmanager
def range(msg, *args, **kwargs):
"""
上下文管理器/装饰器,在作用域开始时推送一个NVTX范围,
并在结束时弹出它。如果给出了额外的参数,
它们将作为参数传递给msg.format()。
参数:
msg (str): 与范围关联的消息
"""
range_push(msg.format(*args, **kwargs))
try:
yield
finally:
range_pop()