speechbrain.augment.augmenter 模块
用于实现数据增强管道的类。
- Authors
Mirco Ravanelli 2022
摘要
类:
应用数据增强的管道。 |
参考
- class speechbrain.augment.augmenter.Augmenter(parallel_augment=False, parallel_augment_fixed_bs=False, concat_original=False, min_augmentations=None, max_augmentations=None, shuffle_augmentations=False, repeat_augment=1, augment_start_index=0, augment_end_index=None, concat_start_index=0, concat_end_index=None, augment_prob=1.0, augmentations=[], enable_augmentations=None)[source]
基础:
Module应用数据增强的管道。
- Parameters:
parallel_augment (bool) – 如果为False,增强将按照pipeline参数中指定的顺序依次应用。 如果为True,所有N个增强将在输出中沿批次轴连接。
parallel_augment_fixed_bs (bool) – 如果为False,每个并行执行的增强器生成的增强示例数量等于批次大小。因此,总体而言,使用此选项会生成N*批次大小的人工数据,其中N是增强器的数量。当为True时,总增强示例的数量保持固定为批次大小,因此,对于每个增强器,固定为批次大小 // N个示例。此选项有助于控制合成示例的数量相对于原始数据分布,因为它始终保持50%的原始数据和50%的增强数据。
concat_original (bool) – 如果为True,原始输入将与增强输出连接(在批次轴上)。
min_augmentations (int) – 输入信号的增强次数在min_augmentations和max_augmentations之间随机采样。例如,如果增强字典包含N=6个增强,并且我们设置min_augmentations=1和max_augmentations=4,我们将应用最多M=4个增强。所选的增强按照增强字典中指定的顺序应用。如果shuffle_augmentations = True,则随机选择一组M个增强。
max_augmentations (int) – 最大增强次数。有关更多详细信息,请参见 min_augmentations。
shuffle_augmentations (bool) – 如果为True,它会打乱增强字典的条目。 效果是随机选择增强的顺序来应用。
repeat_augment (int) – 应用增强算法N次。这可以用于执行更多的数据增强。
augment_start_index (int) – 输入批次中数据增强应从其开始的第一个元素的索引。 此参数允许您指定应用数据增强的起始点。
augment_end_index (int) – 输入批次中数据增强应停止的最后一个元素的索引。 您可以使用此参数来定义在批次内应用数据增强的终点。
concat_start_index (int) – 如果
concat_original设置为 True,您可以指定原始批次的一部分以在输出中连接。 使用此参数选择从原始输入批次中开始复制的第一个元素的索引。concat_end_index (int) – 如果
concat_original设置为 True,您可以指定原始批次的一部分以在输出中连接。使用此参数选择原始输入批次中最后一个元素的索引以结束复制过程。augment_prob (float) – 应用数据增强的概率(0.0 到 1.0)。当设置为 0.0 时,返回原始信号而不进行任何增强。当设置为 1.0 时,始终应用增强。介于两者之间的值决定了增强的可能性。
augmentations (list) – 用于组合执行数据增强的增强器对象列表。
enable_augmentations (list) – 一个布尔值列表,用于选择性地启用或禁用‘augmentations’列表中的特定增强技术。 每个布尔值对应于‘augmentations’列表中的一个增强对象,并且应该具有相同的长度和顺序。 此功能对于执行增强技术的消融实验以针对特定任务进行定制非常有用。
Example
>>> from speechbrain.augment.time_domain import DropFreq, DropChunk >>> freq_dropper = DropFreq() >>> chunk_dropper = DropChunk(drop_start=100, drop_end=16000) >>> augment = Augmenter(parallel_augment=False, concat_original=False, augmentations=[freq_dropper, chunk_dropper]) >>> signal = torch.rand([4, 16000]) >>> output_signal, lengths = augment(signal, lengths=torch.tensor([0.2,0.5,0.7,1.0]))
- augment(x, lengths, selected_augmentations)[source]
对选定的增强应用数据增强。
- Parameters:
x (torch.Tensor (batch, time, channel)) – 输入以增强。
lengths (torch.Tensor) – 批次中每个序列的长度。
selected_augmentations (dict) – 包含要应用的选定增强的字典。
- Returns:
output (torch.Tensor) – 增强的输出。
output_lengths (torch.Tensor) – 每个输出对应的长度。
- forward(x, lengths)[source]
应用数据增强。
- Parameters:
x (torch.Tensor (batch, time, channel)) – 输入以增强。
lengths (torch.Tensor) – 批次中每个序列的长度。
- Returns:
output (torch.Tensor) – 增强的输出。
output_lengths (torch.Tensor) – 每个输出对应的长度。
- concatenate_outputs(augment_lst, augment_len_lst)[source]
连接一系列增强信号,考虑不同的时间长度。 应用填充以确保所有信号都可以连接。
- Parameters:
augment_lst (List of torch.Tensor) – 要连接的增强信号列表。
augment_len_lst (List of torch.Tensor) – 对应于增强信号的长度列表。
- Returns:
concatenated_signals (torch.Tensor) – 包含连接信号的张量。
concatenated_lengths (torch.Tensor) – 包含连接信号长度的张量。
注释
该函数接收一个增强信号列表,这些信号可能由于速度变化等原因具有不同的时间长度。它会将信号填充以匹配输入信号中找到的最大时间维度,并在连接它们之前相应地重新缩放长度。
- replicate_multiple_labels(*args)[source]
沿着批次轴复制标签的次数与增强的数量相对应。实际上,并行和串联增强会改变时间维度。
- Parameters:
*args (tuple) – 要复制的输入标签张量。可以是单个或一组torch.Tensors。
- Returns:
augmented_labels – 与增强输入对应的标签。返回与输入中给定的相同数量的 torch.Tensor。
- Return type:
torch.Tensor