speechbrain.nnet.transducer.transducer_joint 模块

实现transducer_joint的库。

Author

阿卜杜勒瓦哈布·赫巴 2020

摘要

类:

Transducer_joint

计算转录网络(TN)和预测网络(PN)之间的联合张量

参考

class speechbrain.nnet.transducer.transducer_joint.Transducer_joint(joint_network=None, joint='sum', nonlinearity=<class 'torch.nn.modules.activation.LeakyReLU'>)[source]

基础:Module

计算转录网络(TN)和预测网络(PN)之间的联合张量

Parameters:
  • joint_network (torch.class (神经网络模块)) – 如果 joint == “concat”,我们在 TN 和 PN 连接后调用此网络 如果为 None,我们不使用此网络。

  • joint (str) – 通过(“sum”或“concat”)选项连接两个张量。

  • nonlinearity (torch class) – 在TN和PN之间的连接后使用的激活函数 非线性类型(tanh, relu)。

Example

>>> from speechbrain.nnet.transducer.transducer_joint import Transducer_joint
>>> from speechbrain.nnet.linear import Linear
>>> input_TN = torch.rand(8, 200, 1, 40)
>>> input_PN = torch.rand(8, 1, 12, 40)
>>> joint_network = Linear(input_size=80, n_neurons=80)
>>> TJoint = Transducer_joint(joint_network, joint="concat")
>>> output = TJoint(input_TN, input_PN)
>>> output.shape
torch.Size([8, 200, 12, 80])
init_params(first_input)[source]
Parameters:

first_input (tensor) – 用于初始化参数的第一个输入。

forward(input_TN, input_PN)[source]

返回输入张量的融合。

Parameters:
  • input_TN (torch.Tensor) – 来自转录网络的输入。

  • input_PN (torch.Tensor) – 来自预测网络的输入。

Return type:

输入张量的融合。