多模态预测器.fit¶
- MultiModalPredictor.fit(train_data: DataFrame | str, presets: str | None = None, tuning_data: DataFrame | str | None = None, max_num_tuning_data: int | None = None, id_mappings: Dict[str, Dict] | Dict[str, Series] | None = None, time_limit: int | None = None, save_path: str | None = None, hyperparameters: str | Dict | List[str] | None = None, column_types: dict | None = None, holdout_frac: float | None = None, teacher_predictor: str | MultiModalPredictor = None, seed: int | None = 0, standalone: bool | None = True, hyperparameter_tune_kwargs: dict | None = None, clean_ckpts: bool | None = True)[source]¶
拟合模型以基于其他列(特征)预测数据表的某一列(标签)。
- Parameters:
train_data – 一个包含训练数据的 pd.DataFrame。
presets – 关于模型质量的预设,例如 best_quality、high_quality 和 medium_quality。 每种质量都有其对应的 HPO 预设:'best_quality_hpo'、'high_quality_hpo' 和 'medium_quality_hpo'。
tuning_data – 一个包含验证数据的 pd.DataFrame,其列应与 train_data 相同。 如果 tuning_data = None,fit() 将自动从 train_data 中随机保留一些验证数据。
max_num_tuning_data – 调优样本的最大数量(用于目标检测)。
id_mappings – ID到内容的映射(用于语义匹配)。内容可以是文本、图像等。 当pd.DataFrame包含查询/响应标识符而不是它们的内容时使用此功能。
time_limit – fit() 应该运行多长时间(以秒为单位的挂钟时间)。 如果未指定,fit() 将运行直到模型完成训练。
save_path – 模型和工件应保存到的目录路径。
超参数 –
这是为了覆盖一些默认配置。 例如,可以通过格式化来更改文本和图像骨干网络:
一个字符串 hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”
或一个字符串列表 hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]
或一个字典 hyperparameters = {
”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,
}
column_types –
一个将列名映射到其数据类型的字典。 例如:column_types = {“item_name”: “text”, “image”: “image_path”, “product_description”: “text”, “height”: “numerical”} 可以用于具有以下列的表:“item_name”、“brand”、“product_description”和“height”。 如果为None,column_types将从数据中自动推断。 当前支持的类型有:
”image_path”:此列中的每一行都是一个图像路径。
”text”:此列中的每一行包含文本(句子、段落等)。
”numerical”:此列中的每一行包含一个数字。
”categorical”:此列中的每一行属于K个类别之一。
holdout_frac – 作为调优数据保留的训练数据比例,用于优化超参数或提前停止(除非tuning_data = None,否则忽略)。 默认值(如果为None)根据训练数据中的行数以及是否使用超参数优化来选择。
teacher_predictor – 预训练的教师预测器或其保存路径。如果提供,fit() 可以将其知识蒸馏到学生预测器,即当前预测器。
seed – 用于训练的随机种子(默认值为0)。
standalone – 是否保存整个模型以用于离线部署。
hyperparameter_tune_kwargs –
超参数调优策略和参数(例如,运行多少次HPO试验)。 如果为None,则不会执行超参数调优。
- num_trials: int
运行多少次HPO试验。需要指定num_trials或time_limit到fit。
- scheduler: Union[str, ray.tune.schedulers.TrialScheduler]
如果传递了str,AutoGluon将为您创建调度程序,并带有一些默认参数。 如果传递了ray.tune.schedulers.TrialScheduler对象,您需要负责初始化该对象。
- scheduler_init_args: Optional[dict] = None
如果向scheduler提供了str,您可以选择性地为调度程序提供自定义的init_args。
- searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]
如果传递了str,AutoGluon将为您创建搜索器,并带有一些默认参数。 如果传递了ray.tune.schedulers.TrialScheduler对象,您需要负责初始化该对象。 您不需要担心搜索器对象的metric和mode。AutoGluon会自行解决。
- scheduler_init_args: Optional[dict] = None
如果向searcher提供了str,您可以选择性地为搜索器提供自定义的init_args。 您不需要担心metric和mode。AutoGluon会自行解决。
clean_ckpts – 是否在训练后清理中间检查点。
- Return type:
一个“MultiModalPredictor”对象(其本身)。