深度表格模型的模块化设计
我们的关键观察是,许多表格深度学习模型都遵循三个组件的模块化设计:
如下图所示:
首先,输入的
DataFrame具有不同的列,它被转换为TensorFrame,其中列根据它们的stype(语义类型,如分类、数值和文本)进行组织。然后,
TensorFrame被输入到FeatureEncoder中,它将每个stype特征转换为一个三维的Tensor。不同
stypes的Tensors随后被连接成一个形状为[batch_size, num_cols, num_channels]的单个Tensorx。然后通过
TableConvs迭代更新Tensorx。更新后的
Tensorx作为输入传递给Decoder,以生成形状为[batch_size, out_channels]的输出Tensor。
1. FeatureEncoder
FeatureEncoder 将输入的 TensorFrame 转换为 x,一个大小为 [batch_size, num_cols, channels] 的 torch.Tensor。
该类可以包含可学习的参数和 NaN(缺失值)处理。
StypeWiseFeatureEncoder 继承自 FeatureEncoder。
它接受 TensorFrame 作为输入,并通过 stype_encoder_dict 指定的特定类型的特征编码器应用于每个 Tensor 的 stype 以获取每个 stype 的嵌入。
然后,不同 stypes 的嵌入被连接起来,形成最终的3维 Tensor x,其形状为 [batch_size, num_cols, channels]。
注意
存在面向用户和内部的stypes类型。
面向用户的stypes在Dataset级别上声明,用户可以在给定的DataFrame中为每一列指定stype。
在物化过程中,面向用户的stype的原始数据将被转换为内部stype的数据。
我们将内部stype称为面向用户的stype的父级。
例如,stype.text_embedded是一个面向用户的stype,因为它声明了存储在DataFrame中的原始数据的语义类型。
在物化过程中,我们将存储为文本的原始数据转换为嵌入,这使得它与存储为stype.embedding的数据没有区别。
因此,列的相应语义类型在TensorFrame中变为stype.embedding。
我们将stype.embedding视为stype.text_embedded的父类。
在stype_encoder_dict中仅支持父语义类型。
这种设计的动机是,在内部,相同stype的数据可以分组以提高效率。
以下是StypeWiseFeatureEncoder的示例用法,包括
EmbeddingEncoder用于编码stype.categorical列,
LinearEmbeddingEncoder用于编码stype.embedding列,
以及LinearEncoder用于编码stype.numerical列。
from torch_frame import stype
from torch_frame.nn import (
StypeWiseFeatureEncoder,
EmbeddingEncoder,
LinearEmbeddingEncoder,
LinearEncoder,
)
stype_encoder_dict = {
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder(),
stype.embedding: LinearEmbeddingEncoder(),
}
encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
还有其他实现的编码器,例如用于numerical列的LinearBucketEncoder和ExcelFormerEncoder。
请参阅torch_frame.nn以获取内置编码器的完整列表。
你也可以通过继承StypeEncoder为给定的stype实现自定义编码器。
2. TableConv
表格卷积层继承自 TableConv。
它接受形状为 [batch_size, num_cols, channels] 的三维 Tensor x 作为输入,
并根据其他列的嵌入更新列嵌入;从而建模不同列值之间的复杂交互。
下面,我们展示了一个基于自注意力的简单表格卷积,用于建模列之间的交互。
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import TableConv
class SelfAttentionConv(TableConv):
def __init__(self, channels: int):
super().__init__()
self.channels = channels
# Linear functions for modeling key/query/value in self-attention.
self.lin_k = Linear(channels, channels)
self.lin_q = Linear(channels, channels)
self.lin_v = Linear(channels, channels)
def forward(self, x: Tensor) -> Tensor:
# [batch_size, num_cols, channels]
x_key = self.lin_k(x)
x_query = self.lin_q(x)
x_value = self.lin_v(x)
prod = x_query.bmm(x_key.transpose(2, 1)) / math.sqrt(self.channels)
# Attention weights between all pairs of columns.
attn = F.softmax(prod, dim=-1)
# Mix `x_value` based on the attention weights
out = attn.bmm(x_value)
return out
初始化和调用它是直接的。
conv = SelfAttentionConv(32)
x = conv(x)
查看 torch_frame.nn 获取内置卷积层的完整列表。
3. Decoder
Decoder 将输入的 Tensor x 转换为 out,一个形状为 [batch_size, out_channels] 的 Tensor,表示原始 DataFrame 的行嵌入。
下面是一个简单的Decoder示例,它对列嵌入进行均值池化,然后进行线性变换。
import torch
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import Decoder
class MeanDecoder(Decoder):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x: Tensor) -> Tensor:
# Mean pooling over the column dimension
# [batch_size, num_cols, in_channels] -> [batch_size, in_channels]
out = torch.mean(x, dim=1)
# [batch_size, out_channels]
return self.lin(out)
查看 torch_frame.nn 获取内置解码器的完整列表。