Shortcuts

get_dtype

torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype[source]

获取与给定精度字符串对应的 torch.dtype。如果没有传递字符串,我们将默认使用 torch.float32。

注意

如果使用CUDA设备请求bf16精度,我们会验证设备是否确实支持bf16内核。如果不支持,则会引发RuntimeError

Parameters:
  • dtype (可选[str]) – 精度数据类型。默认值:None,在这种情况下我们默认使用 torch.float32

  • device (可选[torch.device]) – 用于训练的设备。仅支持CUDA和CPU设备。如果传入CUDA设备,会进行额外的检查以确保设备支持所请求的精度。默认值:None,在这种情况下假定为CUDA设备。

Raises:
Returns:

对应的 torch.dtype。

Return type:

torch.dtype