Mambular: 使用 Mamba 架构的表格深度学习#

Mambular 是一个 Python 包,它将先进的深度学习架构的力量带到了表格数据中,提供了一系列用于回归、分类和分布回归任务的模型。Mambular 模型设计时考虑到了易用性,遵循 scikit-learn 的 BaseEstimator 接口,使其与熟悉的 scikit-learn 生态系统高度兼容。这意味着你可以像使用任何传统的 scikit-learn 模型一样,使用 Mambular 模型进行拟合、预测和评估,但同时具有深度学习带来的性能和灵活性。

特性#

  • 综合模型套件:包含回归、分类和分布回归模块,适用于广泛的表格数据任务。

  • 最先进的架构:利用了多种在处理表格数据方面表现出色的先进架构。Mambular模型包括强大的Mamba模块 Gu和Dao,并且可以包含双向处理以及特征交互层。

  • 无缝集成:专为与scikit-learn无缝协作而设计,便于轻松融入现有的机器学习管道、交叉验证和超参数调优工作流程。

  • 广泛预处理:配备了一个强大的预处理模块,支持多种数据转换技术,确保您的数据为模型训练做好最佳准备。

  • 类似Sklearn的API:熟悉的scikit-learn fitpredictpredict_proba 方法意味着对于已经习惯于scikit-learn的人来说,学习曲线最小。

  • PyTorch Lightning 内部机制:基于 PyTorch Lightning,Mambular 模型受益于简化的训练流程、易于定制以及分布式训练和 16 位精度等高级功能。

模型#

模型

描述

Mambular

使用 Mamba 块的顺序模型 Gu 和 Dao 专门设计用于各种表格数据任务。

FTTransformer

利用 Gorishniy 等人提出的基于transformer编码器的模型,用于表格数据。

MLP

用于处理表格数据任务的经典多层感知器 (MLP) 模型。

ResNet

一种针对表格数据应用的ResNet架构的适应性设计。

TabTransformer

基于Transformer的表格数据模型,由`Huang等人 <https://arxiv.org/abs/2012.06678>`_提出,增强了特征学习能力。

MambaTab

使用 这里 描述的联合输入表示的 Mamba-Block 的表格模型。不是一个序列模型。

TabulaRNN

用于表格数据的递归神经网络。尚未包含在基准测试中。

文档#

你可以在 这里 找到 Mamba-Tabular API 文档。

安装#

使用 pip 安装 Mambular:

pip install mambular

预处理#

Mambular 通过一套全面的技术简化了模型开发的前处理阶段,为 Mamba 架构准备数据。我们的预处理模块设计得既强大又易于使用,提供多种选项以高效地转换您的表格数据。

数据类型检测与转换#

Mambular 自动识别数据集中每个特征的类型,并应用最适合数值和分类变量的转换。这包括:

  • 序数编码:分类特征被无缝转换为数值,保留其固有顺序,使其适合模型使用。

  • 独热编码:对于名义数据,Mambular采用独热编码来捕捉类别的存在与否,而不引入顺序性。

  • 分箱:数值特征可以被离散化为箱,这是一种在某些建模情境中处理连续变量的有用技术。

  • 决策树分箱:可选地,Mambular可以使用决策树来找到数值特征的最佳分箱策略,从而增强模型的可解释性和性能。

  • 归一化:Mambular 可以轻松处理数值特征,无需将其专门转换为分类特征。标准预处理步骤,如按特征进行归一化,是可行的。

  • 标准化:同样地,可以使用标准化而不是归一化来基于均值和标准差对特征进行缩放。

  • PLE(周期线性编码):该技术可应用于数值特征,通过编码周期性来提升表格深度学习方法的性能。

  • 分位数变换:数值特征可以被转换为服从均匀分布或正态分布,从而提高模型对异常值的鲁棒性。

  • 样条变换:将分段多项式函数应用于数值特征,更有效地捕捉非线性关系。

  • 多项式特征:生成多项式和交互特征,增加特征空间以捕捉数据中更复杂的关系。

处理缺失值#

我们的预处理流程通过使用数值特征的均值插补和分类特征的众数插补,有效地处理了缺失数据。这确保了您的模型接收到完整的数据输入,而无需手动干预。此外,Mambular 在分类预处理中包含经典的 标记,可以在推理过程中管理未知的分类值。

拟合模型#

在 mambular 中拟合模型非常简单。mambular 中的所有模型都是 sklearn BaseEstimators。因此,所有模型都实现了 .fit 方法。此外,这还允许使用所有其他 sklearn 固有的方法,例如它们内置的超参数优化工具。

from mambular.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier(
    d_model=64,
    n_layers=8,
    numerical_preprocessing="ple",
    n_bins=50
)

# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
model.fit(X, y, max_epochs=150, lr=1e-04)

预测也很容易获得:

# simple predictions
preds = model.predict(X)

# Predict probabilities
preds = model.predict_proba(X)

使用 MambularLSS 进行分布回归#

Mambular 通过其 MambularLSS 模块引入了一种分布回归方法,允许用户对响应变量的全部分布进行建模,而不仅仅是其均值。这种方法在理解响应分布的变异性、偏度或峰度与预测其中心趋势同样重要的情况下特别有价值。mambular 中所有可用的模型也都可以作为分布模型使用。

MambularLSS 的主要特点:#

  • 全分布建模:与预测单一值(例如均值)的传统回归模型不同,MambularLSS 模型对响应变量的整个分布进行建模。这使得预测更加信息丰富,包括分位数、方差和高阶矩。

  • 可定制的分布类型: MambularLSS 支持多种分布族(例如,高斯分布、泊松分布、二项分布),使其能够适应不同类型的响应变量,从连续数据到计数数据。

  • 位置、尺度、形状参数:该模型预测与分布的位置、尺度和形状相对应的参数,提供了对数据潜在分布特征的细致理解。

  • 增强的预测不确定性:通过建模完整分布,MambularLSS 提供了更丰富的预测不确定性信息,使得在不确定环境中能够进行更稳健的决策过程。

可用的分布类:#

MambularLSS 提供了广泛的分布类,以满足各种统计建模需求。可用的分布类包括:

  • normal:用于建模均值周围对称分布的连续数据的正态分布。

  • poisson:用于建模计数数据的泊松分布,例如表示在固定时间间隔内发生的事件数量。

  • gamma:用于建模偏斜且在零处有界的连续数据的伽马分布,常用于等待时间的建模。

  • beta: 用于建模介于0和1之间的数据的Beta分布,适用于比例和百分比。

  • dirichlet:用于建模多元数据的狄利克雷分布,其中各分量相关,且总和约束为1。

  • studentt:用于建模尾部比正态分布更重的数据的Student’s T分布,在小样本量时非常有用。

  • negativebinom:用于建模相对于泊松分布具有过度离散的计数数据的负二项分布。

  • inversegamma:逆伽马分布,常用于贝叶斯推断中作为尺度参数的先验分布。

  • categorical:用于建模具有两个以上类别的分类数据的分类分布。

这些分布类允许 MambularLSS 灵活地建模各种数据类型和分布,为用户提供了捕捉其数据全部复杂性所需的工具。

开始使用 MambularLSS:#

要将分布式回归集成到使用 MambularLSS 的工作流程中,首先通过所需的配置初始化模型,类似于其他 Mambular 模型:

from mambular.models import MambularLSS

# Initialize the MambularLSS model
model = MambularLSS(
    dropout=0.2,
    d_model=64,
    n_layers=8,

)

# Fit the model to your data
model.fit(
    X,
    y,
    max_epochs=150,
    lr=1e-04,
    patience=10,
    family="normal" # define your distribution
    )

实现你自己的模型#

Mambular 允许用户轻松地将他们的自定义模型集成到现有逻辑中。这个过程设计得非常简单,使得创建一个 PyTorch 模型并定义其前向传递变得容易。你不是从 nn.Module 继承,而是从 Mambular 的 BaseModel 继承。每个 Mambular 模型接受三个主要参数:类别数量(例如,回归为 1,二分类为 2),cat_feature_infonum_feature_info 分别用于分类和数值特征信息。此外,你可以提供一个配置参数,这可以是自定义配置或提供的默认配置之一。

使用 Mambular 的关键优势之一是前向传递的输入是张量列表。虽然这可能不常见,但对于以不同方式处理不同数据类型的模型来说,这是非常有利的。例如,TabTransformer 模型利用这一特性分别处理分类数据和数值数据,对每种数据类型应用不同的转换和处理步骤。

以下是如何使用 Mambular 实现自定义模型的方法:

  1. 首先,定义你的配置:配置类允许你指定模型的超参数和其他设置。这可以通过一个简单的数据类来完成。

from dataclasses import dataclass

@dataclass
class MyConfig:
    lr: float = 1e-04
    lr_patience: int = 10
    weight_decay: float = 1e-06
    lr_factor: float = 0.1
  1. 其次,定义你的模型:定义你的自定义模型就像为 nn.Module 那样。主要区别在于你将继承自 BaseModel 并使用提供的特征信息来构建你的层。要将你的模型集成到现有的API中,你只需要定义架构和前向传播。

from mambular.base_models import BaseModel
import torch
import torch.nn

class MyCustomModel(BaseModel):
    def __init__(
        self,
        cat_feature_info,
        num_feature_info,
        num_classes: int = 1,
        config=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

        input_dim = 0
        for feature_name, input_shape in num_feature_info.items():
            input_dim += input_shape
        for feature_name, input_shape in cat_feature_info.items():
            input_dim += 1

        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, num_features, cat_features):
        x = num_features + cat_features
        x = torch.cat(x, dim=1)

        # Pass through linear layer
        output = self.linear(x)
        return output
  1. 利用 Mambular API:您可以通过以下方式构建一个回归、分类或分布回归模型,该模型可以利用 Mambular 的所有内置方法:

from mambular.models import SklearnBaseRegressor

class MyRegressor(SklearnBaseRegressor):
    def __init__(self, **kwargs):
        super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
  1. 训练和评估你的模型:你现在可以像使用其他 Mambular 模型一样,对你的自定义模型进行拟合、评估和预测。对于分类或分布回归,分别继承自 SklearnBaseClassifierSklearnBaseLSS

regressor = MyRegressor(numerical_preprocessing="ple")
regressor.fit(X_train, y_train, max_epochs=50)

引用#

如果你发现这个项目对你的研究有帮助,请考虑引用:

@article{thielmann2024mambular,
  title={Mambular: A Sequential Model for Tabular Deep Learning},
  author={Thielmann, Anton Frederik and Kumar, Manish and Weisser, Christoph and Reuter, Arik and S{\"a}fken, Benjamin and Samiee, Soheila},
  journal={arXiv preprint arXiv:2408.06291},
  year={2024}
}

许可证#

整个代码库在 MIT 许可证下。