torch_frame.nn.conv.TromptConv
- class TromptConv(channels: int, num_cols: int, num_prompts: int, num_groups: int = 2)[source]
基础类:
TableConv在“Trompt: Towards a Better Deep Neural Network for Tabular Data”论文中介绍的Trompt单元。
- Parameters:
- forward(x: Tensor, x_prompt: Tensor) Tensor[source]
将
x和x_prompt转换为下一层的x_prompt。- Parameters:
x (torch.Tensor) – 基于特征的嵌入,形状为
[batch_size, num_cols, channels]x_prompt (torch.Tensor) – 输入提示嵌入的形状为
[batch_size, num_prompts, channels]。
- Returns:
- 输出下一层的提示嵌入。其
形状为
[batch_size, num_prompts, channels]。
- Return type: