speechbrain.nnet.hypermixing 模块
该模块通过HyperMixing混合来自不同令牌的信息。 它可以被视为(自)注意力的线性时间替代方案。
来源: https://arxiv.org/abs/2203.03691
- Authors
弗洛里安·迈 2023
胡安·巴勃罗·祖鲁阿加 2023
摘要
类:
该类实现了多头HyperMixing。 |
|
该类实现了超网络。 |
|
实现多头HyperMixer或HyperConformer的类。 |
参考
- class speechbrain.nnet.hypermixing.HyperMixing(input_output_dim: int, hypernet_size: int, tied: bool = False, num_heads: int = 1, fix_tm_hidden_size: bool = False, max_length: int = 3000)[source]
基础:
Module该类实现了多头HyperMixing。 它是HyperMixer中token-mixing组件的实现,HyperMixer是一种线性时间的自注意力替代方案。与原始的HyperMixer相比,该模块支持多头机制,这提高了模型的表达能力,同时减少了参数数量。
参考:https://arxiv.org/abs/2203.03691
- Parameters:
Example
>>> import torch >>> inputs = torch.rand([8, 60, 512]) >>> net = HyperMixing(512, 2048, num_heads=8) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512])
- forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool | None = True, pos_embs: Tensor | None = None)[source]
此方法的签名特意选择与sb.nnet.attention.MultiHeadAttention相同,以便在SpeechBrain中保持兼容性。
注意:key、value、attn_mask 和 pos_embs 没有影响。Query 用于所有三个。因此,目前该模块仅应用于替换自注意力。
- Parameters:
query (torch.Tensor) – (B, L, E) 其中 L 是目标序列长度, B 是批次大小,E 是嵌入维度。
key (torch.Tensor) – (B, S, E) 其中 S 是源序列长度, B 是批次大小,E 是嵌入维度。 当前未使用。所有
value (torch.Tensor) – (B, S, E) 其中 S 是源序列长度, B 是批次大小,E 是嵌入维度。 当前未使用。
attn_mask (torch.Tensor, optional) – 注意:目前没有效果。
key_padding_mask (torch.Tensor, optional) – (B, S) 其中 B 是批量大小,S 是源序列长度。如果提供了 ByteTensor,非零位置将被忽略,而零位置将保持不变。如果提供了 BoolTensor,值为 True 的位置将被忽略,而值为 False 的位置将保持不变。
return_attn_weights (torch.Tensor, optional) – 注意:目前没有效果。
pos_embs (torch.Tensor, optional) – 注意:目前没有效果。
输出
-------
attn_output (torch.Tensor) – (B, L, E) 其中 L 是目标序列长度,B 是批次大小,E 是嵌入维度。
attn_output_weights (torch.Tensor) – (B, L, S) 其中 B 是批次大小,L 是目标序列长度,S 是源序列长度。 注意:总是返回全零。
- class speechbrain.nnet.hypermixing.HyperNetwork(input_output_dim: int, hypernet_size: int, tied=False, num_heads=1, keep_output_size=True)[source]
基础:
Module该类实现了超网络。这是一种使用一个网络(也称为超网络)来生成另一个网络的权重的方法。在这里,它用于生成线性层的标签。
参考:https://arxiv.org/abs/1609.09106
- Parameters: