过滤器

class Filterer(*args, **kwargs)[source]

基础类: Module

用于过滤负三元组方法的接口。

初始化内部模块状态,由nn.Module和ScriptModule共享。

方法总结

contains(batch)

检查是否包含三元组。

forward(negative_batch)

过滤掉训练数据集中所有被标记为正样本的负样本。

方法文档

abstract contains(batch: Tensor) Tensor[source]

检查是否包含三元组。

支持批处理。

Parameters:

batch (Tensor) – 形状 (batch_size, 3) 三元组的批次。

Returns:

形状: (batch_size,) 三元组是否包含在训练三元组中。

Return type:

Tensor

forward(negative_batch: Tensor) Tensor[来源]

过滤掉训练数据集中所有被标记为正样本的负样本。

通常情况下,在训练数据集中,提出的负样本是正样本的概率较低,因此它们会作为假阴性样本。这预计会作为一种正则化手段,因为它向训练数据中添加了噪声信号。然而,正则化的程度很难控制,因为添加的噪声信号取决于给定实体关系或实体对中真实三元组的比例。因此,效果很难控制,研究人员可能希望排除在提出的负三元组中存在假阴性样本的可能性。

注意

过滤是一项非常昂贵的任务,因为每个提出的负样本都必须与整个训练数据集进行检查。

Parameters:

negative_batch (Tensor) – 形状: (batch_size, num_negatives, 3) 负三元组的批次。

Returns:

形状: (batch_size, num_negatives) 一个掩码,其中True表示负样本是有效的。

Return type:

Tensor