Shortcuts

get_memory_stats

torchtune.training.get_memory_stats(device: device, reset_stats: bool = True) dict[source]

计算传入设备的内存摘要。如果reset_statsTrue,这还将重置CUDA的峰值内存跟踪。这对于获取峰值内存的相对使用数据(例如,模型初始化期间的峰值内存,前向传播期间的峰值内存等)并优化训练各个部分的内存非常有用。

Parameters:
  • device (torch.device) – 获取内存摘要的设备。仅支持CUDA设备。

  • reset_stats (bool) – 是否重置CUDA的峰值内存跟踪。

Returns:

一个包含峰值内存活动、峰值内存分配和峰值内存保留的字典。这个字典对于记录内存统计信息非常有用。

Return type:

字典[str, float]

Raises:

ValueError – 如果传入的设备不是CUDA。