speechbrain.k2_integration.graph_compiler 模块
图形编译器类,用于在speechbrain中创建、存储和使用k2解码图。将输出单词限制在词汇表中的单词。
这段代码是一个扩展,因此深受或取自icefall的(https://github.com/k2-fsa/icefall)图编译器。
- Authors:
皮埃尔·冠军 2023
赵泽宇 2023
乔治奥斯·卡拉卡西迪斯 2023
摘要
类:
此类用于编译CTC训练的解码图。 |
|
这个抽象类用于编译用于训练和解码的图。 |
参考
- class speechbrain.k2_integration.graph_compiler.GraphCompiler[source]
基础类:
ABC这个抽象类用于编译用于训练和解码的图。
- abstract property topo: k2.Fsa
返回用于编译图的拓扑结构。
- abstract property device
返回用于编译图的设备。
- abstract compile(texts: List[str], is_training: bool = True) Tuple[k2.Fsa, Tensor][source]
为给定的文本编译图形。
- class speechbrain.k2_integration.graph_compiler.CtcGraphCompiler(_lexicon: 词汇表, device: device, need_repeat_flag: bool = False)[source]
基础类:
GraphCompiler此类用于编译CTC训练的解码图。
- Parameters:
_lexicon (Lexicon) – 它是从
data/lang/lexicon.txt构建的。device (torch.device) – 用于将转录本编译为FSA的操作的设备。
need_repeat_flag (bool) – 如果为True,将会在ctc_topo中添加一个名为
_is_repeat_token_的属性,用于指示该token是否是ctc图中的重复token。此属性是实现基于音素的ctc损失的延迟惩罚所必需的。更多详情请参见https://github.com/k2-fsa/k2/pull/1086。注意:上述更改必须包含在k2中才能启用此标志,因此请确保您使用的是最新版本。
Example
>>> import torch >>> from speechbrain.k2_integration.losses import ctc_k2 >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler >>> from speechbrain.k2_integration.lexicon import Lexicon >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs >>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30) >>> log_probs.requires_grad = True >>> # Assume all utterances have the same length so no padding was needed. >>> input_lens = torch.ones(batch_size) >>> # Create a small lexicon containing only two words and write it to a file. >>> lang_tmpdir = getfixture('tmpdir') >>> lexicon_sample = "hello h e l l o\nworld w o r l d\n<UNK> <unk>" >>> lexicon_file = lang_tmpdir.join("lexicon.txt") >>> lexicon_file.write(lexicon_sample) >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt >>> prepare_lang(lang_tmpdir) >>> # Create a lexicon object >>> lexicon = Lexicon(lang_tmpdir) >>> # Create a random decoding graph >>> graph = CtcGraphCompiler( ... lexicon, ... log_probs.device, ... ) >>> isinstance(graph.topo, k2.Fsa) True
- property topo
返回ctc_topo。
- property lexicon
返回词典。
- property device
返回用于编译图形的设备。