图¶
- class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[源代码]¶
用于将CUDA工作捕获到
torch.cuda.CUDAGraph对象中以供后续重放的上下文管理器。参见CUDA Graphs以获取一般介绍、详细使用方法及约束条件。
- Parameters
cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的图形对象。
池(可选)——不透明的令牌(通过调用
graph_pool_handle()或other_Graph_instance.pool())提示此图的捕获 可能从指定的池中共享内存。请参阅图内存管理。stream(torch.cuda.Stream,可选)– 如果提供,将在上下文中设置为当前流。 如果未提供,
graph将在上下文中将其自己的内部辅助流设置为当前流。capture_error_mode (str, 可选) – 指定用于图捕获流的cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在cuda图捕获期间,某些操作(如cudaMalloc)可能不安全。“global”会在其他线程中的操作上出错,“thread_local”仅会在当前线程中的操作上出错,而“relaxed”不会在操作上出错。除非您熟悉cudaStreamCaptureMode,否则请勿更改此设置。
注意
为了实现有效的内存共享,如果你传递一个之前捕获使用的
pool,并且之前的捕获使用了显式的stream参数,你应该将相同的stream参数传递给这次捕获。警告
此API目前处于测试阶段,可能会在未来的版本中进行更改。