get_dtype¶
- torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype[source]¶
获取与给定精度字符串对应的 torch.dtype。如果没有传递字符串,我们将默认使用 torch.float32。
注意
如果使用CUDA设备请求bf16精度,我们会验证设备是否确实支持bf16内核。如果不支持,则会引发
RuntimeError。- Parameters:
dtype (可选[str]) – 精度数据类型。默认值:
None,在这种情况下我们默认使用 torch.float32device (可选[torch.device]) – 用于训练的设备。仅支持CUDA和CPU设备。如果传入CUDA设备,会进行额外的检查以确保设备支持所请求的精度。默认值:
None,在这种情况下假定为CUDA设备。
- Raises:
ValueError – 如果库不支持精度
RuntimeError – 如果请求了bf16精度但此硬件不支持。
- Returns:
对应的 torch.dtype。
- Return type: