speechbrain.lobes.models.huggingface_transformers.gpt 模块
该模块支持集成huggingface预训练的GPT2LMHeadModel模型。
需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html
- Authors
Pooneh Mousavi 2023
西蒙·阿尔吉西 2023
摘要
类:
该模块支持集成HuggingFace预训练的GPT模型。 |
参考
- class speechbrain.lobes.models.huggingface_transformers.gpt.GPT(source, save_path, freeze=False, max_new_tokens=200, min_length=1, top_k=45, top_p=0.9, num_beams=8, eos_token_id=50258, early_stopping=True)[source]
-
- This lobe enables the integration of HuggingFace pretrained GPT model.
- Transformer from HuggingFace needs to be installed:
模型可以进行微调。它将自动从HuggingFace下载模型或使用本地路径。
- Parameters:
source (str) – HuggingFace 中心名称:例如 “gpt2”
save_path (str) – 下载模型的路径(目录)。
freeze (bool (默认值: False)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。
max_new_tokens (int) – 允许的最大新令牌数量。
min_length (int) – 输入令牌的最小数量
top_k (int) – 保留的顶部结果数量
top_p (float) – 保留的顶部结果的比例
num_beams (int) – 解码器光束数量
eos_token_id (int) – 句子结束标记的索引。
early_stopping (int) – 是否提前停止训练。
Example
>>> model_hub = "gpt2" >>> save_path = "savedir" >>> model = GPT(model_hub, save_path) >>> tokens = torch.tensor([[1, 1]]) >>> tokens_type = torch.tensor([[1, 1]]) >>> attention_mask = torch.tensor([[1, 1]]) >>> outputs = model(tokens, tokens_type, attention_mask)
- forward(input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor)[source]
接收对话历史作为输入,并返回相应的回复。
- Parameters:
input_ids (torch.Tensor) – 一批要转换为特征的输入ID。
token_type_ids (torch.Tensor) – 输入ID中每个标记的标记类型(说话者)。
attention_mask (torch.Tensor) – 一批attention_mask。
- Returns:
output – 回复对话
- Return type:
torch.Tensor
- generate(input_ids: Tensor, token_type_ids, attention_mask: Tensor, decoder_type='greedy')[source]
接收对话历史作为输入,并返回相应的回复。
- Parameters:
input_ids (torch.Tensor) – 一批输入ID,这些是对话上下文的标记
token_type_ids (torch.Tensor)
attention_mask (torch.Tensor) – 一批attention_mask。
decoder_type (str) – 它显示了自回归解码的策略,可以是束搜索或贪婪。
- Returns:
hyp – 对话回复。
- Return type:
torch.Tensor