speechbrain.k2_integration.graph_compiler 模块

图形编译器类,用于在speechbrain中创建、存储和使用k2解码图。将输出单词限制在词汇表中的单词。

这段代码是一个扩展,因此深受或取自icefall的(https://github.com/k2-fsa/icefall)图编译器。

Authors:
  • 皮埃尔·冠军 2023

  • 赵泽宇 2023

  • 乔治奥斯·卡拉卡西迪斯 2023

摘要

类:

CtcGraphCompiler

此类用于编译CTC训练的解码图。

GraphCompiler

这个抽象类用于编译用于训练和解码的图。

参考

class speechbrain.k2_integration.graph_compiler.GraphCompiler[source]

基础类: ABC

这个抽象类用于编译用于训练和解码的图。

abstract property topo: k2.Fsa

返回用于编译图的拓扑结构。

abstract property lexicon: 词汇表

返回用于编译图表的词汇表。

abstract property device

返回用于编译图的设备。

abstract compile(texts: List[str], is_training: bool = True) Tuple[k2.Fsa, Tensor][source]

为给定的文本编译图形。

Parameters:
  • 文本 (列表[字符串]) –

    一个字符串列表。每个字符串包含一个句子的表达。 一个句子由空格分隔的单词组成。一个示例 文本 看起来像:

    [‘hello world’, ‘CTC training with k2’]

  • is_training (bool) – 指示这是否用于训练 (训练中的OOV警告)。

Returns:

  • graph (GraphCompiler) – 一个FsaVec,是self.ctc_topo和转录FSA的组合结果。

  • target_lens (Torch.tensor) – 它是一个形状为(batch,)的长整型张量。它包含每个目标序列的长度。

compile_HL(cache_dir: str | None = None, cache: bool = False)[source]

通过将H与L组合来编译解码图。 这是用于没有语言模型的解码。

Parameters:
  • cache_dir (str) – 存储组合的路径,格式为 .pt。

  • cache (bool) – 是否从.pt格式加载组合(在cache_dir目录中)。

Returns:

HL – HL 组合

Return type:

k2.Fsa

compile_HLG(G, cache_dir: str | None = None, cache: bool = False)[source]

通过将H与LG组合来编译解码图。 这是用于使用小型语言模型进行解码的。

Parameters:
  • G (k2.Fsa) – 语言模型 FSA。

  • cache_dir (str) – 存储组合的路径,格式为 .pt。

  • cache (bool) – 是否从.pt格式加载组合(在cache_dir目录中)。

Returns:

HL – HLG组合

Return type:

k2.Fsa

class speechbrain.k2_integration.graph_compiler.CtcGraphCompiler(_lexicon: 词汇表, device: device, need_repeat_flag: bool = False)[source]

基础类: GraphCompiler

此类用于编译CTC训练的解码图。

Parameters:
  • _lexicon (Lexicon) – 它是从 data/lang/lexicon.txt 构建的。

  • device (torch.device) – 用于将转录本编译为FSA的操作的设备。

  • need_repeat_flag (bool) – 如果为True,将会在ctc_topo中添加一个名为_is_repeat_token_的属性,用于指示该token是否是ctc图中的重复token。此属性是实现基于音素的ctc损失的延迟惩罚所必需的。更多详情请参见https://github.com/k2-fsa/k2/pull/1086。注意:上述更改必须包含在k2中才能启用此标志,因此请确保您使用的是最新版本。

Example

>>> import torch
>>> from speechbrain.k2_integration.losses import ctc_k2
>>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
>>> from speechbrain.k2_integration.lexicon import Lexicon
>>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs
>>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30)
>>> log_probs.requires_grad = True
>>> # Assume all utterances have the same length so no padding was needed.
>>> input_lens = torch.ones(batch_size)
>>> # Create a small lexicon containing only two words and write it to a file.
>>> lang_tmpdir = getfixture('tmpdir')
>>> lexicon_sample = "hello h e l l o\nworld w o r l d\n<UNK> <unk>"
>>> lexicon_file = lang_tmpdir.join("lexicon.txt")
>>> lexicon_file.write(lexicon_sample)
>>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
>>> prepare_lang(lang_tmpdir)
>>> # Create a lexicon object
>>> lexicon = Lexicon(lang_tmpdir)
>>> # Create a random decoding graph
>>> graph = CtcGraphCompiler(
...     lexicon,
...     log_probs.device,
... )
>>> isinstance(graph.topo, k2.Fsa)
True
property topo

返回ctc_topo。

property lexicon

返回词典。

property device

返回用于编译图形的设备。

compile(texts: List[str], is_training: bool = True) Tuple[k2.Fsa, Tensor][source]

通过将ctc_topo与给定的转录本组合来构建解码图。

Parameters:
  • 文本 (列表[字符串]) –

    一个字符串列表。每个字符串包含一个句子的表达。 一个句子由空格分隔的单词组成。一个示例 文本 看起来像:

    [‘hello world’, ‘CTC training with k2’]

  • is_training (bool) – 指示这是否用于训练 (训练中的OOV警告)。

Returns:

  • graph (GraphCompiler) – 一个FsaVec,是self.ctc_topo和转录FSA的组合结果。

  • target_lens (Torch.tensor) – 它是一个形状为(batch,)的长整型张量。它包含每个目标序列的长度。