set_default_dtype¶
- torchtune.training.set_default_dtype(dtype: dtype) Generator[None, None, None][source]¶
上下文管理器用于设置torch的默认数据类型。
- Parameters:
dtype (torch.dpython:type) – 上下文管理器中所需的默认数据类型。
- Returns:
用于设置默认数据类型的上下文管理器。
- Return type:
上下文管理器
示例
>>> with set_default_dtype(torch.bfloat16): >>> x = torch.tensor([1, 2, 3]) >>> x.dtype torch.bfloat16