Shortcuts

前馈

class torchtune.modules.FeedForward(*, gate_proj: Module, down_proj: Module, up_proj: Optional[Module] = None, activation: Module = SiLU())[source]

该类实现了从Llama2派生的前馈网络。

Parameters:
  • gate_proj (nn.Module) – 从输入维度到隐藏维度的投影,通过激活函数传递并与up_proj相乘。

  • down_proj (nn.Module) – 最终投影到输出维度。

  • up_proj (可选[nn.Module]) – 从输入维度到隐藏维度的投影,乘以激活函数(gate_proj)。

  • activation (nn.Module) – 使用的激活函数。默认为 nn.SiLU()。

forward(x: Tensor) Tensor[source]
Parameters:

x (torch.Tensor) – 输入张量,形状为 (..., in_dim),其中 in_dimgate_projup_proj 的输入维度。

Returns:

输出张量的形状为 (..., out_dim),其中 out_dimdown_proj 的输出维度。

Return type:

torch.Tensor