多模态预测器.fit

MultiModalPredictor.fit(train_data: DataFrame | str, presets: str | None = None, tuning_data: DataFrame | str | None = None, max_num_tuning_data: int | None = None, id_mappings: Dict[str, Dict] | Dict[str, Series] | None = None, time_limit: int | None = None, save_path: str | None = None, hyperparameters: str | Dict | List[str] | None = None, column_types: dict | None = None, holdout_frac: float | None = None, teacher_predictor: str | MultiModalPredictor = None, seed: int | None = 0, standalone: bool | None = True, hyperparameter_tune_kwargs: dict | None = None, clean_ckpts: bool | None = True)[source]

拟合模型以基于其他列(特征)预测数据表的某一列(标签)。

Parameters:
  • train_data – 一个包含训练数据的 pd.DataFrame。

  • presets – 关于模型质量的预设,例如 best_quality、high_quality 和 medium_quality。 每种质量都有其对应的 HPO 预设:'best_quality_hpo'、'high_quality_hpo' 和 'medium_quality_hpo'。

  • tuning_data – 一个包含验证数据的 pd.DataFrame,其列应与 train_data 相同。 如果 tuning_data = Nonefit() 将自动从 train_data 中随机保留一些验证数据。

  • max_num_tuning_data – 调优样本的最大数量(用于目标检测)。

  • id_mappings – ID到内容的映射(用于语义匹配)。内容可以是文本、图像等。 当pd.DataFrame包含查询/响应标识符而不是它们的内容时使用此功能。

  • time_limitfit() 应该运行多长时间(以秒为单位的挂钟时间)。 如果未指定,fit() 将运行直到模型完成训练。

  • save_path – 模型和工件应保存到的目录路径。

  • 超参数

    这是为了覆盖一些默认配置。 例如,可以通过格式化来更改文本和图像骨干网络:

    一个字符串 hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    或一个字符串列表 hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    或一个字典 hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • column_types

    一个将列名映射到其数据类型的字典。 例如:column_types = {“item_name”: “text”, “image”: “image_path”, “product_description”: “text”, “height”: “numerical”} 可以用于具有以下列的表:“item_name”、“brand”、“product_description”和“height”。 如果为None,column_types将从数据中自动推断。 当前支持的类型有:

    • ”image_path”:此列中的每一行都是一个图像路径。

    • ”text”:此列中的每一行包含文本(句子、段落等)。

    • ”numerical”:此列中的每一行包含一个数字。

    • ”categorical”:此列中的每一行属于K个类别之一。

  • holdout_frac – 作为调优数据保留的训练数据比例,用于优化超参数或提前停止(除非tuning_data = None,否则忽略)。 默认值(如果为None)根据训练数据中的行数以及是否使用超参数优化来选择。

  • teacher_predictor – 预训练的教师预测器或其保存路径。如果提供,fit() 可以将其知识蒸馏到学生预测器,即当前预测器。

  • seed – 用于训练的随机种子(默认值为0)。

  • standalone – 是否保存整个模型以用于离线部署。

  • hyperparameter_tune_kwargs

    超参数调优策略和参数(例如,运行多少次HPO试验)。 如果为None,则不会执行超参数调优。

    num_trials: int

    运行多少次HPO试验。需要指定num_trialstime_limitfit

    scheduler: Union[str, ray.tune.schedulers.TrialScheduler]

    如果传递了str,AutoGluon将为您创建调度程序,并带有一些默认参数。 如果传递了ray.tune.schedulers.TrialScheduler对象,您需要负责初始化该对象。

    scheduler_init_args: Optional[dict] = None

    如果向scheduler提供了str,您可以选择性地为调度程序提供自定义的init_args。

    searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]

    如果传递了str,AutoGluon将为您创建搜索器,并带有一些默认参数。 如果传递了ray.tune.schedulers.TrialScheduler对象,您需要负责初始化该对象。 您不需要担心搜索器对象的metricmode。AutoGluon会自行解决。

    scheduler_init_args: Optional[dict] = None

    如果向searcher提供了str,您可以选择性地为搜索器提供自定义的init_args。 您不需要担心metricmode。AutoGluon会自行解决。

  • clean_ckpts – 是否在训练后清理中间检查点。

Return type:

一个“MultiModalPredictor”对象(其本身)。