speechbrain.decoders.utils 模块
用于解码模块的工具函数。
- Authors
阿德尔·穆门 2023
周珏洁 2020
彼得·普兰廷加 2020
Mirco Ravanelli 2020
叶松林 2020
摘要
函数:
调用 batch_size 次 filter_seq2seq_output。 |
|
过滤输出,直到第一个eos出现(不包括eos)。 |
|
此函数沿维度多次扩展张量。 |
|
此函数将用fill_value屏蔽张量中的某些元素,如果condition=False。 |
参考
- speechbrain.decoders.utils.inflate_tensor(tensor, times, dim)[source]
此函数沿维度将张量膨胀多次。
- Parameters:
- Returns:
膨胀的张量。
- Return type:
torch.Tensor
Example
>>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) >>> new_tensor = inflate_tensor(tensor, 2, dim=0) >>> new_tensor tensor([[1., 2., 3.], [1., 2., 3.], [4., 5., 6.], [4., 5., 6.]])
- speechbrain.decoders.utils.mask_by_condition(tensor, cond, fill_value)[source]
如果条件为假,此函数将用fill_value屏蔽张量中的某些元素。
- Parameters:
tensor (torch.Tensor) – 需要被掩码的张量。
cond (torch.BoolTensor) – 这个张量必须与张量的大小相同。 每个元素表示是否保留张量中的值。
fill_value (float) – 用于填充被屏蔽元素的值。
- Returns:
被掩码的张量。
- Return type:
torch.Tensor
Example
>>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) >>> cond = torch.BoolTensor([[True, True, False], [True, False, False]]) >>> mask_by_condition(tensor, cond, 0) tensor([[1., 2., 0.], [4., 0., 0.]])
- speechbrain.decoders.utils.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]
调用 batch_size 次 filter_seq2seq_output。
- Parameters:
- Returns:
由seq2seq模型预测的输出。
- Return type:
Example
>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])] >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) >>> predictions [[1, 2, 3], [2, 3]]
- speechbrain.decoders.utils.filter_seq2seq_output(string_pred, eos_id=-1)[source]
过滤输出直到第一个eos出现(不包括eos)。
- Parameters:
- Returns:
由seq2seq模型预测的输出。
- Return type:
Example
>>> string_pred = ['a','b','c','d','eos','e'] >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos') >>> string_out ['a', 'b', 'c', 'd']