speechbrain.lobes.models.huggingface_transformers.gpt 模块

该模块支持集成huggingface预训练的GPT2LMHeadModel模型。

需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html

Authors
  • Pooneh Mousavi 2023

  • 西蒙·阿尔吉西 2023

摘要

类:

GPT

该模块支持集成HuggingFace预训练的GPT模型。

参考

class speechbrain.lobes.models.huggingface_transformers.gpt.GPT(source, save_path, freeze=False, max_new_tokens=200, min_length=1, top_k=45, top_p=0.9, num_beams=8, eos_token_id=50258, early_stopping=True)[source]

基础类: HFTransformersInterface

This lobe enables the integration of HuggingFace pretrained GPT model.
Source paper whisper:

https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf

Transformer from HuggingFace needs to be installed:

https://huggingface.co/transformers/installation.html

模型可以进行微调。它将自动从HuggingFace下载模型或使用本地路径。

Parameters:
  • source (str) – HuggingFace 中心名称:例如 “gpt2”

  • save_path (str) – 下载模型的路径(目录)。

  • freeze (bool (默认值: False)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。

  • max_new_tokens (int) – 允许的最大新令牌数量。

  • min_length (int) – 输入令牌的最小数量

  • top_k (int) – 保留的顶部结果数量

  • top_p (float) – 保留的顶部结果的比例

  • num_beams (int) – 解码器光束数量

  • eos_token_id (int) – 句子结束标记的索引。

  • early_stopping (int) – 是否提前停止训练。

Example

>>> model_hub = "gpt2"
>>> save_path = "savedir"
>>> model = GPT(model_hub, save_path)
>>> tokens = torch.tensor([[1, 1]])
>>> tokens_type = torch.tensor([[1, 1]])
>>> attention_mask = torch.tensor([[1, 1]])
>>> outputs = model(tokens, tokens_type, attention_mask)
forward(input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor)[source]

接收对话历史作为输入,并返回相应的回复。

Parameters:
  • input_ids (torch.Tensor) – 一批要转换为特征的输入ID。

  • token_type_ids (torch.Tensor) – 输入ID中每个标记的标记类型(说话者)。

  • attention_mask (torch.Tensor) – 一批attention_mask。

Returns:

output – 回复对话

Return type:

torch.Tensor

generate(input_ids: Tensor, token_type_ids, attention_mask: Tensor, decoder_type='greedy')[source]

接收对话历史作为输入,并返回相应的回复。

Parameters:
  • input_ids (torch.Tensor) – 一批输入ID,这些是对话上下文的标记

  • token_type_ids (torch.Tensor)

  • attention_mask (torch.Tensor) – 一批attention_mask。

  • decoder_type (str) – 它显示了自回归解码的策略,可以是束搜索或贪婪。

Returns:

hyp – 对话回复。

Return type:

torch.Tensor