Shortcuts

torch.cuda.make_graphed_callables

torch.cuda.make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None)[源代码]

接受可调用对象(函数或nn.Module)并返回图形化版本。

每个绘制的可调用对象的前向传播在其源可调用对象的CUDA工作内部作为一个CUDA图在单个自动求导节点中运行。

绘制的可调用对象的前向传播也会向自动求导图追加一个反向节点。在反向传播过程中,此节点将可调用对象的反向工作作为CUDA图运行。

因此,每个绘制的可调用对象都应是其源可调用对象的即插即用替代品,适用于启用了自动求导的训练循环。

参见部分网络捕获以了解详细的使用方法和限制。

如果你传递一个包含多个可调用对象的元组,它们的捕获将使用相同的内存池。 请参阅图内存管理以了解何时适用此方法。

Parameters
  • 可调用对象 (torch.nn.ModulePython 函数,或 元组这些) – 要图化的可调用对象或可调用对象元组。 请参阅 图内存管理 以了解何时传递可调用对象元组是合适的。 如果您传递一个可调用对象元组,它们在元组中的顺序必须与它们在实际工作负载中运行的顺序相同。

  • sample_args (元组张量,或 元组元组张量) – 每个可调用对象的样本参数。 如果传递了一个可调用对象,sample_args 必须是一个参数张量的元组。 如果传递了一个可调用对象的元组,sample_args 必须是参数张量的元组的元组。

  • num_warmup_iters (int) – 预热迭代的次数。目前,DataDistributedParallel 需要 11 次迭代进行预热。默认值:3

  • allow_unused_input (布尔值) – 如果为 False,指定在计算输出时未使用的输入(因此它们的梯度始终为零)将是一个错误。默认为 False。

  • 可选)——令牌(由graph_pool_handle()other_Graph_instance.pool()返回),提示此图可能与指示的池共享内存。请参阅图内存管理

注意

sample_args 中的每个 Tensor 的 requires_grad 状态必须与训练循环中相应真实输入的预期状态匹配。

警告

此API目前处于测试阶段,可能会在未来的版本中进行更改。

警告

sample_args 对于每个可调用对象,必须仅包含张量。不允许其他类型。

警告

返回的可调用对象不支持高阶微分(例如,双重反向传播)。

警告

在传递给 make_graphed_callables() 的任何 Module 中,只有参数可能是可训练的。缓冲区必须具有 requires_grad=False

警告

在将 torch.nn.Module 通过 make_graphed_callables() 传递后,您可能无法添加或删除该模块的任何参数或缓冲区。

警告

torch.nn.Module 传递给 make_graphed_callables() 时,不能在其上注册模块钩子。然而,在通过 make_graphed_callables() 传递后,允许在模块上注册钩子。

警告

当运行一个图形化的可调用对象时,你必须按照相同的顺序和格式传递其参数,就像它们在那个可调用对象的 sample_args 中出现的那样。

警告

自动混合精度在make_graphed_callables()中仅在禁用缓存时支持。上下文管理器torch.cuda.amp.autocast()必须设置cache_enabled=False