深度表格模型的模块化设计
我们的关键观察是,许多表格深度学习模型都遵循三个组件的模块化设计:
如下图所示:

首先,输入的
DataFrame
具有不同的列,它被转换为TensorFrame
,其中列根据它们的stype
(语义类型,如分类、数值和文本)进行组织。然后,
TensorFrame
被输入到FeatureEncoder
中,它将每个stype
特征转换为一个三维的Tensor
。不同
stypes
的Tensors
随后被连接成一个形状为[batch_size, num_cols, num_channels]
的单个Tensor
x
。然后通过
TableConvs
迭代更新Tensor
x
。更新后的
Tensor
x
作为输入传递给Decoder
,以生成形状为[batch_size, out_channels]
的输出Tensor
。
1. FeatureEncoder
FeatureEncoder
将输入的 TensorFrame
转换为 x
,一个大小为 [batch_size, num_cols, channels]
的 torch.Tensor
。
该类可以包含可学习的参数和 NaN(缺失值)处理。
StypeWiseFeatureEncoder
继承自 FeatureEncoder
。
它接受 TensorFrame
作为输入,并通过 stype_encoder_dict
指定的特定类型的特征编码器应用于每个 Tensor
的 stype
以获取每个 stype
的嵌入。
然后,不同 stypes
的嵌入被连接起来,形成最终的3维 Tensor
x
,其形状为 [batch_size, num_cols, channels]
。
注意
存在面向用户和内部的stypes
类型。
面向用户的stypes
在Dataset
级别上声明,用户可以在给定的DataFrame
中为每一列指定stype
。
在物化过程中,面向用户的stype
的原始数据将被转换为内部stype
的数据。
我们将内部stype
称为面向用户的stype
的父级。
例如,stype.text_embedded
是一个面向用户的stype
,因为它声明了存储在DataFrame
中的原始数据的语义类型。
在物化过程中,我们将存储为文本的原始数据转换为嵌入,这使得它与存储为stype.embedding
的数据没有区别。
因此,列的相应语义类型在TensorFrame
中变为stype.embedding
。
我们将stype.embedding
视为stype.text_embedded
的父类。
在stype_encoder_dict
中仅支持父语义类型。
这种设计的动机是,在内部,相同stype
的数据可以分组以提高效率。
以下是StypeWiseFeatureEncoder
的示例用法,包括
EmbeddingEncoder
用于编码stype.categorical
列,
LinearEmbeddingEncoder
用于编码stype.embedding
列,
以及LinearEncoder
用于编码stype.numerical
列。
from torch_frame import stype
from torch_frame.nn import (
StypeWiseFeatureEncoder,
EmbeddingEncoder,
LinearEmbeddingEncoder,
LinearEncoder,
)
stype_encoder_dict = {
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder(),
stype.embedding: LinearEmbeddingEncoder(),
}
encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
还有其他实现的编码器,例如用于numerical
列的LinearBucketEncoder
和ExcelFormerEncoder
。
请参阅torch_frame.nn
以获取内置编码器的完整列表。
你也可以通过继承StypeEncoder
为给定的stype
实现自定义编码器。
2. TableConv
表格卷积层继承自 TableConv
。
它接受形状为 [batch_size, num_cols, channels]
的三维 Tensor
x
作为输入,
并根据其他列的嵌入更新列嵌入;从而建模不同列值之间的复杂交互。
下面,我们展示了一个基于自注意力的简单表格卷积,用于建模列之间的交互。
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import TableConv
class SelfAttentionConv(TableConv):
def __init__(self, channels: int):
super().__init__()
self.channels = channels
# Linear functions for modeling key/query/value in self-attention.
self.lin_k = Linear(channels, channels)
self.lin_q = Linear(channels, channels)
self.lin_v = Linear(channels, channels)
def forward(self, x: Tensor) -> Tensor:
# [batch_size, num_cols, channels]
x_key = self.lin_k(x)
x_query = self.lin_q(x)
x_value = self.lin_v(x)
prod = x_query.bmm(x_key.transpose(2, 1)) / math.sqrt(self.channels)
# Attention weights between all pairs of columns.
attn = F.softmax(prod, dim=-1)
# Mix `x_value` based on the attention weights
out = attn.bmm(x_value)
return out
初始化和调用它是直接的。
conv = SelfAttentionConv(32)
x = conv(x)
查看 torch_frame.nn
获取内置卷积层的完整列表。
3. Decoder
Decoder
将输入的 Tensor
x
转换为 out
,一个形状为 [batch_size, out_channels]
的 Tensor
,表示原始 DataFrame
的行嵌入。
下面是一个简单的Decoder
示例,它对列嵌入进行均值池化,然后进行线性变换。
import torch
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import Decoder
class MeanDecoder(Decoder):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x: Tensor) -> Tensor:
# Mean pooling over the column dimension
# [batch_size, num_cols, in_channels] -> [batch_size, in_channels]
out = torch.mean(x, dim=1)
# [batch_size, out_channels]
return self.lin(out)
查看 torch_frame.nn
获取内置解码器的完整列表。