autogluon.multimodal.MultiModalPredictor¶
- class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None)[source]¶
AutoMM 旨在简化基础模型在下游应用中的微调,只需三行代码即可完成。 AutoMM 无缝集成了流行的模型库,如 HuggingFace Transformers、 TIMM、 和 MMDetection, 适应多种数据模态,包括图像、文本、表格和文档数据,无论是单独使用还是组合使用。 它支持多种任务,包括分类、回归、 目标检测、命名实体识别、语义匹配和图像分割。
- __init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None)[source]¶
- Parameters:
label – 包含要预测的目标变量的pd.DataFrame列的名称。
problem_type –
问题类型。我们支持的标准问题包括
’binary’: 二分类
’multiclass’: 多分类
’regression’: 回归
’classification’: 分类问题包括‘二分类’和‘多分类’。
此外,我们还支持高级问题,例如
’object_detection’: 目标检测
’ner’ 或 ‘named_entity_recognition’: 命名实体提取
’text_similarity’: 文本-文本语义匹配
’image_similarity’: 图像-图像语义匹配
’image_text_similarity’: 文本-图像语义匹配
’feature_extraction’: 特征提取(仅支持推理)
’zero_shot_image_classification’: 零样本图像分类(仅支持推理)
’few_shot_classification’: 图像或文本数据的少样本分类。
’semantic_segmentation’: 使用Segment Anything模型进行语义分割。
对于某些问题类型,默认行为是基于预设/超参数加载预训练模型,并且预测器可以进行零样本推理(无需.fit()即可运行推理)。这些问题类型包括:
’object_detection’
’text_similarity’
’image_similarity’
’image_text_similarity’
’feature_extraction’
’zero_shot_image_classification’
query – 在语义匹配任务中包含查询数据的pd.DataFrame列的名称。
response – 在语义匹配任务中包含响应数据的pd.DataFrame列的名称。 如果未提供标签列,则假定一个pd.DataFrame行中的查询和响应对为正对。
match_label – 表示
对被计为“匹配”的标签类。 当任务属于语义匹配且标签为二元时使用此标签。 例如,在重复检测任务中,标签列可以包含[“duplicate”, “not duplicate”]。 match_label应为“duplicate”,因为它表示两个项目匹配。 presets – 关于模型质量的预设,例如,‘best_quality’、‘high_quality’(默认)和‘medium_quality’。 每种质量都有其对应的HPO预设:‘best_quality_hpo’、‘high_quality_hpo’和‘medium_quality_hpo’。
eval_metric – 评估指标名称。如果 eval_metric = None,则根据 problem_type 自动选择。 对于多类分类默认为‘accuracy’,对于二分类默认为 roc_auc, 对于回归默认为‘root_mean_squared_error’。
hyperparameters –
This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:
a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”
or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]
or a dictionary hyperparameters = {
”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,
}
path – 模型及相关工件应保存的目录路径。 如果未指定,将在工作目录中创建一个名为“AutogluonAutoMM/ag-[TIMESTAMP]”的时间戳文件夹。 注意:要调用fit()两次并保存每次拟合的所有结果, 您必须指定不同的path位置或根本不指定path。
verbosity – 详细程度级别从0到4,控制打印多少日志信息。 更高的级别对应更详细的打印语句。 您可以设置 verbosity = 0 来抑制警告。
num_classes – 类别数量(用于目标检测)。 如果指定了此参数并且与预训练模型的输出形状不同, 模型的头部将被更改为具有
输出。 classes – 所有类别(用于目标检测)。
warn_if_exist – 如果指定路径已经存在,是否发出警告(默认为 True)。
enable_progress_bar – 是否显示进度条(默认为True)。如果设置了环境变量 os.environ[“AUTOMM_DISABLE_PROGRESS_BAR”],则会被禁用。
pretrained – 是否使用预训练权重初始化模型(默认为 True)。 如果为 False,则创建一个随机初始化的模型。
validation_metric – 用于选择最佳模型和训练期间早停的验证指标。 如果未提供,将根据问题类型自动选择。
sample_data_path – 样本数据的路径,我们可以从中推断出用于目标检测的num_classes或classes。
方法
将模型权重和配置保存到本地目录。
在给定数据集上评估模型。
将此预测器的模型导出为ONNX文件。
Extract features for each sample, i.e., one row in the provided pd.DataFrame data.
拟合模型以基于其他列(特征)预测数据表的某一列(标签)。
Output the training summary information from fit().
从配置中获取GPU的数量。
列出每种问题类型支持的模型。
从由path指定的目录加载一个预测器对象。
优化预测器的模型以进行推理。
预测新数据的标签列值。
预测类别概率而不是类别标签。
Save this predictor to file in directory specified by path.
在配置中设置GPU的数量。
设置日志的详细程度。
属性
class_labels类标签的原始名称。
classes用于目标检测问题类型的对象类。
column_typespd.DataFrame中的列类型。
eval_metric用于评估预测性能的指标是什么。
label包含要预测的目标变量的pd.DataFrame列的名称。
match_label在语义匹配任务中,表示
对的标签类被计为“匹配”。 model_size返回模型大小,单位为兆字节。
path存储模型和相关工件的目录路径。
positive_class将映射到1的类标签的名称。
problem_property问题的属性,存储问题类型及其相关属性。
problem_type这个预测器已经训练用于哪种类型的预测问题。
query在语义匹配任务中,包含查询数据的pd.DataFrame列的名称。
response在语义匹配任务中包含响应数据的pd.DataFrame列的名称。
total_parameters模型参数的数量。
trainable_parameters可训练模型参数的数量,通常是那些requires_grad=True的参数。
validation_metric用于选择最佳模型和训练期间早停的验证指标。
verbosity详细级别范围从0到4,控制打印多少信息。