TabularPredictor.distill

TabularPredictor.distill(train_data: DataFrame | str = None, tuning_data: DataFrame | str = None, augmentation_data: DataFrame = None, time_limit: float = None, hyperparameters: dict | str = None, holdout_frac: float = None, teacher_preds: str = 'soft', augment_method: str = 'spunge', augment_args: dict = {'max_size': 100000, 'size_factor': 5}, models_name_suffix: str = None, verbosity: int = None)[source]

[实验性] 将AutoGluon最准确的集成预测器提炼为更简单/更快且需要更少内存/计算的单一模型。 蒸馏可以产生一个比直接在原始训练数据上拟合的相同模型更准确的模型。 调用distill()后,此预测器中将有更多模型可用,可以使用predictor.leaderboard(test_data)进行评估,并通过以下方式部署:predictor.predict(test_data, model=MODEL_NAME)。 如果在fit()中先前设置了cache_data=False,这将引发异常。

注意:在catboost v0.24发布之前,在多类分类中使用CatBoost学生的distill()需要您首先安装catboost-dev:pip install catboost-dev

Parameters:
  • train_data (str 或 pd.DataFrame, 默认 = None) – 与 fit()train_data 参数相同。 如果为 None,将从用于生成此 Predictor 的 fit() 调用中加载相同的训练数据。

  • tuning_data (str 或 pd.DataFrame, 默认 = None) – 与 fit()tuning_data 参数相同。 如果 tuning_data = Nonetrain_data = None:将从用于生成此 Predictor 的 fit() 调用中加载相同的训练/验证分割, 除非之前使用了 bagging/stacking,在这种情况下将执行新的训练/验证分割。

  • augmentation_data (pd.DataFrame, 默认 = None) – 一个可选的额外数据集,包含未标记的行,可用于在蒸馏过程中增强用于拟合学生模型的数据集(如果为None则忽略)。

  • time_limit (int, default = None) – 蒸馏过程大约应该运行多长时间(以秒为单位)。 如果为None,则不会强制执行时间限制,允许蒸馏模型完全训练。

  • 超参数 (字典字符串, 默认 = 无) – 指定使用哪些模型作为学生模型以及为它们使用哪些超参数值。 与 fit()hyperparameters 参数相同。 如果 = 无,则学生模型将使用与生成此预测器时使用的 fit() 相同的超参数。 注意:目前仅支持 [‘GBM’,’NN_TORCH’,’RF’,’CAT’] 学生模型的蒸馏,其他模型及其超参数在此处被忽略。

  • holdout_frac (float) – 与 holdout_frac 参数相同,参考 TabularPredictor.fit()

  • teacher_preds (str, default = 'soft') – 从哪种形式的教师预测中提取(教师指的是最准确的AutoGluon集成预测器)。 如果为None,我们只使用原始标签进行训练(无数据增强)。 如果为‘hard’,标签是由teacher.predict()给出的硬教师预测。 如果为‘soft’,标签是由teacher.predict_proba()给出的软教师预测。 注意:对于回归问题,‘hard’和‘soft’是等价的。 如果augment_method不为None,教师预测仅用于标记增强数据(训练数据保持原始标签)。 要应用标签平滑:teacher_preds=’onehot’将使用原始训练数据标签转换为多类问题的一热向量(无数据增强)。

  • augment_method (str, default='spunge') –

    指定用于生成蒸馏学生模型的增强数据的方法。 选项包括:

    None : 不执行数据增强。 ‘munge’ : MUNGE算法 (https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf)。 ‘spunge’ : MUNGE算法的一个更简单、更高效的变体。

  • augment_args (dict, default = {'size_factor':5, 'max_size': int(1e5)}) –

    包含以下控制所选augment_method的kwargs(如果augment_method=None,则忽略这些参数):

    ’num_augmented_samples’: int, 蒸馏过程中使用的增强数据点的数量。如果指定,将覆盖‘size_factor’和‘max_size’。 ‘max_size’: float, 要添加的最大增强数据点数量(如果指定了‘num_augmented_samples’,则忽略此参数)。 ‘size_factor’: float, 如果n = 训练数据样本大小,则添加int(n * size_factor)个增强数据点,最多到‘max_size’。 augment_args中的较大值会减慢distill()的运行时间,如果提供的时间限制太小,可能会产生更差的结果。 您还可以为autogluon.tabular.augmentation.distill_utils中的spunge_augmentmunge_augment函数传递kwargs。

  • models_name_suffix (str, default = None) – 可选的附加后缀,可以附加在所有蒸馏学生模型名称的末尾。 注意:默认情况下,所有蒸馏模型的名称中都会包含‘_DSTL’子字符串。

  • verbosity (int, default = None) – 控制蒸馏过程中打印输出的数量(4 = 最高,0 = 最低)。 与 verbosity 参数相同,该参数属于 TabularPredictor。 如果为 None,则再次使用之前拟合中使用的相同 verbosity

Return type:

对应于蒸馏模型的名称列表(str)。

示例

>>> from autogluon.tabular import TabularDataset, TabularPredictor
>>> train_data = TabularDataset('train.csv')
>>> predictor = TabularPredictor(label='class').fit(train_data, auto_stack=True)
>>> distilled_model_names = predictor.distill()
>>> test_data = TabularDataset('test.csv')
>>> ldr = predictor.leaderboard(test_data)
>>> model_to_deploy = distilled_model_names[0]
>>> predictor.predict(test_data, model=model_to_deploy)