from __future__ import annotations
import math
from typing import Any
from torch import Tensor
from torch.nn import (
BatchNorm1d,
Dropout,
LayerNorm,
Linear,
Module,
ReLU,
Sequential,
)
import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.encoder.stype_encoder import (
EmbeddingEncoder,
LinearEncoder,
StypeEncoder,
)
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
class FCResidualBlock(Module):
r"""Fully connected residual block.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
normalization (str, optional): The type of normalization to use.
:obj:`layer_norm`, :obj:`batch_norm`, or :obj:`None`.
(default: :obj:`layer_norm`)
dropout_prob (float): The dropout probability (default: `0.0`, i.e.,
no dropout).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
normalization: str | None = "layer_norm",
dropout_prob: float = 0.0,
) -> None:
super().__init__()
self.lin1 = Linear(in_channels, out_channels)
self.lin2 = Linear(out_channels, out_channels)
self.relu = ReLU()
self.dropout = Dropout(dropout_prob)
self.norm1: BatchNorm1d | LayerNorm | None
self.norm2: BatchNorm1d | LayerNorm | None
if normalization == "batch_norm":
self.norm1 = BatchNorm1d(out_channels)
self.norm2 = BatchNorm1d(out_channels)
elif normalization == "layer_norm":
self.norm1 = LayerNorm(out_channels)
self.norm2 = LayerNorm(out_channels)
else:
self.norm1 = self.norm2 = None
self.shortcut: Linear | None
if in_channels != out_channels:
self.shortcut = Linear(in_channels, out_channels)
else:
self.shortcut = None
def reset_parameters(self) -> None:
self.lin1.reset_parameters()
self.lin2.reset_parameters()
if self.norm1 is not None:
self.norm1.reset_parameters()
if self.norm2 is not None:
self.norm2.reset_parameters()
if self.shortcut is not None:
self.shortcut.reset_parameters()
def forward(self, x: Tensor) -> Tensor:
out = self.lin1(x)
out = self.norm1(out) if self.norm1 else out
out = self.relu(out)
out = self.dropout(out)
out = self.lin2(out)
out = self.norm2(out) if self.norm2 else out
out = self.relu(out)
out = self.dropout(out)
if self.shortcut is not None:
x = self.shortcut(x)
out = out + x
return out
[docs]class ResNet(Module):
r"""The ResNet model introduced in the
`"Revisiting Deep Learning Models for Tabular Data"
<https://arxiv.org/abs/2106.11959>`_ paper.
.. note::
For an example of using ResNet, see `examples/revisiting.py
<https://github.com/pyg-team/pytorch-frame/blob/master/examples/
revisiting.py>`_.
Args:
channels (int): The number of channels in the backbone layers.
out_channels (int): The number of output channels in the decoder.
num_layers (int): The number of layers in the backbone.
col_stats(dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]):
A dictionary that maps column name into stats.
Available as :obj:`dataset.col_stats`.
col_names_dict (dict[:class:`torch_frame.stype`, List[str]]): A
dictionary that maps stype to a list of column names. The column
names are sorted based on the ordering that appear in
:obj:`tensor_frame.feat_dict`. Available as
:obj:`tensor_frame.col_names_dict`.
stype_encoder_dict
(dict[:class:`torch_frame.stype`,
:class:`torch_frame.nn.encoder.StypeEncoder`], optional):
A dictionary mapping stypes into their stype encoders.
(default: :obj:`None`, will call :obj:`EmbeddingEncoder()`
for categorical feature and :obj:`LinearEncoder()` for
numerical feature)
normalization (str, optional): The type of normalization to use.
:obj:`batch_norm`, :obj:`layer_norm`, or :obj:`None`.
(default: :obj:`layer_norm`)
dropout_prob (float): The dropout probability (default: `0.2`).
"""
def __init__(
self,
channels: int,
out_channels: int,
num_layers: int,
col_stats: dict[str, dict[StatType, Any]],
col_names_dict: dict[torch_frame.stype, list[str]],
stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
| None = None,
normalization: str | None = "layer_norm",
dropout_prob: float = 0.2,
) -> None:
super().__init__()
if stype_encoder_dict is None:
stype_encoder_dict = {
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder(),
}
self.encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
num_cols = sum(
[len(col_names) for col_names in col_names_dict.values()])
in_channels = channels * num_cols
self.backbone = Sequential(*[
FCResidualBlock(
in_channels if i == 0 else channels,
channels,
normalization=normalization,
dropout_prob=dropout_prob,
) for i in range(num_layers)
])
self.decoder = Sequential(
LayerNorm(channels),
ReLU(),
Linear(channels, out_channels),
)
self.reset_parameters()
def reset_parameters(self) -> None:
self.encoder.reset_parameters()
for block in self.backbone:
block.reset_parameters()
self.decoder[0].reset_parameters()
self.decoder[-1].reset_parameters()
[docs] def forward(self, tf: TensorFrame) -> Tensor:
r"""Transforming :class:`TensorFrame` object into output prediction.
Args:
tf (TensorFrame): Input :class:`TensorFrame` object.
Returns:
torch.Tensor: Output of shape [batch_size, out_channels].
"""
x, _ = self.encoder(tf)
# Flattening the encoder output
x = x.view(x.size(0), math.prod(x.shape[1:]))
x = self.backbone(x)
out = self.decoder(x)
return out