Shortcuts

torch.cuda.graphs 的源代码

```html
import gc
from typing import Optional

import torch
from torch.utils import _pytree
from .._utils import _dummy_type

if not hasattr(torch._C, "_CudaStreamBase"):
    # 定义虚拟基类
    torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
    torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
    torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
        "_cuda_isCurrentStreamCapturing"
    )

from torch._C import (  # noqa: F401
    _cuda_isCurrentStreamCapturing,
    _CUDAGraph,
    _graph_pool_handle,
)


[docs]def is_current_stream_capturing(): r"""如果当前CUDA流正在进行CUDA图捕获,则返回True,否则返回False。 如果当前设备上不存在CUDA上下文,则返回False而不初始化上下文。 """ return _cuda_isCurrentStreamCapturing()
# Python shim帮助Sphinx更可靠地处理文档字符串。
[docs]def graph_pool_handle(): r"""返回一个表示图内存池id的不透明令牌。 请参阅 :ref:`图内存管理`。 .. 警告:: 此API处于测试阶段,可能会在未来的版本中更改。 """ return _graph_pool_handle()
# Python shim帮助Sphinx更可靠地处理文档字符串。
[docs]class CUDAGraph(torch._C._CUDAGraph): r"""围绕CUDA图的包装器。 .. 警告:: 此API处于测试阶段,可能会在未来的版本中更改。 """ def __new__(cls): return super().__new__(cls)
[docs] def capture_begin(self, pool=None, capture_error_mode="global"): r"""开始在当前流上捕获CUDA工作。 通常,您不应该自己调用 ``capture_begin``。 使用 :class:`~torch.cuda.graph` 或 :func:`~torch.cuda.make_graphed_callables`, 它们会在内部调用 ``capture_begin``。 参数: pool (可选): 令牌(由 :func:`~torch.cuda.graph_pool_handle` 或 :meth:`other_Graph_instance.pool()` 返回),提示此图可能与指示的池共享内存。 请参阅 :ref:`图内存管理`。 capture_error_mode (str, 可选): 指定图捕获流的cudaStreamCaptureMode。 可以是 "global", "thread_local" 或 "relaxed"。在cuda图捕获期间,某些操作(如cudaMalloc)可能不安全。 "global" 会在其他线程中的操作出错,"thread_local" 只会对当前线程中的操作出错,而 "relaxed" 不会对这些操作出错。 除非您熟悉 `cudaStreamCaptureMode `_,否则不要更改此设置。 """ # noqa: B950 super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
[docs] def capture_end(self): r"""结束当前流上的CUDA图捕获。 在 ``capture_end`` 之后,可以在此实例上调用 ``replay``。 通常,您不应该自己调用 ``capture_end``。 使用 :class:`~torch.cuda.graph` 或 :func:`~torch.cuda.make_graphed_callables`, 它们会在内部调用 ``capture_end``。 """ super().capture_end()
[docs] def replay(self): r"""重放此图捕获的CUDA工作。""" super().replay()
[docs] def reset(self): r"""删除此实例当前持有的图。""" super().reset()
[docs] def pool(self): r"""返回一个表示此图内存池id的不透明令牌。 此id可以选择传递给另一个图的 ``capture_begin``, 提示另一个图可能共享相同的内存池。 """ return super().pool()
[docs] def enable_debug_mode(self): r"""为CUDAGraph.debug_dump启用调试模式。""" return super().enable_debug_mode()
[docs] def debug_dump(self, debug_path): r""" 参数: debug_path (必需): 转储图的路径。 如果通过CUDAGraph.enable_debug_mode()启用了调试,则调用调试函数转储图。 """ return super().debug_dump(debug_path)
[docs]class <