工具

PyTorch LLaMA 模型。

LlamaAttention

多头注意力机制来自《Attention Is All You Need》论文。

LlamaDecoderLayer

Llama解码器层类。

LlamaMLP

LlamaMLP 类。

LlamaRMSNorm

LlamaRMSNorm 类。

LlamaRotaryEmbedding

Llama 旋转嵌入。

函数

apply_rotary_pos_emb

应用旋转位置嵌入。

expand_mask

将attention_mask从[bsz, seq_len]扩展到[bsz, 1, tgt_seq_len, src_seq_len]

make_causal_mask

制作用于双向自注意力的因果掩码。

repeat_kv

这相当于 torch.repeat_interleave(x, dim=1, repeats=n_rep)。

rotate_half

旋转输入的一半隐藏维度。

class LlamaAttention

基础:Module

来自《Attention Is All You Need》论文的多头注意力机制。

__init__(hidden_size, num_attention_heads, num_key_value_heads, max_position_embeddings, rope_theta)

LlamaAttention 的初始化函数。

forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)

LlamaAttention的前向函数。

Parameters:
  • hidden_states (张量) –

  • attention_mask (Tensor | None) –

  • position_ids (LongTensor | None) –

  • past_key_value (Tuple[Tensor] | None) –

  • output_attentions (bool) –

  • use_cache (bool) –

Return type:

元组[张量, 张量 | , 元组[张量] | ]

class LlamaDecoderLayer

基础:Module

Llama解码器层类。

__init__(index, hidden_size, intermediate_size=14336, rms_norm_eps=1e-05, num_attention_heads=32, num_key_value_heads=8, max_position_embeddings=131072, rope_theta=500000.0)

LlamaDecoderLayer的初始化函数。

forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)

LlamaDecoderLayer的前向函数。

Parameters:
  • hidden_states (张量) –

  • attention_mask (Tensor | None) –

  • position_ids (LongTensor | None) –

  • past_key_value (Tuple[Tensor] | None) –

  • output_attentions (bool | None) –

  • use_cache (bool | None) –

Return type:

元组[浮点张量, 元组[浮点张量, 浮点张量] | ]

class LlamaMLP

基础:Module

LlamaMLP 类。

__init__(hidden_size, intermediate_size)

LlamaMLP的初始化函数。

forward(x)

LlamaMLP的前向函数。

class LlamaRMSNorm

基础:Module

LlamaRMSNorm 类。

__init__(hidden_size, eps=1e-06)

LlamaRMSNorm 等同于 T5LayerNorm。

forward(hidden_states)

LlamaRMSNorm 的前向函数。

class LlamaRotaryEmbedding

基础:Module

Llama 旋转嵌入。

__init__(dim, max_position_embeddings=2048, base=10000, device=None)

LlamaRotaryEmbedding 的初始化函数。

forward(x, seq_len=None)

LlamaRotaryEmbedding 的前向函数。

apply_rotary_pos_emb(q, k, cos, sin, position_ids)

应用旋转位置嵌入。

expand_mask(mask, dtype, tgt_len=None)

将attention_mask从[bsz, seq_len]扩展到[bsz, 1, tgt_seq_len, src_seq_len]

Parameters:
  • mask (张量) –

  • dtype (dtype) –

  • tgt_len (int | None) –

make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)

制作用于双向自注意力的因果掩码。

Parameters:
  • input_ids_shape (大小) –

  • dtype (dtype) –

  • 设备 (device) –

  • past_key_values_length (int) –

repeat_kv(hidden_states, n_rep)

这相当于 torch.repeat_interleave(x, dim=1, repeats=n_rep)。

隐藏状态从 (batch, num_key_value_heads, seqlen, head_dim) 变为 (batch, num_attention_heads, seqlen, head_dim)

Parameters:
  • hidden_states (张量) –

  • n_rep (int) –

Return type:

张量

rotate_half(x)

旋转输入的一半隐藏维度。