speechbrain.lobes.models.huggingface_transformers.whisper 模块
该模块支持集成huggingface预训练的whisper模型。
需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html
- Authors
阿德尔·穆门 2022, 2024
Titouan Parcollet 2022
卢卡·德拉·利贝拉 2022
Ha Nguyen 2023
摘要
类:
该模块支持集成HuggingFace预训练的Whisper模型。 |
参考
- class speechbrain.lobes.models.huggingface_transformers.whisper.Whisper(source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=False, output_all_hiddens=False, language=None, task='transcribe')[source]
-
该模块支持集成HuggingFace预训练的Whisper模型。
- Source paper whisper:
需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html
代码的某些部分也改编自官方的OpenAI仓库: https://github.com/openai/whisper
模型可以进行微调。它将自动从HuggingFace下载模型或使用本地路径。
- Parameters:
source (str) – HuggingFace 中心名称:例如 “openai/whisper-tiny”
save_path (str) – 下载模型的路径(目录)。
sampling_rate (int (默认值: 16000)) – 音频信号的采样率。
encoder_only (bool (默认值: False)) – 如果为True,前向函数输出编码器最后一个transformer层的隐藏状态。 如果为False,则执行并返回解码器的一步。
freeze (bool (默认值: False)) – 如果为True,模型将被冻结。
freeze_encoder (bool (默认值: False)) – 如果为True,则编码器被冻结。
output_attentions (bool (默认值: False)) – 如果
True,前向函数会输出注意力权重。默认情况下,它是False,因为 flash attention 需要设置output_attentions=False。如果output_attentions是True, 则会使用从头实现的注意力机制,这可能会使代码变慢并增加 VRAM 内存使用量。output_all_hiddens (bool (默认值: False)) – 如果为True,前向函数会输出编码器中所有transformer层的隐藏状态。 例如,whisper-base有6个transformer层,输出形状为(7, B, T, C), 其中CNN的输出被添加到开头。 如果为False,前向函数仅输出编码器最后一个transformer层的隐藏状态。
语言 (str (默认: "en")) – 用于解码器的语言标记。
任务 (str (默认: "transcribe")) – 用于解码器的任务令牌。它必须是以下之一: - “transcribe” - “translate”
Example
>>> model_hub = "openai/whisper-tiny" >>> save_path = "savedir" >>> sampling_rate = 16000 >>> model = Whisper(model_hub, save_path, sampling_rate) >>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id >>> inputs = torch.randn([1, 93680]) >>> outputs = model(inputs, tokens)
- freeze_model(model)[source]
冻结模型的参数。
- Parameters:
model (来自 AutoModel.from_config) – 有效的 HuggingFace transformers 模型对象。
- forward(wav, decoder_input_ids=None)[source]
执行梅尔变换和一步耳语(编码器-解码器)。
- Parameters:
wav (torch.Tensor) – 一批要转换为特征的音频信号。
decoder_input_ids (torch.Tensor) – 解码器的输入标记。这可以是语言、任务等。 请参阅whisper论文以获取更多详细信息,或前往SpeechBrain中的 seq2seq2.py文件查看如何使用贪婪搜索和/或束搜索生成标记。
- Returns:
out_encoder (torch.Tensor) – 编码器模型的输出。
decoder_logits (torch.Tensor) – 解码器模型的输出。
decoder_attn (torch.Tensor) – 解码器模型的注意力值。
- log_mel_spectrogram(audio, padding: int = 0)[source]
计算一批输入波形的梅尔频谱图。
参考:改编自 https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L92
- Parameters:
audio (torch.Tensor) – 一批16 kHz的音频波形。
padding (int) – 要附加到音频张量末尾的样本数量。
- Returns:
log_spec – 一个包含批量梅尔频谱图的张量。
- Return type:
torch.Tensor
- pad_or_trim(array, length: int = 480000, axis=-1)[source]
按照编码器的预期填充或修剪梅尔频谱图。
参考:改编自 https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L52
- forward_encoder(mel)[source]
接收一个输入mel并返回其对应的编码器状态。 如果output_all_hiddens为True,则返回编码器的最后一个隐藏状态或所有隐藏状态。
- Parameters:
mel (torch.Tensor (signal)) – 一批音频mel,用于转换为特征。
- Returns:
编码器的最后一个隐藏状态,如果output_all_hiddens为True,则为所有隐藏状态。
- Return type:
torch.Tensor
- forward_decoder(encoder_states, decoder_input_ids, use_cache=True, past_key_values=None)[source]
执行whisper解码器的一步。
- Parameters:
encoder_states (torch.Tensor) – 一批编码器状态特征(mel + whisper 特征提取器)。
decoder_input_ids (torch.Tensor) – 解码器的输入标记。这可以是语言、任务等。 请参阅whisper论文以获取更多详细信息,或前往SpeechBrain中的 seq2seq2.py文件查看如何使用贪婪搜索和/或束搜索生成标记。
use_cache (bool) – 如果为True,键和值将作为KV缓存的输出返回。
past_key_values (torch.Tensor (默认值: None)) – 如果不为None,则使用过去的键值进行KV缓存,并避免重新计算注意力权重。
- Returns:
logits (torch.Tensor) – 解码器的logits。
attn (torch.Tensor | None) – 如果
output_attentions为True,则返回注意力权重。否则,返回None。past_key_values (torch.Tensor) – 解码器的过去键值。
- property all_language_tokens
返回与语言标记对应的标记列表。
- property all_language_codes
返回与语言标记对应的语言代码列表。
- property non_speech_tokens
返回要抑制的标记列表,以避免任何说话者标签或非语音注释,以防止采样实际上未在音频中说的文本,例如。
♪♪♪
( 说外语 )
[DAVID] 你好,
保留基本标点符号,如逗号、句号、问号、感叹号等。
取自:openai/whisper GitHub
- property is_multilingual
如果模型是多语言的,则返回 True,否则返回 False。
- property get_suppress_tokens
返回要抑制的令牌列表
- detect_language(mel)[source]
检测给定梅尔频谱图特征的语言。
- Parameters:
mel (torch.Tensor) – 用于检测语言的梅尔频谱图特征。
- Returns:
language_tokens (torch.Tensor of shape (batch_size,)) – 最可能的语言标记的ID,出现在转录开始标记之后。
language_probs (List[Dict[str, float]]) – 包含所有语言概率分布的字典列表。
- Raises:
ValueError – 如果模型没有语言标记。