speechbrain.decoders.seq2seq 模块
用于序列到序列自回归模型的解码方法。
- Authors
阿德尔·穆门 2022, 2023, 2024
周珏洁 2020
彼得·普兰廷加 2020
Mirco Ravanelli 2020
叶松林 2020
摘要
类:
该类处理解码过程中假设的数据。 |
|
S2SBaseSearcher 类,用于被其他解码方法继承,用于序列到序列模型。 |
|
该类实现了seq2seq模型的束搜索算法。 |
|
该类实现了贪婪解码方法的通用前向传递。 |
|
该类实现了基于文本的HF序列到序列模型的束搜索解码,例如mBART或NLLB。 |
|
该类实现了AttentionalRNNDecoder(speechbrain/nnet/RNN.py)的束搜索解码。 |
|
该类实现了AttentionalRNNDecoder(speechbrain/nnet/RNN.py)的贪婪解码。 |
|
该类实现了Transformer的束搜索解码。 |
|
该类实现了Transformer的贪婪解码。 |
|
该类实现了由OpenAI在https://cdn.openai.com/papers/whisper.pdf中提出的Whisper神经网络的束搜索解码。 |
|
该类实现了由OpenAI在https://cdn.openai.com/papers/whisper.pdf中提出的Whisper神经网络的贪婪解码。 |
参考
- class speechbrain.decoders.seq2seq.AlivedHypotheses(alived_seq, alived_log_probs, sequence_scores)[source]
基础:
Module该类在解码过程中处理假设的数据。
- Parameters:
alived_seq (torch.Tensor) – 每个假设的标记序列。
alived_log_probs (torch.Tensor) – 每个假设中每个标记的对数概率。
sequence_scores (torch.Tensor) – 每个假设的对数概率之和。
- class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
基础:
ModuleS2SBaseSearcher 类将由其他解码方法继承,用于 seq2seq 模型。
- Parameters:
- forward(enc_states, wav_len)[source]
此方法应实现解码方法的前向算法。
- Parameters:
enc_states (torch.Tensor) – 解码时要使用的预计算编码器状态。 (例如,要关注的编码语音表示)。
wav_len (torch.Tensor) – 语音大脑风格的相对长度。
- Returns:
hyps – 预测的标记,作为列表的列表,或者如果 return_topk 为 True, 则为形状为 (batch, topk, 最大 token_id 序列长度) 的张量。
top_lengths – 批次中每个 topk 序列的长度。
top_scores – 这是 topk 假设的最终分数。
top_log_probs – 每个假设的对数概率。
- forward_step(inp_tokens, memory, enc_states, enc_lens)[source]
此方法应实现自回归模型中的前向操作的一步。
- Parameters:
inp_tokens (torch.Tensor) – 当前步骤的输入张量。
memory (无限制) – 此步骤的内存变量输入。 (例如 RNN 隐藏状态)。
enc_states (torch.Tensor) – 需要关注的编码器状态。
enc_lens (torch.Tensor) – 每个enc_states序列的实际长度。
- Returns:
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
memory (No limit) – 在此步骤中生成的内存变量。 (例如 RNN 隐藏状态)。
attn (torch.Tensor) – 用于进行惩罚的注意力权重。
- class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
基础类:
S2SBaseSearcher该类实现了贪婪解码方法的通用前向传递。另请参阅 S2SBaseSearcher()。
- forward(enc_states, wav_len)[source]
此方法执行贪婪搜索。
- Parameters:
enc_states (torch.Tensor) – 解码时要使用的预计算编码器状态。 (例如,要关注的编码语音表示)。
wav_len (torch.Tensor) – 语音大脑风格的相对长度。
- Returns:
hyps (List[List[int]]) – 包含假设的列表。
top_lengths (torch.Tensor (batch)) – 该张量包含每个假设的长度。
top_scores (torch.Tensor (batch)) – 每个假设的分数。
top_log_probs (torch.Tensor (batch, max length of token_id sequences)) – 每个假设的对数概率。
- class speechbrain.decoders.seq2seq.S2STransformerGreedySearcher(modules, temperature=0.0, **kwargs)[source]
基础类:
S2SGreedySearcher该类实现了Transformer的贪婪解码。
- Parameters:
模块 (包含以下内容的列表:) –
- modeltorch.nn.Module
一个TransformerASR模型。
- seq_lintorch.nn.Module
用于seq2seq模型的线性输出层。
温度 (float) – 解码期间使用的温度。
**kwargs – 传递给 S2SGreedySearcher 的参数
- class speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher(model, temperature=0.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]
基础类:
S2SGreedySearcher该类实现了由OpenAI开发的Whisper神经网络的贪婪解码 在https://cdn.openai.com/papers/whisper.pdf中。
- Parameters:
model (HuggingFaceWhisper) – Whisper 模型。
温度 (float) – 解码期间使用的温度。
use_kv_cache (bool (默认值: True)) – 是否使用键值缓存。
suppress_blank (bool (默认值: True)) – 这将抑制空白输出。
suppress_tokens (str 或 list (默认: "-1")) – 要抑制的令牌ID列表(或以逗号分隔的令牌ID) “-1” 将抑制一组符号,如
model.non_speech_tokens()中定义的sample_len (int (默认值: None)) – 采样的最大令牌数。
prefix (str 或 list (默认值: None)) – 要添加到输入标记的前缀。 参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
提示 (str 或 list (默认值: None)) – 要添加到输入标记中的提示。 参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
**kwargs – 参见 S2SBaseSearcher,参数直接传递。
- property get_tokens_to_suppress
如果self.config.suppress_tokens为None,获取在解码期间要抑制的tokens。
- class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, temperature=0.0, **kwargs)[source]
基础类:
S2SGreedySearcher该类实现了AttentionalRNNDecoder(speechbrain/nnet/RNN.py)的贪婪解码。 另请参见S2SBaseSearcher()和S2SGreedySearcher()。
- Parameters:
embedding (torch.nn.Module) – 一个嵌入层。
decoder (torch.nn.Module) – 注意力RNN解码器。
linear (torch.nn.Module) – 一个线性输出层。
温度 (float) – 解码期间使用的温度。
**kwargs – 参见 S2SBaseSearcher,参数直接传递。
Example
>>> import speechbrain as sb >>> from speechbrain.decoders import S2SRNNGreedySearcher >>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> searcher = S2SRNNGreedySearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=0, ... eos_index=1, ... min_decode_ratio=0, ... max_decode_ratio=1, ... ) >>> batch_size = 2 >>> enc = torch.rand([batch_size, 6, 7]) >>> wav_len = torch.ones([batch_size]) >>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, scorer=None, return_topk=False, topk=1, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]
基础类:
S2SBaseSearcher该类实现了用于seq2seq模型的beam-search算法。 另请参阅S2SBaseSearcher()。
- Parameters:
bos_index (int) – 序列开始标记的索引。
eos_index (int) – 序列结束标记的索引。
min_decode_ratio (float) – 最小解码步骤与编码器状态长度的比率。
max_decode_ratio (float) – 最大解码步骤与编码器状态长度的比率。
beam_size (int) – 光束的宽度。
scorer (speechbrain.decoders.scorers.ScorerBuilder) – 评分器实例。默认值:None。
return_topk (bool) – 是否返回topk假设。topk假设将被填充到相同的长度。默认值:False。
topk (int) – 如果 return_topk 为 True,则返回 topk 个假设。默认值:1。
using_eos_threshold (bool) – 是否使用eos阈值。默认值:True。
eos_threshold (float) – eos token的阈值系数。默认值:1.5。 参见参考文献中的3.1.2节:https://arxiv.org/abs/1904.02619
length_normalization (bool) – 是否将分数除以长度。默认值:True。
using_max_attn_shift (bool) – 是否使用max_attn_shift约束。默认值:False。
max_attn_shift (int) – 在波束搜索中,如果注意力偏移超过max_attn_shift,将会阻止这些波束。默认值:60。 参考:https://arxiv.org/abs/1904.02619
minus_inf (float) – 用于阻止搜索路径的负无穷大值。默认值:-1e20。
- init_beam_search_data(enc_states, wav_len)[source]
初始化波束搜索数据。
- Parameters:
enc_states (torch.Tensor) – 需要关注的编码器状态。
wav_len (torch.Tensor) – 每个enc_states序列的实际长度。
- Returns:
alived_hyps (AlivedHypotheses) – 存活的假设。
inp_tokens (torch.Tensor) – 当前步骤的输入张量。
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
eos_hyps_and_log_probs_scores (list) – 生成的假设(已达到eos的假设)和对数概率分数。
memory (No limit) – 在此步骤中生成的内存变量。
scorer_memory (No limit) – 在此步骤中生成的内存变量。
attn (torch.Tensor) – 注意力权重。
prev_attn_peak (torch.Tensor) – 前一个注意力峰值位置。
enc_states (torch.Tensor) – 需要关注的编码器状态。
enc_lens (torch.Tensor) – 每个enc_states序列的实际长度。
- search_step(alived_hyps, inp_tokens, log_probs, eos_hyps_and_log_probs_scores, memory, scorer_memory, attn, prev_attn_peak, enc_states, enc_lens, step)[source]
搜索下一步最可能的标记。
- Parameters:
alived_hyps (AlivedHypotheses) – 存活的假设。
inp_tokens (torch.Tensor) – 当前步骤的输入张量。
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
eos_hyps_and_log_probs_scores (list) – 生成的假设(已达到eos的假设)和对数概率分数。
memory (无限制) – 此步骤的内存变量输入。 (例如 RNN 隐藏状态)。
scorer_memory (无限制) – 此步骤的内存变量输入。 (例如 RNN 隐藏状态)。
attn (torch.Tensor) – 注意力权重。
prev_attn_peak (torch.Tensor) – 之前的注意力峰值位置。
enc_states (torch.Tensor) – 需要关注的编码器状态。
enc_lens (torch.Tensor) – 每个enc_states序列的实际长度。
step (int) – 当前的解码步骤。
- Returns:
alived_hyps (AlivedHypotheses) – 存活的假设。
inp_tokens (torch.Tensor) – 当前步骤的输入张量。
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
eos_hyps_and_log_probs_scores (list) – 生成的假设(已达到eos的假设)和对数概率分数。
memory (No limit) – 在此步骤中生成的内存变量。
scorer_memory (No limit) – 在此步骤中生成的内存变量。
attn (torch.Tensor) – 注意力权重。
prev_attn_peak (torch.Tensor) – 前一个注意力峰值位置。
scores (torch.Tensor) – 当前步骤输出的分数。
- forward(enc_states, wav_len)[source]
应用beamsearch并返回预测的标记。
- Parameters:
enc_states (torch.Tensor) – 需要关注的编码器状态。
wav_len (torch.Tensor) – 每个enc_states序列的实际长度。
- Returns:
hyps (list) – 预测的标记。
best_lens (torch.Tensor) – 每个预测标记的长度。
best_scores (torch.Tensor) – 每个预测标记的分数。
best_log_probs (torch.Tensor) – 每个预测标记的对数概率。
- class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, temperature=1.0, **kwargs)[source]
基础类:
S2SBeamSearcher该类实现了AttentionalRNNDecoder(speechbrain/nnet/RNN.py)的束搜索解码。 另请参见S2SBaseSearcher()、S2SBeamSearcher()。
- Parameters:
embedding (torch.nn.Module) – 一个嵌入层。
decoder (torch.nn.Module) – 注意力RNN解码器。
linear (torch.nn.Module) – 一个线性输出层。
温度 (float) – 应用于softmax的温度因子。它改变了概率分布,当T>1时更柔和,当T<1时更尖锐。
**kwargs – 参见 S2SBeamSearcher,参数直接传递。
Example
>>> import speechbrain as sb >>> vocab_size = 5 >>> emb = torch.nn.Embedding(vocab_size, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3) >>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size) >>> scorer = sb.decoders.scorer.ScorerBuilder( ... full_scorers = [coverage_scorer], ... partial_scorers = [], ... weights= dict(coverage=1.5) ... ) >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=4, ... eos_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... beam_size=2, ... scorer=scorer, ... ) >>> batch_size = 2 >>> enc = torch.rand([batch_size, 6, 7]) >>> wav_len = torch.ones([batch_size]) >>> hyps, _, _, _ = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2STransformerBeamSearcher(modules, temperature=1.0, **kwargs)[source]
基础类:
S2SBeamSearcher该类实现了Transformer的束搜索解码。 另请参见S2SBaseSearcher(), S2SBeamSearcher()。
- Parameters:
模块 (包含以下内容的列表:) –
- modeltorch.nn.Module
一个Transformer模型。
- seq_lintorch.nn.Module
一个线性输出层。
temperature (float) – 应用于softmax的温度因子。它改变了概率分布,当T>1时更柔和,当T<1时更尖锐。
**kwargs – 传递给S2SBeamSearcher的参数
Example
>>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.decoders import S2STransformerBeamSearcher >>> batch_size=8 >>> n_channels=6 >>> input_size=40 >>> d_model=128 >>> tgt_vocab=140 >>> src = torch.rand([batch_size, n_channels, input_size]) >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels]) >>> net = TransformerASR( ... tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> searcher = S2STransformerBeamSearcher( ... modules=[net, lin], ... bos_index=1, ... eos_index=2, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... using_eos_threshold=False, ... beam_size=7, ... temperature=1.15, ... ) >>> enc, dec = net.forward(src, tgt) >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size))
- class speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher(module, temperature=1.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]
基础类:
S2SBeamSearcher该类实现了由OpenAI开发的Whisper神经网络的束搜索解码 在https://cdn.openai.com/papers/whisper.pdf中。
束搜索是有状态的,这意味着一些变量存储在搜索器中。如果你想在不同的上下文中重用搜索器,你应该确保变量相应地更新。
- Parameters:
模块 (包含以下内容的列表:) –
- 模型torch.nn.Module
一个whisper模型。它应该有一个decode()方法。
温度 (float) – 解码期间使用的温度。
use_kv_cache (bool (默认值: True)) – 是否使用键值缓存。
suppress_blank (bool (默认值: True)) – 这将抑制空白输出。
suppress_tokens (str 或 list (默认: "-1")) – 要抑制的令牌ID列表(或以逗号分隔的令牌ID) “-1” 将抑制一组符号,如
model.non_speech_tokens()中定义的sample_len (int (默认值: None)) – 采样的最大令牌数。
prefix (str 或 list (默认值: None)) – 要添加到输入标记的前缀。 参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
提示 (str 或 list (默认值: None)) – 要添加到输入标记中的提示。 参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
**kwargs – 参见 S2SBeamSearcher,参数直接传递。
- property get_tokens_to_suppress
如果self.config.suppress_tokens为None,获取在解码期间要抑制的tokens。
- class speechbrain.decoders.seq2seq.S2SHFTextBasedBeamSearcher(modules, vocab_size, **kwargs)[source]
基础类:
S2STransformerBeamSearcher该类实现了基于文本的HF序列到序列模型(如mBART或NLLB)的束搜索解码。 它与S2STransformerBeamSearcher没有显著不同。 这就是为什么它继承了S2STransformerBeamSearcher。 主要区别可能出现在希望直接使用基于文本的HF模型的lm_head而不是创建新的投影层(self.fc = None)时。
- Parameters:
模块 (包含以下内容的列表:) –
- modeltorch.nn.Module
一个Transformer模型。
- seq_lintorch.nn.Module
一个线性输出层。 通常在此用例中设置为None。
vocab_size (int) – lm_head的维度。
**kwargs – 传递给S2SBeamSearcher的参数