get_memory_stats¶
- torchtune.training.get_memory_stats(device: device, reset_stats: bool = True) dict[source]¶
计算传入设备的内存摘要。如果
reset_stats为True,这还将重置CUDA的峰值内存跟踪。这对于获取峰值内存的相对使用数据(例如,模型初始化期间的峰值内存,前向传播期间的峰值内存等)并优化训练各个部分的内存非常有用。- Parameters:
device (torch.device) – 获取内存摘要的设备。仅支持CUDA设备。
reset_stats (bool) – 是否重置CUDA的峰值内存跟踪。
- Returns:
一个包含峰值内存活动、峰值内存分配和峰值内存保留的字典。这个字典对于记录内存统计信息非常有用。
- Return type:
- Raises:
ValueError – 如果传入的设备不是CUDA。