torch_frame.nn.decoder.TromptDecoder
- class TromptDecoder(in_channels: int, out_channels: int, num_prompts: int)[source]
基础类:
Decoder在“Trompt: Towards a Better Deep Neural Network for Tabular Data”论文中介绍的Trompt下游。
- forward(x: Tensor) Tensor[source]
将形状为
[batch_size, num_cols, channels]的x解码为形状为[batch_size, out_channels]的输出张量。- Parameters:
x (torch.Tensor) – 输入列方向的张量,形状为
[batch_size, num_cols, hidden_channels]。args (Any) – 额外参数。
kwargs (Any) – 额外的关键字参数。