speechbrain.lobes.models.ContextNet 模块

ContextNet的SpeechBrain实现由 https://arxiv.org/pdf/2005.03191.pdf

Authors
  • 钟建元 2020

摘要

类:

ContextNet

该类实现了ContextNet。

ContextNetBlock

该类实现了ContextNet中的一个块。

SEmodule

该类实现了Squeeze-and-Excitation模块。

参考

class speechbrain.lobes.models.ContextNet.ContextNet(input_shape, out_channels=640, conv_channels=None, kernel_size=3, strides=None, num_blocks=21, num_layers=5, inner_dim=12, alpha=1, beta=1, dropout=0.15, activation=<class 'speechbrain.nnet.activations.Swish'>, se_activation=<class 'torch.nn.modules.activation.Sigmoid'>, norm=<class 'speechbrain.nnet.normalization.BatchNorm1d'>, residuals=None)[source]

基础类: Sequential

该类实现了ContextNet。

参考论文: https://arxiv.org/pdf/2005.03191.pdf

Parameters:
  • input_shape (tuple) – 输入的预期形状。

  • out_channels (int) – 该模型的输出通道数(默认640)。

  • conv_channels (可选 (列表[整数])) – 每个contextnet块的输出通道数。如果未提供,将初始化为上述论文的默认设置。

  • kernel_size (int) – 卷积层的核大小(默认为3)。

  • strides (可选 (list[int])) – 每个上下文块的步幅因子。此步幅应用于每个上下文块的最后一个卷积层。如果未提供,将初始化为上述论文的默认设置。

  • num_blocks (int) – 上下文块的数量(默认21)。

  • num_layers (int) – 每个上下文块的深度卷积层数(默认值为5)。

  • inner_dim (int) – SE模块瓶颈网络的内部维度(默认值为12)。

  • alpha (float) – 用于缩放网络输出通道的因子(默认值为1)。

  • beta (float) – 用于缩放Swish激活的Beta值(默认值为1)。

  • dropout (float) – 丢弃率(默认 0.15)。

  • activation (torch class) – 每个上下文块的激活函数(默认为Swish)。

  • se_activation (torch class) – SE模块的激活函数(默认为 torch.nn.Sigmoid)。

  • norm (torch 类) – 用于正则化模型的归一化方法(默认为 BatchNorm1d)。

  • residuals (可选 (list[bool])) – 是否在每个上下文块应用残差连接(默认无)。

Example

>>> inp = torch.randn([8, 48, 40])
>>> block = ContextNet(input_shape=inp.shape, num_blocks=14)
>>> out = block(inp)
>>> out.shape
torch.Size([8, 6, 640])
class speechbrain.lobes.models.ContextNet.SEmodule(input_shape, inner_dim, activation=<class 'torch.nn.modules.activation.Sigmoid'>, norm=<class 'speechbrain.nnet.normalization.BatchNorm1d'>)[source]

基础:Module

该类实现了Squeeze-and-Excitation模块。

Parameters:
  • input_shape (tuple) – 输入的预期形状。

  • inner_dim (int) – SE模块瓶颈网络的内部维度(默认值为12)。

  • activation (torch 类) – SE 模块的激活函数(默认为 torch.nn.Sigmoid)。

  • norm (torch 类) – 用于正则化模型的归一化方法(默认为 BatchNorm1d)。

Example

>>> inp = torch.randn([8, 120, 40])
>>> net = SEmodule(input_shape=inp.shape, inner_dim=64)
>>> out = net(inp)
>>> out.shape
torch.Size([8, 120, 40])
forward(x)[source]

处理输入张量 x 并返回输出张量。

class speechbrain.lobes.models.ContextNet.ContextNetBlock(out_channels, kernel_size, num_layers, inner_dim, input_shape, stride=1, beta=1, dropout=0.15, activation=<class 'speechbrain.nnet.activations.Swish'>, se_activation=<class 'torch.nn.modules.activation.Sigmoid'>, norm=<class 'speechbrain.nnet.normalization.BatchNorm1d'>, residual=True)[source]

基础:Module

该类实现了ContextNet中的一个块。

Parameters:
  • out_channels (int) – 该模型的输出通道数(默认640)。

  • kernel_size (int) – 卷积层的核大小(默认为3)。

  • num_layers (int) – 此上下文块的深度卷积层数(默认值为5)。

  • inner_dim (int) – SE模块瓶颈网络的内部维度(默认为12)。

  • input_shape (tuple) – 输入的预期形状。

  • stride (int) – 此上下文块的步幅因子(默认为1)。

  • beta (float) – 用于缩放Swish激活的Beta值(默认值为1)。

  • dropout (float) – 丢弃率(默认 0.15)。

  • activation (torch class) – 此上下文块的激活函数(默认为Swish)。

  • se_activation (torch class) – SE模块的激活函数(默认为 torch.nn.Sigmoid)。

  • norm (torch class) – 用于正则化模型的归一化方法(默认为 BatchNorm1d)。

  • 残差 (bool) – 是否在此上下文块中应用残差连接(默认值为 None)。

Example

>>> inp = torch.randn([8, 120, 40])
>>> block = ContextNetBlock(256, 3, 5, 12, input_shape=inp.shape, stride=2)
>>> out = block(inp)
>>> out.shape
torch.Size([8, 60, 256])
forward(x)[source]

处理输入张量 x 并返回输出张量。