speechbrain.lobes.models.PIQ 模块
该文件实现了通过量化实现事后解释所需的类和函数。
作者 * Cem Subakan 2023 * Francesco Paissan 2023
摘要
类:
该类实现了一个卷积编码器,用于从对数谱中提取分类嵌入。 |
|
该类实现了一个残差块。 |
|
实现VQ字典。 |
|
此类定义了向量量化的前向方法。 |
|
此类定义了向量量化的前向方法。 |
|
该类从FocalNet分类器的表示中重建对数功率谱图。 |
|
该类从ViT分类器的表示中重建对数功率谱图。 |
|
该类从分类器的表示中重建对数功率谱图。 |
函数:
该类返回一个二进制矩阵,该矩阵指示给定标签数组的VQ字典中的无关区域 |
|
对网络权重应用Xavier初始化。 |
参考
- speechbrain.lobes.models.PIQ.get_irrelevant_regions(labels, K, num_classes, N_shared=5, stage='TRAIN')[source]
该类返回一个二进制矩阵,该矩阵根据给定的标签数组指示VQ字典中的无关区域
- Parameters:
- Returns:
无关区域
- Return type:
torch.Tensor
Example
>>> labels = torch.Tensor([1, 0, 2]) >>> irrelevant_regions = get_irrelevant_regions(labels, 20, 3, 5) >>> print(irrelevant_regions.shape) torch.Size([3, 20])
- class speechbrain.lobes.models.PIQ.VectorQuantization(*args, **kwargs)[source]
基础:
Function该类定义了向量量化的前向方法。由于VQ不可微分,如果在调用
.grad()时,它将返回一个RuntimeError。请参考VectorQuantizationStraightThrough以获取VQ操作的直通梯度估计。- static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]
将VQ应用于向量
input,使用codebook作为VQ字典。- Parameters:
ctx (torch context) – 用于存储反向传播信息的上下文对象。
inputs (torch.Tensor) – 需要量化的隐藏表示。期望的形状是
torch.Size([B, W, H, C])。codebook (torch.Tensor) – 用于量化的VQ字典。期望的形状为
torch.Size([K, C]),其中K为字典元素的数量。labels (torch.Tensor) – 分类标签。用于定义不相关区域并根据预测类别划分潜在空间。形状应为
torch.Size([B])。num_classes (int) – 可能的类别数量
activate_class_partitioning (bool) –
True如果潜在空间应该为不同的类别进行量化。shared_keys (int) – 类之间共享的键的数量。
训练 (bool) –
True如果阶段是训练。
- Returns:
用于量化表示的Codebook索引
- Return type:
torch.Tensor
Example
>>> inputs = torch.ones(3, 14, 25, 256) >>> codebook = torch.randn(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> print(VectorQuantization.apply(inputs, codebook, labels).shape) torch.Size([3, 14, 25])
- class speechbrain.lobes.models.PIQ.VectorQuantizationStraightThrough(*args, **kwargs)[source]
基础:
Function这个类定义了向量量化的前向方法。由于VQ不可微分,它近似了VQ的梯度,如https://arxiv.org/abs/1711.00937中所述。
- static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]
将VQ应用于向量
input,使用codebook作为VQ字典,并通过量化步骤的直通(id)近似来估计梯度。- Parameters:
ctx (torch context) – 用于存储反向传播信息的上下文对象。
inputs (torch.Tensor) – 需要量化的隐藏表示。期望的形状是
torch.Size([B, W, H, C])。codebook (torch.Tensor) – 用于量化的VQ字典。期望的形状为
torch.Size([K, C]),其中K为字典元素的数量。labels (torch.Tensor) – 分类标签。用于定义不相关区域并根据预测类别划分潜在空间。形状应为
torch.Size([B])。num_classes (int) – 可能的类别数量
activate_class_partitioning (bool) –
True如果潜在空间应该为不同的类别进行量化。shared_keys (int) – 类之间共享的键的数量。
训练 (bool) –
True如果阶段是训练。
- Returns:
量化表示和量化表示的码本索引
- Return type:
Example
>>> inputs = torch.ones(3, 14, 25, 256) >>> codebook = torch.randn(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> quant, quant_ind = VectorQuantizationStraightThrough.apply(inputs, codebook, labels) >>> print(quant.shape, quant_ind.shape) torch.Size([3, 14, 25, 256]) torch.Size([1050])
- static backward(ctx, grad_output, grad_indices, labels=None, num_classes=None, activate_class_partitioning=True, shared_keys=10, training=True)[source]
假设向量量化作为恒等函数来估计梯度。(https://arxiv.org/abs/1711.00937)
- class speechbrain.lobes.models.PIQ.Conv2dEncoder_v2(dim=256)[source]
基础:
Module该类实现了一个卷积编码器,用于从对数谱中提取分类嵌入。
- Parameters:
dim (int) – 提取嵌入的通道数。
Example
>>> inputs = torch.ones(3, 431, 513) >>> model = Conv2dEncoder_v2() >>> print(model(inputs).shape) torch.Size([3, 256, 26, 32])
- class speechbrain.lobes.models.PIQ.ResBlockAudio(dim)[source]
基础:
Module这个类实现了一个残差块。
- Parameters:
dim (int) – 要处理的张量的输入通道数。与残差块的输出通道数匹配。
Example
>>> res = ResBlockAudio(128) >>> x = torch.randn(2, 128, 16, 16) >>> print(x.shape) torch.Size([2, 128, 16, 16])
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSI_Audio(dim=128, K=512, numclasses=50, activate_class_partitioning=True, shared_keys=0, use_adapter=True, adapter_reduce_dim=True)[source]
基础:
Module该类从分类器的表示中重建对数功率谱图。
- Parameters:
Example
>>> psi = VectorQuantizedPSI_Audio(dim=256, K=1024) >>> x = torch.randn(2, 256, 16, 16) >>> labels = torch.Tensor([0, 2]) >>> logspectra, hcat, z_q_x = psi(x, labels) >>> print(logspectra.shape, hcat.shape, z_q_x.shape) torch.Size([2, 1, 257, 257]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSIFocalNet_Audio(dim=1024, **kwargs)[source]
-
该类从FocalNet分类器的表示中重建对数功率谱图。
Example
>>> psi = VectorQuantizedPSIFocalNet_Audio(dim=256, K=1024) >>> x = torch.randn(2, 256, 16, 16) >>> labels = torch.Tensor([0, 2]) >>> logspectra, hcat, z_q_x = psi(x, labels) >>> print(logspectra.shape, hcat.shape, z_q_x.shape) torch.Size([2, 1, 495, 593]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSIViT_Audio(dim=768, **kwargs)[source]
-
该类从ViT分类器的表示中重建对数功率谱图。
Example
>>> psi = VectorQuantizedPSIViT_Audio(dim=256, K=1024) >>> x = torch.randn(2, 256, 16, 16) >>> labels = torch.Tensor([0, 2]) >>> logspectra, hcat, z_q_x = psi(x, labels) >>> print(logspectra.shape, hcat.shape, z_q_x.shape) torch.Size([2, 1, 495, 593]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- class speechbrain.lobes.models.PIQ.VQEmbedding(K, D, numclasses=50, activate_class_partitioning=True, shared_keys=0)[source]
基础:
Module实现VQ字典。封装了
VectorQuantization和VectorQuantizationStraightThrough。更多详情请参考具体类。- Parameters:
- forward(z_e_x, labels=None)[source]
封装了向量量化。计算输入量化的VQ字典索引。请注意,这个前向步骤是不可微分的。
- Parameters:
z_e_x (torch.Tensor) – 要量化的输入张量。
labels (torch.Tensor) – 输入表示的预测类别(用于潜在空间量化)。
- Returns:
用于量化表示的Codebook索引
- Return type:
torch.Tensor
Example
>>> inputs = torch.ones(3, 256, 14, 25) >>> codebook = VQEmbedding(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> print(codebook(inputs, labels).shape) torch.Size([3, 14, 25])
- straight_through(z_e_x, labels=None)[source]
实现了使用梯度直通近似的向量量化。
- Parameters:
z_e_x (torch.Tensor) – 要量化的输入张量。
labels (torch.Tensor) – 输入表示的预测类别(用于潜在空间量化)。
- Returns:
直接量化表示和量化表示
- Return type:
Example
>>> inputs = torch.ones(3, 256, 14, 25) >>> codebook = VQEmbedding(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> quant, quant_ind = codebook.straight_through(inputs, labels) >>> print(quant.shape, quant_ind.shape) torch.Size([3, 256, 14, 25]) torch.Size([3, 256, 14, 25])