speechbrain.utils.bertscore 模块
为BERTscore指标提供了一个指标类。
作者 * Sylvain de Langen 2024
摘要
类:
使用提供的HuggingFace Transformers文本编码器计算BERTScore,使用论文BERTScore: Evaluating Text Generation with BERT中描述的方法。 |
函数:
返回一个带有特殊标记掩码的标记掩码。 |
|
返回用于BERTScore指标的令牌权重。 |
参考
- class speechbrain.utils.bertscore.BERTScoreStats(lm: TextEncoder, batch_size: int = 64, use_idf: bool = True, sentence_level_averaging: bool = True, allow_matching_special_tokens: bool = False)[source]
基础:
MetricStats使用提供的HuggingFace Transformers文本编码器计算BERTScore, 采用论文BERTScore: Evaluating Text Generation with BERT中描述的方法。
BERTScore 操作于上下文标记(例如 BERT 的输出,但许多其他模型也可以使用)。由于使用了余弦相似度,输出范围将在
-1和1之间。更多详情请参阅链接资源。特殊标记(如从标记器查询的)被完全忽略。
作者关于该指标的参考实现可以在这里找到。链接页面详细描述了该方法,并比较了BERTScore与人类评估在许多不同模型中的关系。
警告
开箱即用,此实现可能不会严格匹配参考实现的结果。请阅读参数文档以了解差异。
- Parameters:
lm (speechbrain.lobes.models.huggingface_transformers.TextEncoder) – 用作语言模型的HF Transformers分词器和文本编码器包装器。
batch_size (int, optional) – 一次应考虑多少对语句。数值越大速度越快,但可能导致内存不足(OOM)。
use_idf (bool, 可选) – 如果启用(默认),参考中的标记将通过逆文档频率进行加权,这可以减少可能携带较少信息的常见词的影响。在IDF计算中,每个附加的句子都被视为一个文档。
sentence_level_averaging (bool, 可选) – 当
True时,最终的召回率/精确度指标将是每个测试句子的召回率/精确度的平均值,而不是每个测试标记的平均值,例如,在最终指标中,一个非常长的句子将与一个非常短的句子具有相同的权重。默认值为True,这与参考实现一致。allow_matching_special_tokens (bool, 可选) – 当
True时,非特殊标记可能在贪婪匹配期间与特殊标记匹配(例如[CLS]/[SEP])。由于填充处理的原因,批量大小必须为 1。 默认值为False,这与参考实现的行为不同(参见 bert_score#180)。
- speechbrain.utils.bertscore.get_bert_token_mask(tokenizer) BoolTensor[source]
返回一个带有特殊标记被屏蔽的标记掩码。
- Parameters:
tokenizer (transformers.PreTrainedTokenizer) – 用于BERT模型的HuggingFace分词器。
- Returns:
一个可以通过令牌ID索引的掩码张量(形状为
[vocab_size])。- Return type:
torch.BoolTensor
- speechbrain.utils.bertscore.get_bertscore_token_weights(tokenizer, corpus: Iterable[str] | None = None) Tensor[source]
返回用于BERTScore指标的令牌权重。 当指定
corpus时,权重是每个令牌的逆文档频率 (IDF),从corpus中提取。IDF公式改编自BERTScore论文,其中参考语料库中缺失的单词使用
+1平滑进行加权。- Parameters:
tokenizer (transformers.PreTrainedTokenizer) – 用于BERT模型的HuggingFace分词器。
corpus (Iterable[str], optional) – 用于计算IDF的可迭代语料库。在IDF计算中,每个迭代值被视为语料库中的一个文档。 如果省略,则不进行IDF加权。
- Returns:
一个可以通过令牌ID索引的浮点张量,形状为
[vocab_size],其中每个条目表示给定令牌的影响应乘以多少。- Return type:
torch.Tensor