评估器
- class Evaluator(filtered: bool = False, requires_positive_mask: bool = False, batch_size: int | None = None, slice_size: int | None = None, mode: Literal['training', 'validation', 'testing'] | None = None)[source]
基类:
ABC,Generic[MetricKeyType]KGE模型的抽象评估器。
评估器封装了基于头部和尾部得分的评估指标计算。为此,它提供了两种方法来处理一批三元组以及由某些模型生成的得分。它在状态中维护中间结果,并提供了一个方法在完成后获取最终结果。
初始化评估器。
- Parameters:
方法总结
clear()清除缓冲区和中间结果。
evaluate(model, mapped_triples[, ...])评估模型在映射三元组上的指标。
finalize()计算最终结果,并清除缓冲区。
获取评估器的规范化名称。
process_scores_(hrt_batch, target, scores[, ...])处理一批三元组及其为所有实体计算的分数。
方法文档
- evaluate(model: Model, mapped_triples: Tensor, batch_size: int | None = None, slice_size: int | None = None, device: device | None = None, use_tqdm: bool = True, tqdm_kwargs: Mapping[str, Any] | None = None, restrict_entities_to: Collection[int] | None = None, restrict_relations_to: Collection[int] | None = None, do_time_consuming_checks: bool = True, additional_filter_triples: None | Tensor | list[Tensor] = None, pre_filtered_triples: bool = True, targets: Collection[Literal['head', 'relation', 'tail']] = ('head', 'tail')) MetricResults[MetricKeyType][源代码]
评估模型在映射三元组上的指标。
- Parameters:
模型 (Model) – 要评估的模型。
mapped_triples (Tensor) – 用于评估的三元组。映射的三元组应永远不包含反向三元组 - 这些是由模型类动态创建的。
batch_size (int | None) – >0 一个正整数,用作批量大小。通常尽可能选择较大的值。如果为None,则默认为1。
slice_size (int | None) – >0 使用切片时评分函数的除数。
device (device | None) – 评估将在其上运行的设备。如果为None,则使用模型的设备。
use_tqdm (bool) – 是否应该显示进度条?
restrict_entities_to (Collection[int] | None) – 可选地限制评估到给定的实体ID。如果只对实体的一部分感兴趣,例如由于类型限制,但希望在所有可用数据上进行训练,这可能很有用。对于实体排名,我们仍然计算所有可能替换实体的所有分数,以避免可能降低性能的不规则访问模式,但之后会过滤分数,只保留感兴趣的分数。如果提供,我们默认假设三元组已经被过滤,因此它只包含感兴趣的实体。要在此方法中显式过滤,请传递pre_filtered_triples=False。
restrict_relations_to (Collection[int] | None) – 可选地限制评估到给定的关系ID。如果只对关系的一部分感兴趣(例如由于关系类型),但希望在所有可用数据上进行训练,这可能很有用。如果提供了,我们默认假设三元组已经被过滤,只包含感兴趣的关系。要在此方法中显式过滤,请传递pre_filtered_triples=False。
do_time_consuming_checks (bool) – 是否对提供的参数执行一些耗时的检查。目前,这仅包括: 如果 restrict_entities_to 或 restrict_relations_to 不是 None,则检查三元组是否已被过滤。禁用此选项可以加速该方法。仅在 pre_filtered_triples 设置为 True 时有效。
pre_filtered_triples (bool) – 三元组是否已经预先过滤以符合 restrict_entities_to / restrict_relations_to。 当设置为 True 时,如果三元组没有被过滤,结果可能无效。预先过滤三元组可以加速此方法,并且建议在相同三元组集上多次评估时使用。
additional_filter_triples (None | Tensor | list[Tensor]) – 在过滤评估期间要过滤掉的额外真实三元组。
targets (Collection[Literal['head', 'relation', 'tail']]) – 预测目标
- Raises:
NotImplementedError – 如果请求了关系预测评估
ValueError – 如果 pre_filtered_triples 包含不需要的实体(只能通过耗时的检查来检测)。
MemoryError – 如果评估在CPU上失败
- Returns:
评估结果
- Return type:
MetricResults[MetricKeyType]
- abstract finalize() MetricResults[MetricKeyType][源代码]
计算最终结果,并清除缓冲区。
- Return type:
MetricResults[MetricKeyType]