calib_utils
提供基本的校准工具。
类
校准数据提供者类。 |
|
带有随机数据提供器的校准数据读取器类。 |
函数
读取TensorRT校准缓存并返回为字典。 |
- class CalibrationDataProvider
基础:
CalibrationDataReader校准数据提供者类。
- __init__(onnx_path, calibration_data, calibration_shapes=None)
使用校准数据迭代器初始化数据提供者类。
- Parameters:
onnx_path (str) – ONNX模型的路径。
calibration_data (ndarray | Dict[str, ndarray]) – 用于校准模型的Numpy数据。 例如,如果一个模型的输入形状为{“sample”: (2, 4, 64, 64), “timestep”: (1,), “encoder_hidden_states”: (2, 16, 768)},那么校准数据应该具有形状为{“sample”: (1024, 4, 64, 64), “timestep”: (512,), “encoder_hidden_states”: (1024, 16, 768)}的张量字典,以便使用512个样本进行校准。
calibration_shapes (str) –
- get_next()
返回阅读器中下一个可用的校准输入。
- class RandomDataProvider
基础:
CalibrationDataReader带有随机数据提供器的校准数据读取器类。
- __init__(onnx_model, calibration_shapes=None)
使用随机校准数据初始化数据读取器类。
- Parameters:
onnx_model (str | ModelProto) –
calibration_shapes (str) –
- get_next()
返回阅读器中下一个可用的校准输入。
- import_scales_from_calib_cache(cache_path)
读取TensorRT校准缓存并返回为字典。
- Parameters:
cache_path (str) – 校准缓存路径。
- Returns:
float_scale}.
- Return type:
带有比例的字典,格式为 {tensor_name