speechbrain.inference.text 模块

指定文本处理模块的推理接口。

Authors:
  • 阿库·罗赫 2021

  • 彼得·普兰廷加 2021

  • 洛伦·卢戈斯奇 2020

  • Mirco Ravanelli 2020

  • Titouan Parcollet 2021

  • 阿卜杜勒·赫巴 2021

  • 安德烈亚斯·诺茨 2022, 2023

  • Pooneh Mousavi 2023

  • Sylvain de Langen 2023

  • 阿德尔·穆门 2023

  • 普拉迪亚·坎达尔卡 2023

摘要

类:

GPTResponseGenerator

一个即用型的响应生成器模型

GraphemeToPhoneme

一个预训练模型实现,用于将原始自然语言文本作为输入的Grapheme-to-Phoneme (G2P) 模型。

Llama2ResponseGenerator

一个即用型的响应生成器模型

ResponseGenerator

一个即用型的响应生成器模型

参考

class speechbrain.inference.text.GraphemeToPhoneme(*args, **kwargs)[source]

基础类: Pretrained, EncodeDecodePipelineMixin

一个预训练模型实现,用于将原始自然语言文本作为输入的Grapheme-to-Phoneme (G2P) 模型

Parameters:
  • *args (元组)

  • **kwargs (dict) – 参数被转发到 Pretrained 父类。

Example

>>> text = ("English is tough. It can be understood "
...         "through thorough thought though")
>>> from speechbrain.inference.text import GraphemeToPhoneme
>>> tmpdir = getfixture('tmpdir')
>>> g2p = GraphemeToPhoneme.from_hparams('path/to/model', savedir=tmpdir) 
>>> phonemes = g2p.g2p(text) 
INPUT_STATIC_KEYS = ['txt']
OUTPUT_KEYS = ['phonemes']
property phonemes

返回可用的音素

property language

返回此模型可用的语言

g2p(text)[source]

执行字形到音素的转换

Parameters:

文本 (strlist[str]) – 要编码为音素的单个字符串 - 或一系列字符串

Returns:

result – 如果提供了单个示例,则返回值是音素的单个列表

Return type:

list

load_dependencies()[source]

加载任何相关的模型依赖项

__call__(text)[source]

一个方便的调用包装器 - 与G2P相同

Parameters:

文本 (strlist[str]) – 要编码为音素的单个字符串 - 或一个字符串序列

Returns:

result – 如果提供了单个示例,则返回值是音素的单个列表

Return type:

list

forward(noisy, lengths=None)[source]

对噪声输入进行增强

class speechbrain.inference.text.ResponseGenerator(*args, **kwargs)[source]

基础类:Pretrained

一个即用型的响应生成器模型

该类可用于根据用户输入生成并继续对话。 给定的YAML必须包含*_NEEDED[]列表中指定的字段。 它需要与custom.py一起使用,以加载带有添加的标记(如bos、eos和说话者标记)的扩展模型。

Parameters:
  • *args (元组)

  • **kwargs (dict) – 参数被转发到 Pretrained 父类。

MODULES_NEEDED = ['model']
generate_response(turn)[source]

根据用户的输入完成对话。 :param turn: 用户输入,即对话的最后一轮。 :type turn: str

Returns:

根据对话历史生成的用户输入响应。

Return type:

响应

prepare_input()[source]

用户应根据自己的任务修改此函数。

generate()[source]

用户应根据自己的任务修改此函数。

class speechbrain.inference.text.GPTResponseGenerator(*args, **kwargs)[source]

基础类:ResponseGenerator

一个即用型的响应生成器模型

该类可用于根据用户输入生成并继续对话。 给定的YAML必须包含*_NEEDED[]列表中指定的字段。 它需要与custom.py一起使用,以加载扩展的GPT模型,并添加诸如bos、eos和说话者标记等额外标记。

Parameters:
  • *args (元组)

  • **kwargs (dict) – 参数被转发到 Pretrained 父类。

Example

>>> from speechbrain.inference.text import GPTResponseGenerator
>>> tmpdir = getfixture("tmpdir")
>>> res_gen_model = GPTResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-GPT-Response_Generation",
... pymodule_file="custom.py")  
>>> response = res_gen_model.generate_response("I want to book a table for dinner")  
generate(inputs)[source]

根据用户的输入完成对话。

Parameters:

inputs (tuple) – history_bos 是带有适当说话者标记的标记化历史+输入值,在每个回合之前附加,history_token_type 根据说出该标记的人(用户或系统)确定每个标记的类型。

Returns:

根据对话历史为用户输入生成的假设。

Return type:

响应

prepare_input()[source]
Convert user input and previous histories to the format acceptable for GPT model.

它会附加所有先前的历史和输入,并根据max_history值进行截断。 然后,它对输入进行标记化,并生成额外的输入,以确定每个标记的类型(系统或用户)。

Returns:

  • history_bos (torch.Tensor) – 在每个回合前附加了适当说话者标记的标记化历史+输入值。

  • history_token_type (torch.LongTensor) – 基于谁说出该标记的每个标记的类型(用户或系统)

class speechbrain.inference.text.Llama2ResponseGenerator(*args, **kwargs)[source]

基础类:ResponseGenerator

一个即用型的响应生成器模型

该类可用于根据用户输入生成并继续对话。 给定的YAML必须包含*_NEEDED[]列表中指定的字段。 它需要与custom.py一起使用,以加载扩展的Llama2模型,并添加诸如bos、eos和说话者标记等额外标记。

Parameters:
  • *args (元组)

  • **kwargs (dict) – 参数被转发到 Pretrained 父类。

Example

>>> from speechbrain.inference.text import Llama2ResponseGenerator
>>> tmpdir = getfixture("tmpdir")
>>> res_gen_model = Llama2ResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-Llama2-Response_Generation",
... pymodule_file="custom.py")  
>>> response = res_gen_model.generate_response("I want to book a table for dinner")  
generate(inputs)[source]

根据用户的输入完成对话。 :param inputs: 传递给llama2模型进行生成的提示输入。 :type inputs: prompt_bos

Returns:

基于对话历史为用户输入生成的假设。

Return type:

响应

prepare_input()[source]
Convert user input and previous histories to the format acceptable for Llama2 model.

它会附加所有之前的历史记录和输入,并根据max_history值进行截断。 然后对输入进行标记化并添加提示。

Returns:

prompt_bos – 使用适当的提示符对历史记录+输入值进行标记化。

Return type:

torch.Tensor