深度表格模型的模块化设计

我们的关键观察是,许多表格深度学习模型都遵循三个组件的模块化设计:

  1. FeatureEncoder

  2. TableConv

  3. Decoder

如下图所示:

../_images/modular.png
  • 首先,输入的DataFrame具有不同的列,它被转换为TensorFrame,其中列根据它们的stype(语义类型,如分类、数值和文本)进行组织。

  • 然后,TensorFrame 被输入到 FeatureEncoder 中,它将每个 stype 特征转换为一个三维的 Tensor

  • 不同stypesTensors随后被连接成一个形状为[batch_size, num_cols, num_channels]的单个Tensor x

  • 然后通过TableConvs迭代更新Tensor x

  • 更新后的Tensor x作为输入传递给Decoder,以生成形状为[batch_size, out_channels]的输出Tensor

1. FeatureEncoder

FeatureEncoder 将输入的 TensorFrame 转换为 x,一个大小为 [batch_size, num_cols, channels]torch.Tensor。 该类可以包含可学习的参数和 NaN(缺失值)处理。

StypeWiseFeatureEncoder 继承自 FeatureEncoder。 它接受 TensorFrame 作为输入,并通过 stype_encoder_dict 指定的特定类型的特征编码器应用于每个 Tensorstype 以获取每个 stype 的嵌入。 然后,不同 stypes 的嵌入被连接起来,形成最终的3维 Tensor x,其形状为 [batch_size, num_cols, channels]

注意

存在面向用户和内部的stypes类型。

面向用户的stypesDataset级别上声明,用户可以在给定的DataFrame中为每一列指定stype。 在物化过程中,面向用户的stype的原始数据将被转换为内部stype的数据。 我们将内部stype称为面向用户的stype的父级。 例如,stype.text_embedded是一个面向用户的stype,因为它声明了存储在DataFrame中的原始数据的语义类型。

在物化过程中,我们将存储为文本的原始数据转换为嵌入,这使得它与存储为stype.embedding的数据没有区别。 因此,列的相应语义类型在TensorFrame中变为stype.embedding。 我们将stype.embedding视为stype.text_embedded的父类。 在stype_encoder_dict中仅支持父语义类型。 这种设计的动机是,在内部,相同stype的数据可以分组以提高效率。

以下是StypeWiseFeatureEncoder的示例用法,包括 EmbeddingEncoder用于编码stype.categorical列, LinearEmbeddingEncoder用于编码stype.embedding列, 以及LinearEncoder用于编码stype.numerical列。

from torch_frame import stype
from torch_frame.nn import (
    StypeWiseFeatureEncoder,
    EmbeddingEncoder,
    LinearEmbeddingEncoder,
    LinearEncoder,
)

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.embedding: LinearEmbeddingEncoder(),
}

encoder = StypeWiseFeatureEncoder(
    out_channels=channels,
    col_stats=col_stats,
    col_names_dict=col_names_dict,
    stype_encoder_dict=stype_encoder_dict,
)

还有其他实现的编码器,例如用于numerical列的LinearBucketEncoderExcelFormerEncoder。 请参阅torch_frame.nn以获取内置编码器的完整列表。

你也可以通过继承StypeEncoder为给定的stype实现自定义编码器。

2. TableConv

表格卷积层继承自 TableConv。 它接受形状为 [batch_size, num_cols, channels] 的三维 Tensor x 作为输入, 并根据其他列的嵌入更新列嵌入;从而建模不同列值之间的复杂交互。 下面,我们展示了一个基于自注意力的简单表格卷积,用于建模列之间的交互。

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import TableConv

class SelfAttentionConv(TableConv):
  def __init__(self, channels: int):
      super().__init__()
      self.channels = channels
      # Linear functions for modeling key/query/value in self-attention.
      self.lin_k = Linear(channels, channels)
      self.lin_q = Linear(channels, channels)
      self.lin_v = Linear(channels, channels)

  def forward(self, x: Tensor) -> Tensor:
      # [batch_size, num_cols, channels]
      x_key = self.lin_k(x)
      x_query = self.lin_q(x)
      x_value = self.lin_v(x)
      prod = x_query.bmm(x_key.transpose(2, 1)) / math.sqrt(self.channels)
      # Attention weights between all pairs of columns.
      attn = F.softmax(prod, dim=-1)
      # Mix `x_value` based on the attention weights
      out = attn.bmm(x_value)
      return out

初始化和调用它是直接的。

conv = SelfAttentionConv(32)
x = conv(x)

查看 torch_frame.nn 获取内置卷积层的完整列表。

3. Decoder

Decoder 将输入的 Tensor x 转换为 out,一个形状为 [batch_size, out_channels]Tensor,表示原始 DataFrame 的行嵌入。

下面是一个简单的Decoder示例,它对列嵌入进行均值池化,然后进行线性变换。

import torch
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import Decoder

class MeanDecoder(Decoder):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x: Tensor) -> Tensor:
        # Mean pooling over the column dimension
        # [batch_size, num_cols, in_channels] -> [batch_size, in_channels]
        out = torch.mean(x, dim=1)
        # [batch_size, out_channels]
        return self.lin(out)

查看 torch_frame.nn 获取内置解码器的完整列表。