Shortcuts

torch.set_default_dtype

torch.set_default_dtype(d)[源代码]

将默认的浮点数数据类型设置为 d。支持 torch.float32 和 torch.float64 作为输入。其他数据类型可能会被接受但没有抱怨,但它们不受支持,并且不太可能按预期工作。

当 PyTorch 初始化时,其默认的浮点数数据类型是 torch.float32,而 set_default_dtype(torch.float64) 的目的是为了实现类似 NumPy 的类型推断。默认的浮点数数据类型用于:

  1. 隐式确定默认的复数数据类型。当默认的浮点类型为float32时,默认的复数数据类型为complex64,而当默认的浮点类型为float64时,默认的复数类型为complex128。

  2. 推断使用Python浮点数或复数Python构造的张量的数据类型。请参见下面的示例。

  3. 确定布尔值和整数张量之间以及Python浮点数和复数Python数字之间的类型提升结果。

Parameters

d (torch.dtype) – 要设为默认的浮点数数据类型。 可以是 torch.float32 或 torch.float64。

示例

>>> # 初始默认的浮点类型是 torch.float32
>>> # Python 浮点数被解释为 float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # 初始默认的浮点类型是 torch.complex64
>>> # 复数 Python 数字被解释为 complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python浮点数现在被解释为float64
>>> torch.tensor([1.2, 3]).dtype    # 一个新的浮点数张量
torch.float64
>>> # 复数Python数现在被解释为complex128
>>> torch.tensor([1.2, 3j]).dtype   # 一个新的复数张量
torch.complex128
优云智算