InfoNCELoss
- class InfoNCELoss(margin: float = 0.02, log_adversarial_temperature: float = -2.995732273553991, reduction: str = 'mean')[source]
基础类:
CrossEntropyLoss由[wang2022]提出的带有附加边际的InfoNCE损失。
这个损失等同于
CrossEntropyLoss,其中分数已经被转换:正分数减去边距 gamma,然后除以温度 tau
\[f'(k) = \frac{f(k) - \gamma}{\tau}\]负分数仅除以温度 tau
\[f'(k^-) = \frac{f(k^-)}{\tau}\]
初始化损失。
- Parameters:
margin (float) –
损失的边界(在参考论文中也写作 \(\gamma\))
注意
在官方实现中,边界参数似乎仅在训练期间使用。 https://github.com/intfloat/SimKGC/blob/4388ebc0c0011fe333bc5a98d0613ab0d1825ddc/models.py#L92-L94
log_adversarial_temperature (float) –
负采样温度的对数(在参考论文中也写作\(\tau\))。 我们遵循建议的参数化方法,确保所有超参数值的温度为正。
注意
对抗温度是计算权重时使用的softmax温度的倒数! 其名称仅用于与[wang2022]的命名法保持一致。
注意
在官方实现中,温度是一个可训练的参数,参见 https://github.com/intfloat/SimKGC/blob/4388ebc0c0011fe333bc5a98d0613ab0d1825ddc/models.py#L31
reduction (str) – 用于将批次中的单个损失值聚合为标量损失值的归约操作的名称。从 {‘mean’, ‘sum’} 中选择。
- Raises:
ValueError – 如果边距为负数
属性摘要
优化损失超参数的默认策略
方法总结
process_lcwa_scores(predictions, labels[, ...])处理来自LCWA训练循环的分数。
process_slcwa_scores(positive_scores, ...[, ...])处理来自sLCWA训练循环的分数。
属性文档
- hpo_default: ClassVar[Mapping[str, Any]] = {'log_adversarial_temperature': {'high': 3.0, 'low': -3.0, 'type': <class 'float'>}, 'margin': {'high': 0.1, 'low': 0.01, 'type': <class 'float'>}}
优化损失超参数的默认策略
方法文档
- process_lcwa_scores(predictions: Tensor, labels: Tensor, label_smoothing: float | None = None, num_entities: int | None = None) Tensor[来源]
处理来自LCWA训练循环的分数。
- process_slcwa_scores(positive_scores: Tensor, negative_scores: Tensor, label_smoothing: float | None = None, batch_filter: Tensor | None = None, num_entities: int | None = None) Tensor[source]
处理来自sLCWA训练循环的分数。
- Parameters:
positive_scores (Tensor) – 形状: (batch_size, 1) 正三元组的分数。
negative_scores (Tensor) – 形状: (batch_size, num_neg_per_pos) 或 (num_unfiltered_negatives,) 负三元组的分数,可以是密集的2D形状,或者如果它们已经被过滤,可以是稀疏形状。如果它们以稀疏形状给出,还需要提供batch_filter。
label_smoothing (float | None) – 一个可选的标签平滑参数。
batch_filter (Tensor | None) – 形状: (batch_size, num_neg_per_pos) 一个可选的负分数过滤器,用于保留哪些负分数。仅在负分数已被预过滤时提供。
num_entities (int | None) – 实体的数量。仅在启用标签平滑时需要。
- Returns:
一个标量损失项。
- Return type: