JitScalarType¶
- class torch.onnx.JitScalarType(value)¶
在 torch 中定义的标量类型。
使用
JitScalarType将 PyTorch 和 JIT 标量类型转换为 ONNX 标量类型。示例
>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() TensorProtoDataType.FLOAT
>>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() TensorProtoDataType.FLOAT
>>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() TensorProtoDataType.FLOAT
- classmethod from_dtype(dtype)[源代码]¶
将 torch dtype 转换为 JitScalarType。
- Note: DO NOT USE this API when dtype comes from a torch._C.Value.type() calls.
在形状信息不存在的情况下,可能会引发“RuntimeError: INTERNAL ASSERT FAILED at “../aten/src/ATen/core/jit_type_base.h”错误。 应改用from_value API,这样更安全。
- Parameters
- Returns
JitScalarType
- Raises
OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或如果它是 None。
- Return type
- classmethod from_value(value, default=None)[源代码]¶
从值的标量类型创建一个JitScalarType。
- Parameters
- Returns
JitScalarType。
- Raises
OnnxExporterError – 如果值没有有效的标量类型且默认值为 None。
SymbolicValueError – 当 value.type() 的信息为空且默认值为 None 时
- Return type
- scalar_name()[源代码]¶
将 JitScalarType 转换为 JIT 标量类型名称。
- Return type
字面量[‘字节’, ‘字符’, ‘双精度’, ‘浮点’, ‘半精度’, ‘整数’, ‘长整数’, ‘短整数’, ‘布尔’, ‘复数半精度’, ‘复数浮点’, ‘复数双精度’, ‘QInt8’, ‘QUInt8’, ‘QInt32’, ‘BFloat16’, ‘Float8E5M2’, ‘Float8E4M3FN’, ‘Float8E5M2FNUZ’, ‘Float8E4M3FNUZ’, ‘未定义’]
- torch_name()[源代码]¶
将 JitScalarType 转换为 torch 类型名称。
- Return type
字面量[‘bool’, ‘uint8_t’, ‘int8_t’, ‘double’, ‘float’, ‘half’, ‘int’, ‘int64_t’, ‘int16_t’, ‘complex32’, ‘complex64’, ‘complex128’, ‘qint8’, ‘quint8’, ‘qint32’, ‘bfloat16’, ‘float8_e5m2’, ‘float8_e4m3fn’, ‘float8_e5m2fnuz’, ‘float8_e4m3fnuz’]