torch_frame.nn.models.trompt 的源代码

from __future__ import annotations

from typing import Any

import torch
from torch import Tensor
from torch.nn import LayerNorm, Module, ModuleList, Parameter, ReLU, Sequential

import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.conv import TromptConv
from torch_frame.nn.decoder import TromptDecoder
from torch_frame.nn.encoder.stype_encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeEncoder,
)
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
from torch_frame.typing import NAStrategy


[docs]class Trompt(Module): r"""The Trompt model introduced in the `"Trompt: Towards a Better Deep Neural Network for Tabular Data" <https://arxiv.org/abs/2305.18446>`_ paper. .. note:: For an example of using Trompt, see `examples/trompt.py <https://github.com/pyg-team/pytorch-frame/blob/master/examples/ trompt.py>`_. Args: channels (int): Hidden channel dimensionality out_channels (int): Output channels dimensionality num_prompts (int): Number of prompt columns. num_layers (int, optional): Number of :class:`TromptConv` layers. (default: :obj:`6`) col_stats(Dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]): A dictionary that maps column name into stats. Available as :obj:`dataset.col_stats`. col_names_dict (Dict[:obj:`torch_frame.stype`, List[str]]): A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in :obj:`tensor_frame.feat_dict`. Available as :obj:`tensor_frame.col_names_dict`. stype_encoder_dicts (list[dict[:class:`torch_frame.stype`, :class:`torch_frame.nn.encoder.StypeEncoder`]], optional): A list of :obj:`num_layers` dictionaries that each dictionary maps stypes into their stype encoders. (default: :obj:`None`, will call :obj:`EmbeddingEncoder()` for categorical feature and :obj:`LinearEncoder()` for numerical feature) """ def __init__( self, channels: int, out_channels: int, num_prompts: int, num_layers: int, # kwargs for encoder col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None, ) -> None: super().__init__() if num_layers <= 0: raise ValueError( f"num_layers must be a positive integer (got {num_layers})") self.channels = channels self.out_channels = out_channels self.num_layers = num_layers num_cols = sum( [len(col_names) for col_names in col_names_dict.values()]) self.x_prompt = Parameter(torch.empty(num_prompts, channels)) self.encoders = ModuleList() self.trompt_convs = ModuleList() for i in range(num_layers): if stype_encoder_dicts is None: stype_encoder_dict_layer = { stype.categorical: EmbeddingEncoder( post_module=LayerNorm(channels), na_strategy=NAStrategy.MOST_FREQUENT, ), stype.numerical: LinearEncoder( post_module=Sequential( ReLU(), LayerNorm(channels), ), na_strategy=NAStrategy.MEAN, ), } else: stype_encoder_dict_layer = stype_encoder_dicts[i] self.encoders.append( StypeWiseFeatureEncoder( out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict=stype_encoder_dict_layer, )) self.trompt_convs.append( TromptConv(channels, num_cols, num_prompts)) # Decoder is shared across layers. self.trompt_decoder = TromptDecoder(channels, out_channels, num_prompts) self.reset_parameters() def reset_parameters(self) -> None: torch.nn.init.normal_(self.x_prompt, std=0.01) for encoder in self.encoders: encoder.reset_parameters() for trompt_conv in self.trompt_convs: trompt_conv.reset_parameters() self.trompt_decoder.reset_parameters()
[docs] def forward(self, tf: TensorFrame) -> Tensor: r"""Transforming :class:`TensorFrame` object into a series of output predictions at each layer. Used during training to compute layer-wise loss. Args: tf (:class:`torch_frame.TensorFrame`): Input :class:`TensorFrame` object. Returns: torch.Tensor: Output predictions stacked across layers. The shape is :obj:`[batch_size, num_layers, out_channels]`. """ batch_size = len(tf) outs = [] # [batch_size, num_prompts, channels] x_prompt = self.x_prompt.repeat(batch_size, 1, 1) for i in range(self.num_layers): # [batch_size, num_cols, channels] x, _ = self.encoders[i](tf) # [batch_size, num_prompts, channels] x_prompt = self.trompt_convs[i](x, x_prompt) # [batch_size, out_channels] out = self.trompt_decoder(x_prompt) # [batch_size, 1, out_channels] out = out.view(batch_size, 1, self.out_channels) outs.append(out) # [batch_size, num_layers, out_channels] stacked_out = torch.cat(outs, dim=1) return stacked_out