伯努利负采样器

class BernoulliNegativeSampler(*, mapped_triples: Tensor, **kwargs)[源代码]

基础类:NegativeSampler

[wang2014]提出的伯努利负采样方法的实现。

在关系 \((h,r,t) \in \mathcal{K}\) 中,破坏头 \(h\) 或尾 \(t\) 的概率由关系 \(r\) 的全局属性决定:

  • \(r\)一对多的关系(例如motherOf):更高的概率被分配给替换\(h\)

  • \(r\)多对一的(例如 bornIn):分配了更高的概率来替换 \(t\)

更准确地说,对于每个关系 \(r \in \mathcal{R}\),首先计算每个头的平均尾数 (tph) 和每个尾的平均头数 (hpt)。

然后,头部损坏概率 \(p_r\) 定义为 \(p_r = \frac{tph}{tph + hpt}\)。 尾部损坏概率定义为 \(1 - p_r = \frac{hpt}{tph + hpt}\)

对于每个三元组 \((h,r,t) \in \mathcal{K}\),头部以概率 \(p_r\) 被破坏,尾部以概率 \(1 - p_r\) 被破坏。

如果 filtered 设置为 True,所有作为实际正三元组存在的被提议的损坏三元组 \((h,r,t) \in \mathcal{K}\) 将被移除。

使用给定的实体初始化伯努利负采样器。

Parameters:

方法总结

corrupt_batch(positive_batch)

从正样本批次中生成负样本,不应用任何过滤器。

方法文档

corrupt_batch(positive_batch: Tensor) Tensor[来源]

从正样本批次中生成负样本,不应用任何过滤器。

Parameters:

positive_batch (Tensor) – 形状: (*batch_dims, 3) 正样本三元组。

Returns:

形状:(*batch_dims, num_negs_per_pos, 3) 负三元组。result[*bi, :, :] 包含从 positive_batch[*bi, :] 生成的负样本。

Return type:

Tensor