前馈¶
- class torchtune.modules.FeedForward(*, gate_proj: Module, down_proj: Module, up_proj: Optional[Module] = None, activation: Module = SiLU())[source]¶
该类实现了从Llama2派生的前馈网络。
- Parameters:
gate_proj (nn.Module) – 从输入维度到隐藏维度的投影,通过激活函数传递并与up_proj相乘。
down_proj (nn.Module) – 最终投影到输出维度。
up_proj (可选[nn.Module]) – 从输入维度到隐藏维度的投影,乘以激活函数(gate_proj)。
activation (nn.Module) – 使用的激活函数。默认为 nn.SiLU()。
- forward(x: Tensor) Tensor[source]¶
- Parameters:
x (torch.Tensor) – 输入张量,形状为
(..., in_dim),其中in_dim是gate_proj和up_proj的输入维度。- Returns:
输出张量的形状为
(..., out_dim),其中out_dim是down_proj的输出维度。- Return type: