onnx._custom_element_types

此模块定义了numpy不支持的自定义数据类型。 函数 onnx.numpy_helper.from_array()onnx.numpy_helper.to_array() 使用它们 来在这些类型之间转换数组。 类 onnx.reference.ReferenceEvalutor 也使用它们。 例如,为了创建用于单元测试的此类数组,可以方便地编写 如下内容:

import numpy as np
from onnx import TensorProto
from onnx.reference.ops.op_cast import Cast_19 as Cast

tensor_bfloat16 = Cast.eval(np.array([0, 1], dtype=np.float32), to=TensorProto.BFLOAT16)

下面使用的numpy表示dtypes仅供内部使用。它们可能会根据这些numpy类型的行业标准化在未来发生变化。

onnx._custom_element_types.bfloat16 = dtype((numpy.uint16, [('bfloat16', '<u2')]))

将bfloat16定义为uint16。

onnx._custom_element_types.float4e2m1 = dtype((numpy.uint8, [('float4e2m1', 'u1')]))

定义浮点数4 e2m1类型,详情请参见存储在4位中的浮点数以获取技术细节。 请注意,一个整数使用一个字节存储,因此其大小是其onnx大小的两倍。

onnx._custom_element_types.float8e4m3fn = dtype((numpy.uint8, [('e4m3fn', 'u1')]))

定义 float 8 e4m3fn 类型,有关技术细节,请参见 Float stored in 8 bits

onnx._custom_element_types.float8e4m3fnuz = dtype((numpy.uint8, [('e4m3fnuz', 'u1')]))

定义 float 8 e4m3fnuz 类型,详情请参阅 存储在8位中的浮点数 技术细节。

onnx._custom_element_types.float8e5m2 = dtype((numpy.uint8, [('e5m2', 'u1')]))

定义 float 8 e5m2 类型,详情请参阅 Float stored in 8 bits 的技术细节。

onnx._custom_element_types.float8e5m2fnuz = dtype((numpy.uint8, [('e5m2fnuz', 'u1')]))

定义 float 8 e5m2fnuz 类型,详情请参阅 存储在8位中的浮点数 技术细节。

onnx._custom_element_types.int4 = dtype((numpy.int8, [('int4', 'i1')]))

定义int4,详情请参见4位整数类型的技术细节。 请注意,一个整数使用一个字节存储,因此其大小是其onnx大小的两倍。

onnx._custom_element_types.uint4 = dtype((numpy.uint8, [('uint4', 'u1')]))

定义int4,详情请参见4位整数类型的技术细节。 请注意,一个整数使用一个字节存储,因此其大小是其onnx大小的两倍。