Shortcuts

torch.cuda.nvtx 的源代码

r"""此包增加了对用于分析的NVIDIA工具扩展(NVTX)的支持。"""

from contextlib import contextmanager

try:
    from torch._C import _nvtx
except ImportError:

    class _NVTXStub:
        @staticmethod
        def _fail(*args, **kwargs):
            raise RuntimeError(
                "未安装NVTX函数。您确定您有CUDA构建吗?"
            )

        rangePushA = _fail
        rangePop = _fail
        markA = _fail

    _nvtx = _NVTXStub()  # type: ignore[assignment]

__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]


[docs]def range_push(msg): """ 将一个范围推入嵌套范围跨度的堆栈。返回开始范围的基于零的深度。 参数: msg (str): 与范围关联的ASCII消息 """ return _nvtx.rangePushA(msg)
[docs]def range_pop(): """从嵌套范围跨度的堆栈中弹出一个范围。返回结束范围的基于零的深度。""" return _nvtx.rangePop()
def range_start(msg) -> int: """ 用字符串消息标记范围的开始。它返回一个唯一的句柄 用于将此范围传递给相应的rangeEnd()调用。 此函数与range_push/range_pop的一个关键区别在于 range_start/range_end版本支持跨线程的范围(在一个线程上开始 并在另一个线程上结束)。 返回: 一个范围句柄(uint64_t),可以传递给range_end()。 参数: msg (str): 与范围关联的ASCII消息。 """ return _nvtx.rangeStartA(msg) def range_end(range_id) -> None: """ 标记给定range_id的范围的结束。 参数: range_id (int): 开始范围的唯一句柄。 """ _nvtx.rangeEnd(range_id)
[docs]def mark(msg): """ 描述在某个时间点发生的瞬时事件。 参数: msg (str): 与事件关联的ASCII消息。 """ return _nvtx.markA(msg)
@contextmanager def range(msg, *args, **kwargs): """ 上下文管理器/装饰器,在作用域开始时推送一个NVTX范围, 并在结束时弹出它。如果给出了额外的参数, 它们将作为参数传递给msg.format()。 参数: msg (str): 与范围关联的消息 """ range_push(msg.format(*args, **kwargs)) try: yield finally: range_pop()
优云智算