ONNX 形状推断¶
ONNX 提供了对 ONNX 图形进行形状推断的可选实现。此实现涵盖了每个核心运算符,并提供了可扩展性的接口。因此,您可以选择在您的图形上调用现有的形状推断功能,或者定义形状推断实现以配合您的自定义运算符(或两者都做!)。形状推断函数作为 OpSchema 对象的成员存储。
在ONNX 1.10版本中,符号生成和传播以及形状数据传播被添加到了ONNX图级形状推断中。详细提案请参见这里
背景¶
请参阅部分的IR.md以了解静态张量形状的回顾。
特别是,静态张量形状(由TensorShapeProto表示)与运行时张量形状不同。当确切的运行时张量形状在静态(即在编译时)未知时,通常会使用此功能。
一个具有未定义
shape字段的Tensor用于表示未知秩的张量。一个定义了
shape的Tensor表示一个已知秩的张量。每个
TensorShapeProto的Dimension可以有一个已知的整数值 (由dim_value字段表示),或者它可以有一个由符号标识表示的未知值 (由dim_param字段表示),或者它可能没有定义任何字段 (在这种情况下,它表示一个匿名的未知值)。
调用形状推断¶
形状推断可以通过C++或Python调用。Python API的描述和示例可以在这里找到。
C++ API 由一个单一的函数组成
shape_inference::InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry);
第一个参数是一个ModelProto,用于执行形状推断,该参数会就地标注形状信息。第二个参数是可选的。
限制¶
形状推断不能保证是完整的。特别是,一些动态行为会阻止形状推断的流程,例如将形状重塑为动态提供的形状。此外,并非所有操作符都需要有形状推断的实现。
形状推断仅适用于常量和简单变量。它不支持包含变量的算术表达式。例如,Concat 在形状为 (5, 2) 和 (7, 2) 的张量上可以推断出形状为 (12, 2) 的结果,但在形状为 (5, 2) 和 (N, 2) 的张量上,Concat 只会产生 (M, 2),而不是包含 N+5 的表示。请注意,不同的未知符号值将被传播,因此这里的 M 表示一个未知量,与其他出现的 M 相同。
这些限制是当前实现的属性,不是基本约束 - 如果您需要更高级的功能,请告诉我们!
为操作符实现形状推断¶
你可以通过以下方式向你的操作符的Schema添加一个形状推断函数
OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction 定义在
shape_inference.h 中,与核心
接口结构体 InferenceContext 以及一系列辅助
方法一起。InferenceContext 是提供给
你的推理函数的核心结构体。它允许访问有关
操作符输入的信息,并且允许写出推断出的信息。
要查看众多示例,请在代码库中搜索TypeAndShapeInferenceFunction的出现。一个相对复杂的实现是Concat的实现,位于onnx/defs/tensor/defs.cc中。
在实现操作符的形状推断方法时,请注意以下几点以避免常见错误:
在访问任何输入的
shape之前,代码必须检查形状是否可用。如果不可用,则应将其视为一个动态张量,其秩未知,并应适当处理。通常,形状推断逻辑由调用hasInputShape或hasNInputShapes来保护。在访问任何维度的
dim_value或dim_param之前,代码必须检查这些字段是否有值。特别是,代码必须处理维度可能没有静态已知值的可能性。
在shape_inference.h中有几个实用函数来处理各种常见情况。
使用
checkInputRank来处理必须具有固定秩的输入。(参见RoiAlign的推理作为示例。)unifyInputDim和unifyDim以及updateOutputShape可以在期望多个输入维度相同,并且输入维度传播到特定输出维度时使用。(参见RoiAlign的推理示例。)重载运算符
*和/可以在符号维度上使用,当输出维度通过算术从输入维度计算时。(参见SpaceToDepth的推理示例。)
这些工具可以安全地处理缺失的形状和尺寸。
示例: 考虑一个简单的矩阵乘法操作,它期望输入的形状为
[M,K] 和 [K,N],并返回一个形状为 [M,N] 的输出。这可以如下编码:
// Check that input 0 has rank 2 (if its rank is known).
checkInputRank(ctx, 0, 2);
// Check that input 1 has rank 2 (if its rank is known).
checkInputRank(ctx, 1, 2);
Dim M, K, N;
// Check various dimensions, handling missing dimensions/shapes safely.
unifyInputDim(ctx, 0, 0, M);
unifyInputDim(ctx, 0, 1, K);
unifyInputDim(ctx, 1, 0, K);
unifyInputDim(ctx, 1, 1, N);
updateOutputShape(ctx, 0, {M. N});