Open In Colab 要在GitHub上执行或查看/下载此笔记本

用于更快低内存微调的神经网络适配器

本教程涵盖了SpeechBrain中适配器(如LoRA)的实现。这包括如何将SpeechBrain实现的适配器、自定义适配器以及来自PEFT等库的适配器集成到预训练模型中。

先决条件

介绍与背景

随着预训练模型变得越来越大、能力越来越强,人们对于在合理的时间范围内以内存高效的方式适应特定任务的方法越来越感兴趣。一种这样的技术是冻结原始参数并在原始模型中插入少量额外的参数,这些参数被称为“适配器”。这些适配器通常可以在参数数量的一小部分情况下匹配完全微调的性能,这意味着更快且更节省内存的微调[1]。一种流行的技术被称为低秩适应(LoRA)[2]。

在软件方面,HuggingFace 已经开发了一个名为 PEFT [3] 的流行适配器库。我们的实现包括该库的一些功能,以及将 PEFT 适配器集成到 SpeechBrain 模型中的能力。我们将从一个基本的 YAML 示例开始,这样如果你通过实验学习效果更好,你可以自己尝试一下。

相关文献

  1. N. Houlsby, A. Giurgiu, S. Jastrzebski, B. Morrone, Q. De Laroussilhe, A. Gesmundo, M. Attariyan, 和 S. Gelly, “用于自然语言处理的参数高效迁移学习。” 在 国际机器学习会议 中, 2019.

  2. E.J. Hu, P. Wallis, Z. Allen-Zhu, Y. Li, S. Wang, L. Wang, 和 W. Chen, “LoRA: 大型语言模型的低秩适应。” 在 国际学习表示会议 中, 2021.

  3. S. Mangrulkar, S. Gugger, L. Debut, Y. Belkada, S. Paul, 和 B. Bossan, “PEFT: 最先进的参数高效微调方法。” GitHub 仓库, 2022.

太长不看

本教程的简要说明是,你应该在HyperPyYAML文件中使用这样的部分来创建带有适配器的模型:

adapted_model: !new:speechbrain.nnet.adapters.AdaptedModel
    model_to_adapt: !ref <model>
    adapter_class: !name:speechbrain.nnet.adapters.LoRA
    all_linear: True
    unfrozen_layers: ["conv_1d_*"]
    adapter_kwargs:
        rank: 8

将此部分添加到YAML中,将已经定义的键model并为每个线性层添加一个LoRA适配器,使用关键字参数rank=8all_linearall_conv参数分别简单地为所有线性层或所有卷积层添加适配器。 默认情况下,此类冻结所有未适配层的参数,但可以使用unfrozen_layers参数指定也应训练的层的名称, 代价是参数数量增加。可以使用target_layers参数指定应适配的特定层。这些参数都支持通过使用python的fnmatch库进行unix风格的通配符匹配。

如果TL;DR还不够,你需要更详细地通过一个例子来工作,请继续到下一节。

详细教程

我们将演示如何在模板配方上使用适配器,其中包括完整训练所需的一切。第一步是预训练一个模型,以便我们稍后可以添加适配器。

!git clone --depth 1 --branch v1.0.2 https://github.com/speechbrain/speechbrain.git
!python -m pip install -e speechbrain
fatal: destination path 'speechbrain' already exists and is not an empty directory.
/home/pplantinga/Documents/Repositories/uvenv/bin/python: No module named pip
# In order to use speechbrain in this repo we have to add it to the path
import os, sys

sys.path.append(os.path.join(os.getcwd(), 'speechbrain'))
%cd speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/speechbrain/docs/tutorials/nn/speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/uvenv/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
!python train.py train.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/CRDNN_BPE_960h_LM/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 173.0M
* Total Number of Parameters: 173.0M
* Trainable Parameters represent 100.0000% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|████████████████████████| 760/760 [08:28<00:00,  1.49it/s, train_loss=1.35]
100%|█████████████████████████████████████████| 545/545 [01:28<00:00,  6.18it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.35 - valid loss: 1.31, valid CER: 7.71, valid WER: 20.06
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|███████████████████████████████████████| 1310/1310 [09:25<00:00,  2.32it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.30, test CER: 5.75, test WER: 17.57
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest

推理

为了证明这是有效的,让我们只对一个文件进行推理。这段代码取自 transcribe_file.py

from speechbrain.inference.ASR import EncoderDecoderASR
from speechbrain.utils.fetching import fetch, LocalStrategy

# Ensure all the needed files end up in the same place to load with the transcriber
save_dir = os.path.abspath("results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest")
fetch("lm.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("tokenizer.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("inference.yaml", os.getcwd(), save_dir, local_strategy=LocalStrategy.SYMLINK)

transcriber = EncoderDecoderASR.from_hparams(source=save_dir, hparams_file="inference.yaml")
speech_file = "../data/LibriSpeech/dev-clean-2/1272/135031/1272-135031-0015.flac"
transcriber.transcribe_file(speech_file)
INFO:speechbrain.utils.fetching:Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
INFO:speechbrain.utils.fetching:Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: lm, tokenizer, model, normalizer
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
'THE METAL FOREST IS IN THE GREAT DOMED CAVERN THE LARGEST IN ALL OUR DOMINIONS REPLIED CALICO ⁇ '

添加适配器

所以现在我们已经证明了模型至少是有效的,让我们继续添加适配器。我们基本上需要创建一个新的yaml文件,将适配器添加到模型中,然后使用这个新的yaml文件进行训练。为此,我们只需加载旧的yaml文件,然后更改所有必要的部分以训练适配后的模型。

%%writefile train_lora.patch
--- train.yaml	2024-10-07 19:23:49.839501714 -0400
+++ train_lora.yaml	2024-10-07 19:25:40.340933091 -0400
@@ -30,7 +30,7 @@
 NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
 
-output_folder: !ref results/CRDNN_BPE_960h_LM/<seed>
+output_folder: !ref results/crdnn_lora/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -41,7 +41,7 @@
 # speechbrain HuggingFace repository. However, a local path pointing to a
 # directory containing the lm.ckpt and tokenizer.ckpt may also be specified
 # instead. E.g if you want to use your own LM / tokenizer.
-pretrained_path: speechbrain/asr-crdnn-rnnlm-librispeech
+pretrained_path: results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest
 
 
 # Path where data manifest files will be stored. The data manifest files are created by the
@@ -481,10 +481,9 @@
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
     normalize: !ref <normalize>
-    lm_model: !ref <lm_model>
 
 # Gathering all the submodels in a single model object.
-model: !new:torch.nn.ModuleList
+model_pretrained: !new:torch.nn.ModuleList
     - - !ref <encoder>
       - !ref <embedding>
       - !ref <decoder>
@@ -629,8 +628,31 @@
     loadables:
         lm: !ref <lm_model>
         tokenizer: !ref <tokenizer>
-        model: !ref <model>
+        model: !ref <model_pretrained>
     paths:
         lm: !ref <pretrained_path>/lm.ckpt
         tokenizer: !ref <pretrained_path>/tokenizer.ckpt
-        model: !ref <pretrained_path>/asr.ckpt
+        model: !ref <pretrained_path>/model.ckpt
+
+new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
+    model_to_adapt: !ref <encoder>
+    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    all_linear: True
+    manual_adapter_insertion: True
+    adapter_kwargs:
+        rank: 8
+
+new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
+    model_to_adapt: !ref <decoder>
+    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    all_linear: True
+    manual_adapter_insertion: True
+    adapter_kwargs:
+        rank: 8
+
+model: !new:torch.nn.ModuleList
+    - - !ref <new_encoder>
+      - !ref <embedding>
+      - !ref <new_decoder>
+      - !ref <ctc_lin>
+      - !ref <seq_lin>
Overwriting train_lora.patch
!patch train.yaml -i train_lora.patch -o train_lora.yaml
patching file train_lora.yaml (read from train.yaml)

因为我们使用预训练器加载预训练参数,所以必须在加载预训练参数后插入此代码以插入适配器。

这是yaml中manual_adapter_insertion: True的原因以及对训练代码的以下简要更改:

%%writefile train_lora_py.patch
--- train.py	2024-10-07 14:57:21.534381751 -0400
+++ train_lora.py	2024-10-07 19:33:12.839895913 -0400
@@ -473,6 +473,8 @@
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     hparams["pretrainer"].collect_files()
     hparams["pretrainer"].load_collected()
+    hparams["new_encoder"].insert_adapters()
+    hparams["new_decoder"].insert_adapters()
 
     # Trainer initialization
     asr_brain = ASR(
Overwriting train_lora_py.patch
!patch train.py -i train_lora_py.patch -o train_lora.py
patching file train_lora.py (read from train.py)

训练适应模型

训练方式与之前相同,使用更新后的lora文件。适应后的模型设计为原位替换。注意可训练参数的数量减少到接近原始参数的1%。

!python train_lora.py train_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.4807% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|███████████████████████████| 760/760 [04:09<00:00,  3.04it/s, train_loss=1]
100%|█████████████████████████████████████████| 545/545 [01:40<00:00,  5.42it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.00 - valid loss: 1.26, valid CER: 7.29, valid WER: 19.16
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|███████████████████████████████████████| 1310/1310 [12:55<00:00,  1.69it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.26, test CER: 5.62, test WER: 17.05
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+latest

自定义适配器

我们这样设计是为了让你可以用peft适配器替换SpeechBrain适配器:

new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
    model_to_adapt: !ref <encoder>
-   adapter_class: !name:speechbrain.nnet.adapters.LoRA
+   adapter_class: !name:peft.tuners.lora.layer.Linear
    manual_adapter_insertion: True
    adapter_kwargs:
-       rank: 16
+       r: 16
+       adapter_name: lora

但这训练的内容与之前完全相同,所以我们不需要再重复一遍。也许更有趣的是设计一个自定义适配器:

%%file pool_lora.py

import torch

class PoolLoRA(torch.nn.Module):
    def __init__(self, target_module, stride=2, rank=16, alpha=1.0):
        super().__init__()

        input_size = target_module.weight.data.shape[1]
        output_size = target_module.weight.data.shape[0]
        
        # Disable gradient for pretrained module
        self.pretrained_module = target_module
        for param in self.pretrained_module.parameters():
            param.requires_grad = False
        device = target_module.weight.device

        self.adapter_down_scale = torch.nn.AvgPool1d(kernel_size=stride)
        self.adapter_down_proj = torch.nn.Linear(
            input_size // stride, rank, bias=False, device=device
        )   
        self.adapter_up_proj = torch.nn.Linear(
            rank, output_size, bias=False, device=device
        )   
        self.adapter_up_proj.weight.data.fill_(0.0)

        self.scaling = alpha / rank

    def forward(self, x: torch.Tensor):
        """Applies the LoRA Adapter.

        Arguments
        ---------
        x: torch.Tensor
            Input tensor to the adapter module.

        Returns
        -------
        The linear outputs
        """
        x_pretrained = self.pretrained_module(x)

        x_downsample = self.adapter_down_proj(self.adapter_down_scale(x))
        x_pool_lora = self.adapter_up_proj(x_downsample)
        
        return x_pretrained + x_pool_lora * self.scaling
Overwriting pool_lora.py
%%writefile train_pool_lora.patch
--- train_lora.yaml	2024-10-07 22:44:02.767830301 -0400
+++ train_pool_lora.yaml	2024-10-07 22:41:30.602641301 -0400
@@ -30,7 +30,7 @@
 NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
 
-output_folder: !ref results/crdnn_lora/<seed>
+output_folder: !ref results/crdnn_pool_lora/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -636,19 +636,21 @@
 
 new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
     model_to_adapt: !ref <encoder>
-    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    adapter_class: !name:pool_lora.PoolLoRA
     all_linear: True
     manual_adapter_insertion: True
     adapter_kwargs:
-        rank: 8
+        stride: 2
+        rank: 16
 
 new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
     model_to_adapt: !ref <decoder>
-    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    adapter_class: !name:pool_lora.PoolLoRA
     all_linear: True
     manual_adapter_insertion: True
     adapter_kwargs:
-        rank: 8
+        stride: 2
+        rank: 16
 
 model: !new:torch.nn.ModuleList
     - - !ref <new_encoder>
Overwriting train_pool_lora.patch
!patch train_lora.yaml -i train_pool_lora.patch -o train_pool_lora.yaml
patching file train_pool_lora.yaml (read from train_lora.yaml)
!python train_lora.py train_pool_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_pool_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.5210% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|████████████████████████| 760/760 [04:19<00:00,  2.93it/s, train_loss=0.98]
100%|█████████████████████████████████████████| 545/545 [01:44<00:00,  5.24it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 9.80e-01 - valid loss: 1.26, valid CER: 7.18, valid WER: 18.92
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|███████████████████████████████████████| 1310/1310 [14:07<00:00,  1.55it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.25, test CER: 5.61, test WER: 16.99
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+latest

结论

就是这样,感谢您的关注!继续前进,制作酷炫的适配器。

引用SpeechBrain

如果您在研究中或业务中使用SpeechBrain,请使用以下BibTeX条目引用它:

@misc{speechbrainV1,
  title={Open-Source Conversational AI with {SpeechBrain} 1.0},
  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
  year={2024},
  eprint={2407.00463},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
  year={2021},
  eprint={2106.04624},
  archivePrefix={arXiv},
  primaryClass={eess.AS},
  note={arXiv:2106.04624}
}