autogluon.multimodal.MultiModalPredictor

class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None)[source]

AutoMM 旨在简化基础模型在下游应用中的微调,只需三行代码即可完成。 AutoMM 无缝集成了流行的模型库,如 HuggingFace TransformersTIMM、 和 MMDetection, 适应多种数据模态,包括图像、文本、表格和文档数据,无论是单独使用还是组合使用。 它支持多种任务,包括分类、回归、 目标检测、命名实体识别、语义匹配和图像分割。

__init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None)[source]
Parameters:
  • label – 包含要预测的目标变量的pd.DataFrame列的名称。

  • problem_type

    问题类型。我们支持的标准问题包括

    • ’binary’: 二分类

    • ’multiclass’: 多分类

    • ’regression’: 回归

    • ’classification’: 分类问题包括‘二分类’和‘多分类’。

    此外,我们还支持高级问题,例如

    • ’object_detection’: 目标检测

    • ’ner’ 或 ‘named_entity_recognition’: 命名实体提取

    • ’text_similarity’: 文本-文本语义匹配

    • ’image_similarity’: 图像-图像语义匹配

    • ’image_text_similarity’: 文本-图像语义匹配

    • ’feature_extraction’: 特征提取(仅支持推理)

    • ’zero_shot_image_classification’: 零样本图像分类(仅支持推理)

    • ’few_shot_classification’: 图像或文本数据的少样本分类。

    • ’semantic_segmentation’: 使用Segment Anything模型进行语义分割。

    对于某些问题类型,默认行为是基于预设/超参数加载预训练模型,并且预测器可以进行零样本推理(无需.fit()即可运行推理)。这些问题类型包括:

    • ’object_detection’

    • ’text_similarity’

    • ’image_similarity’

    • ’image_text_similarity’

    • ’feature_extraction’

    • ’zero_shot_image_classification’

  • query – 在语义匹配任务中包含查询数据的pd.DataFrame列的名称。

  • response – 在语义匹配任务中包含响应数据的pd.DataFrame列的名称。 如果未提供标签列,则假定一个pd.DataFrame行中的查询和响应对为正对。

  • match_label – 表示对被计为“匹配”的标签类。 当任务属于语义匹配且标签为二元时使用此标签。 例如,在重复检测任务中,标签列可以包含[“duplicate”, “not duplicate”]。 match_label应为“duplicate”,因为它表示两个项目匹配。

  • presets – 关于模型质量的预设,例如,‘best_quality’、‘high_quality’(默认)和‘medium_quality’。 每种质量都有其对应的HPO预设:‘best_quality_hpo’、‘high_quality_hpo’和‘medium_quality_hpo’。

  • eval_metric – 评估指标名称。如果 eval_metric = None,则根据 problem_type 自动选择。 对于多类分类默认为‘accuracy’,对于二分类默认为 roc_auc, 对于回归默认为‘root_mean_squared_error’。

  • hyperparameters

    This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:

    a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    or a dictionary hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • path – 模型及相关工件应保存的目录路径。 如果未指定,将在工作目录中创建一个名为“AutogluonAutoMM/ag-[TIMESTAMP]”的时间戳文件夹。 注意:要调用fit()两次并保存每次拟合的所有结果, 您必须指定不同的path位置或根本不指定path

  • verbosity – 详细程度级别从0到4,控制打印多少日志信息。 更高的级别对应更详细的打印语句。 您可以设置 verbosity = 0 来抑制警告。

  • num_classes – 类别数量(用于目标检测)。 如果指定了此参数并且与预训练模型的输出形状不同, 模型的头部将被更改为具有 输出。

  • classes – 所有类别(用于目标检测)。

  • warn_if_exist – 如果指定路径已经存在,是否发出警告(默认为 True)。

  • enable_progress_bar – 是否显示进度条(默认为True)。如果设置了环境变量 os.environ[“AUTOMM_DISABLE_PROGRESS_BAR”],则会被禁用。

  • pretrained – 是否使用预训练权重初始化模型(默认为 True)。 如果为 False,则创建一个随机初始化的模型。

  • validation_metric – 用于选择最佳模型和训练期间早停的验证指标。 如果未提供,将根据问题类型自动选择。

  • sample_data_path – 样本数据的路径,我们可以从中推断出用于目标检测的num_classes或classes。

方法

dump_model

将模型权重和配置保存到本地目录。

evaluate

在给定数据集上评估模型。

export_onnx

将此预测器的模型导出为ONNX文件。

extract_embedding

Extract features for each sample, i.e., one row in the provided pd.DataFrame data.

fit

拟合模型以基于其他列(特征)预测数据表的某一列(标签)。

fit_summary

Output the training summary information from fit().

get_num_gpus

从配置中获取GPU的数量。

list_supported_models

列出每种问题类型支持的模型。

load

从由path指定的目录加载一个预测器对象。

optimize_for_inference

优化预测器的模型以进行推理。

predict

预测新数据的标签列值。

predict_proba

预测类别概率而不是类别标签。

save

Save this predictor to file in directory specified by path.

set_num_gpus

在配置中设置GPU的数量。

set_verbosity

设置日志的详细程度。

属性

class_labels

类标签的原始名称。

classes

用于目标检测问题类型的对象类。

column_types

pd.DataFrame中的列类型。

eval_metric

用于评估预测性能的指标是什么。

label

包含要预测的目标变量的pd.DataFrame列的名称。

match_label

在语义匹配任务中,表示对的标签类被计为“匹配”。

model_size

返回模型大小,单位为兆字节。

path

存储模型和相关工件的目录路径。

positive_class

将映射到1的类标签的名称。

problem_property

问题的属性,存储问题类型及其相关属性。

problem_type

这个预测器已经训练用于哪种类型的预测问题。

query

在语义匹配任务中,包含查询数据的pd.DataFrame列的名称。

response

在语义匹配任务中包含响应数据的pd.DataFrame列的名称。

total_parameters

模型参数的数量。

trainable_parameters

可训练模型参数的数量,通常是那些requires_grad=True的参数。

validation_metric

用于选择最佳模型和训练期间早停的验证指标。

verbosity

详细级别范围从0到4,控制打印多少信息。