工具
PyTorch LLaMA 模型。
类
多头注意力机制来自《Attention Is All You Need》论文。 |
|
Llama解码器层类。 |
|
LlamaMLP 类。 |
|
LlamaRMSNorm 类。 |
|
Llama 旋转嵌入。 |
函数
应用旋转位置嵌入。 |
|
将attention_mask从[bsz, seq_len]扩展到[bsz, 1, tgt_seq_len, src_seq_len]。 |
|
制作用于双向自注意力的因果掩码。 |
|
这相当于 torch.repeat_interleave(x, dim=1, repeats=n_rep)。 |
|
旋转输入的一半隐藏维度。 |
- class LlamaAttention
基础:
Module来自《Attention Is All You Need》论文的多头注意力机制。
- __init__(hidden_size, num_attention_heads, num_key_value_heads, max_position_embeddings, rope_theta)
LlamaAttention 的初始化函数。
- forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
LlamaAttention的前向函数。
- Parameters:
hidden_states (张量) –
attention_mask (Tensor | None) –
position_ids (LongTensor | None) –
past_key_value (Tuple[Tensor] | None) –
output_attentions (bool) –
use_cache (bool) –
- Return type:
元组[张量, 张量 | 无, 元组[张量] | 无]
- class LlamaDecoderLayer
基础:
ModuleLlama解码器层类。
- __init__(index, hidden_size, intermediate_size=14336, rms_norm_eps=1e-05, num_attention_heads=32, num_key_value_heads=8, max_position_embeddings=131072, rope_theta=500000.0)
LlamaDecoderLayer的初始化函数。
- forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
LlamaDecoderLayer的前向函数。
- Parameters:
hidden_states (张量) –
attention_mask (Tensor | None) –
position_ids (LongTensor | None) –
past_key_value (Tuple[Tensor] | None) –
output_attentions (bool | None) –
use_cache (bool | None) –
- Return type:
元组[浮点张量, 元组[浮点张量, 浮点张量] | 无]
- class LlamaMLP
基础:
ModuleLlamaMLP 类。
- __init__(hidden_size, intermediate_size)
LlamaMLP的初始化函数。
- forward(x)
LlamaMLP的前向函数。
- class LlamaRMSNorm
基础:
ModuleLlamaRMSNorm 类。
- __init__(hidden_size, eps=1e-06)
LlamaRMSNorm 等同于 T5LayerNorm。
- forward(hidden_states)
LlamaRMSNorm 的前向函数。
- class LlamaRotaryEmbedding
基础:
ModuleLlama 旋转嵌入。
- __init__(dim, max_position_embeddings=2048, base=10000, device=None)
LlamaRotaryEmbedding 的初始化函数。
- forward(x, seq_len=None)
LlamaRotaryEmbedding 的前向函数。
- apply_rotary_pos_emb(q, k, cos, sin, position_ids)
应用旋转位置嵌入。
- expand_mask(mask, dtype, tgt_len=None)
将attention_mask从[bsz, seq_len]扩展到[bsz, 1, tgt_seq_len, src_seq_len]。
- Parameters:
mask (张量) –
dtype (dtype) –
tgt_len (int | None) –
- make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)
制作用于双向自注意力的因果掩码。
- Parameters:
input_ids_shape (大小) –
dtype (dtype) –
设备 (device) –
past_key_values_length (int) –
- repeat_kv(hidden_states, n_rep)
这相当于 torch.repeat_interleave(x, dim=1, repeats=n_rep)。
隐藏状态从 (batch, num_key_value_heads, seqlen, head_dim) 变为 (batch, num_attention_heads, seqlen, head_dim)
- Parameters:
hidden_states (张量) –
n_rep (int) –
- Return type:
张量
- rotate_half(x)
旋转输入的一半隐藏维度。