torch.cuda.make_graphed_callables¶
- torch.cuda.make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None)[源代码]¶
接受可调用对象(函数或
nn.Module
)并返回图形化版本。每个绘制的可调用对象的前向传播在其源可调用对象的CUDA工作内部作为一个CUDA图在单个自动求导节点中运行。
绘制的可调用对象的前向传播也会向自动求导图追加一个反向节点。在反向传播过程中,此节点将可调用对象的反向工作作为CUDA图运行。
因此,每个绘制的可调用对象都应是其源可调用对象的即插即用替代品,适用于启用了自动求导的训练循环。
参见部分网络捕获以了解详细的使用方法和限制。
如果你传递一个包含多个可调用对象的元组,它们的捕获将使用相同的内存池。 请参阅图内存管理以了解何时适用此方法。
- Parameters
可调用对象 (torch.nn.Module 或 Python 函数,或 元组 的 这些) – 要图化的可调用对象或可调用对象元组。 请参阅 图内存管理 以了解何时传递可调用对象元组是合适的。 如果您传递一个可调用对象元组,它们在元组中的顺序必须与它们在实际工作负载中运行的顺序相同。
sample_args (元组 的 张量,或 元组 的 元组 的 张量) – 每个可调用对象的样本参数。 如果传递了一个可调用对象,
sample_args
必须是一个参数张量的元组。 如果传递了一个可调用对象的元组,sample_args
必须是参数张量的元组的元组。num_warmup_iters (int) – 预热迭代的次数。目前,
DataDistributedParallel
需要 11 次迭代进行预热。默认值:3
。allow_unused_input (布尔值) – 如果为 False,指定在计算输出时未使用的输入(因此它们的梯度始终为零)将是一个错误。默认为 False。
池(可选)——令牌(由
graph_pool_handle()
或other_Graph_instance.pool()
返回),提示此图可能与指示的池共享内存。请参阅图内存管理。
注意
在
sample_args
中的每个 Tensor 的requires_grad
状态必须与训练循环中相应真实输入的预期状态匹配。警告
此API目前处于测试阶段,可能会在未来的版本中进行更改。
警告
sample_args
对于每个可调用对象,必须仅包含张量。不允许其他类型。警告
返回的可调用对象不支持高阶微分(例如,双重反向传播)。
警告
在传递给
make_graphed_callables()
的任何Module
中,只有参数可能是可训练的。缓冲区必须具有requires_grad=False
。警告
在将
torch.nn.Module
通过make_graphed_callables()
传递后,您可能无法添加或删除该模块的任何参数或缓冲区。警告
torch.nn.Module
传递给make_graphed_callables()
时,不能在其上注册模块钩子。然而,在通过make_graphed_callables()
传递后,允许在模块上注册钩子。警告
当运行一个图形化的可调用对象时,你必须按照相同的顺序和格式传递其参数,就像它们在那个可调用对象的
sample_args
中出现的那样。警告
自动混合精度在
make_graphed_callables()
中仅在禁用缓存时支持。上下文管理器torch.cuda.amp.autocast()必须设置cache_enabled=False。