speechbrain.decoders.utils 模块

用于解码模块的工具函数。

Authors
  • 阿德尔·穆门 2023

  • 周珏洁 2020

  • 彼得·普兰廷加 2020

  • Mirco Ravanelli 2020

  • 叶松林 2020

摘要

函数:

batch_filter_seq2seq_output

调用 batch_size 次 filter_seq2seq_output。

filter_seq2seq_output

过滤输出,直到第一个eos出现(不包括eos)。

inflate_tensor

此函数沿维度多次扩展张量。

mask_by_condition

此函数将用fill_value屏蔽张量中的某些元素,如果condition=False。

参考

speechbrain.decoders.utils.inflate_tensor(tensor, times, dim)[source]

此函数沿维度将张量膨胀多次。

Parameters:
  • tensor (torch.Tensor) – 需要扩展的张量。

  • times (int) – 张量将膨胀的次数。

  • dim (int) – 要扩展的维度。

Returns:

膨胀的张量。

Return type:

torch.Tensor

Example

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> new_tensor = inflate_tensor(tensor, 2, dim=0)
>>> new_tensor
tensor([[1., 2., 3.],
        [1., 2., 3.],
        [4., 5., 6.],
        [4., 5., 6.]])
speechbrain.decoders.utils.mask_by_condition(tensor, cond, fill_value)[source]

如果条件为假,此函数将用fill_value屏蔽张量中的某些元素。

Parameters:
  • tensor (torch.Tensor) – 需要被掩码的张量。

  • cond (torch.BoolTensor) – 这个张量必须与张量的大小相同。 每个元素表示是否保留张量中的值。

  • fill_value (float) – 用于填充被屏蔽元素的值。

Returns:

被掩码的张量。

Return type:

torch.Tensor

Example

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> cond = torch.BoolTensor([[True, True, False], [True, False, False]])
>>> mask_by_condition(tensor, cond, 0)
tensor([[1., 2., 0.],
        [4., 0., 0.]])
speechbrain.decoders.utils.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]

调用 batch_size 次 filter_seq2seq_output。

Parameters:
  • 预测 (列表torch.Tensor) – 包含由seq2seq系统预测的输出整数的列表。

  • eos_id (int, string) – eos的id。

Returns:

由seq2seq模型预测的输出。

Return type:

list

Example

>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
>>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
>>> predictions
[[1, 2, 3], [2, 3]]
speechbrain.decoders.utils.filter_seq2seq_output(string_pred, eos_id=-1)[source]

过滤输出直到第一个eos出现(不包括eos)。

Parameters:
  • string_pred (list) – 包含由seq2seq系统预测的输出字符串/整数的列表。

  • eos_id (int, string) – eos的id。

Returns:

由seq2seq模型预测的输出。

Return type:

list

Example

>>> string_pred = ['a','b','c','d','eos','e']
>>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
>>> string_out
['a', 'b', 'c', 'd']