用于训练的任务类与数据输入系统
用于训练的任务类与数据输入系统
任务类
在ESpnet1中,我们有太多重复的Python模块。ESPnet2的主要目标之一是提供通用接口,使我们能够更专注于每个任务的独特部分。
Task 类是一个通用系统,用于为每个任务(如ASR、TTS、LM等)构建训练工具,其灵感来源于Fairseq Task的设计理念。要构建您的任务,您只需继承AbsTask类即可:
from espnet2.tasks.abs_task import AbsTask
from espnet2.train.abs_espnet_model import AbsESPnetModel
class NewModel(ESPnetModel):
def forward(self, input, target):
(...)
# loss: The loss of the task. Must be a scalar value.
# stats: A dict object, used for logging and validation criterion
# weight: A scalar value that is used for normalization of loss and stats values among each mini-batches.
# In many cases, this value should be equal to the mini-batch-size
return loss, stats, weight
class NewTask(AbsTask):
@classmethod
def add_task_arguments(cls, parser):
parser.add_arguments(...)
(...)
@classmethod
def build_collate_fn(cls, args: argparse.Namespace)
(...)
@classmethod
def build_preprocess_fn(cls, args, train):
(...)
@classmethod
def required_data_names(cls, inference: bool = False):
(...)
@classmethod
def optional_data_names(cls, inference: bool = False):
(...)
@classmethod
def build_model(cls, args):
return NewModel(...)
if __name__ == "__main__":
# Start training
NewTask.main()数据输入系统
Espnet2还提供了一个命令行界面来描述训练语料库。相反,与fairseq或pytorch-lightning等训练系统不同,我们的Task类没有显式构建数据集的接口。这是因为我们目前只专注于与语音/文本相关的任务,因此暂时不需要如此通用的系统。
以下是命令行参数的一个示例:
python -m espnet2.bin.asr_train \
--train_data_path_and_name_and_type=/some/path/tr/wav.scp,speech,sound \
--train_data_path_and_name_and_type=/some/path/tr/token_int,text,text_int \
--valid_data_path_and_name_and_type=/some/path/dev/wav.scp,speech,sound \
--valid_data_path_and_name_and_type=/some/path/dev/token_int,text,text_int首先,我们的mini-batch始终是一个dict对象:
# In training iteration
for batch in iterator:
# e.g. batch = {"speech": ..., "text": ...}
# Forward
model(**batch)其中model与通过Task.build_model()构建的模型相同。
你可以灵活地使用--*_data_path_and_name_and_type来构建这个小批量对象。--*_data_path_and_name_and_type可以根据需要重复使用,每个--*_data_path_and_name_and_type对应小批量中的一个元素。同时请记住输入数据和目标数据之间没有区别。
--train_data_path_and_name_and_type 的参数应以逗号分隔的三个值给出,例如 。
key-name指定字典的键名file-path是数据源的文件/目录路径。file-format表示由file-path指定的文件格式。例如sound、kaldi_ark等。
scp 文件
你可以使用--help选项查看支持的文件格式。
python -m espnet2.bin.asr_train --help几乎所有格式都按照Kaldi-ASR规范被称为scp文件。scp只是一个文本文件,每行包含两列:第一列表示样本ID,第二列是某些值,例如文件路径、文本转录或数字序列。
- format=npy
样本ID_a /某个路径/a.npy 样本ID_b /某个路径/b.npy - 格式=音频
样本ID_a /某个路径/a.flac 样本ID_b /某个路径/a.wav - format=kaldi_ark
样本ID_a /some/path/a.ark:1234 样本ID_b /some/path/a.ark:5678 - format=text_int
sample_id_a 10 2 4 4 sample_id_b 3 2 0 1 6 2 - format=text
sample_id_a 你好世界 sample_id_b 今天下雨了
required_data_names() 和 optional_data_names()
虽然本系统可以创建任意字典,但每项任务都假定特定键是用于特定目的的。例如,ASR任务需要speech和text键,其值分别用作输入数据和目标数据。请再次查看Task类的方法:required_data_names()和optional_data_names()。
class NewTask(AbsTask):
@classmethod
def required_data_names(cls, inference: bool = False):
if not inference:
retval = ("input", "target")
else:
retval = ("input",)
return retval
@classmethod
def optional_data_names(cls, inference: bool = False):
retval = ("auxially_feature",)
return retvalrequired_data_names() 确定必需的数据名称,而 optional_data_names() 提供可选数据。这意味着其他名称可以通过命令行参数指定。
# The following is the expected argument
python -m new_task \
--train_data_path_and_name_and_type=filepath,input,sometype \
--train_data_path_and_name_and_type=filepath,target,sometype \
--train_data_path_and_name_and_type=filepath,auxially_feature,sometype
# The following raises an error
python -m new_task \
--train_data_path_and_name_and_type=filepath,unknown,sometype该系统的目的仅用于断言检查,如果觉得不必要,您可以通过--allow_variable_data_keys true关闭此检查。
# Ignore assertion checking for data names
python -m new_task \
--train_data_path_and_name_and_type=filepath,unknown_name,sometype \
--allow_variable_data_keys true自定义PyTorch数据加载器的collate_fn
Task 类有一个方法可以自定义 collate_fn:
class NewTask(AbsTask):
@classmethod
def build_collate_fn(cls, args: argparse.Namespace):
...collate_fn 是 torch.utils.data.DataLoader 的一个参数,它可以修改从数据加载器接收到的数据。例如:
def collate_fn(data):
# data is a list of the return value of Dataset class:
modified_data = (...touch data)
return modified_data
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset, collate_fn=collate_fn)
for modified_data in data_loader:
...参数类型由输入的dataset类决定,我们的数据集始终是espnet2.train.dataset.ESPnetDataset,其返回值是一个包含样本ID和张量字典的元组。
batch = ("sample_id", {"speech": tensor, "text": tensor})因此,类型是一个张量字典的列表。
data = [
("sample_id", {"speech": tensor, "text": tensor}),
("sample_id2", {"speech": tensor, "text": tensor}),
...
]在espnet2中,collate_fn的返回类型应为列表元组和张量字典,因此Task的collate_fn必须将数据类型转换为此格式。
for ids, batch in data_loader:
model(**batch)我们提供了通用的collate_fn函数,该函数可支持多种场景,因此您可能无需自定义。此collate_fn能够处理seq2seq任务中的变长序列特征:
- 数据集输出的序列张量的第一轴必须是长度轴,例如:(Length, Dim)、(Length, Dim, Dim2) 或 (Length, ...)
- It's not necessary to make the lengths of each sample unified and they are stacked with zero-padding.
- 填充值可以修改。
from espnet2.train.collate_fn import CommonCollateFn @classmethod def build_collate_fn(cls, args): # float_pad_value用于浮点张量,int_pad_value用于整型张量 return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
- 填充值可以修改。
- 同时也会附加表示每个样本长度的张量
batch = {"speech": ..., "speech_lengths": ..., "text": ..., "text_lengths": ...} - 如果特征不是序列数据,可以禁用此行为。
python -m new_task --train_data_path_and_name_and_type=filepath,foo,npy@classmethod def build_collate_fn(cls, args): return CommonCollateFn(not_sequence=["foo"])
