Skip to content
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from pytorch_tabular.utils import make_mixed_dataset

# %加载自动重新加载扩展
# %自动重新加载 2
data, cat_col_names, num_col_names = make_mixed_dataset(
    task="classification", n_samples=10000, n_features=20, n_categories=4
)
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)

导入库

from pytorch_tabular import TabularModel
from pytorch_tabular.models import GANDALFConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig

定义配置

data_config = DataConfig(
    target=[
        "target"
    ],  # 目标应始终为一个列表。仅回归任务支持多目标。多任务分类功能尚未实现。
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True,  # 运行LRFinder以自动推导学习率
    batch_size=1024,
    max_epochs=100,
    accelerator="auto",  # can be 'cpu','gpu', 'tpu', or 'ipu'
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="",  # 头部没有增加额外的层,仅包含一个映射层至输出维度。
    dropout=0.1,
    initialization="kaiming",
).__dict__  # 转换为字典以传递给模型配置(OmegaConf不接受对象)

model_config = GANDALFConfig(
    task="classification",
    gflu_stages=3,  # GFLU模块的级数
    gflu_dropout=0.0,  # 每个GFLU块中的丢弃
    gflu_feature_init_sparsity=0.1,  # 初始特征选择的稀疏性
    head="LinearHead",  # 线性磁头
    head_config=head_config,  # 线性磁头配置
    learning_rate=1e-3,
)


tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False,
)

训练模型

tabular_model.fit(train=train, validation=val)
Seed set to 42

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.09120108393559097
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_fd8f0e95-a529-4577-9cd3-e484ff32cbef.ckpt
Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_fd8f0e95-a529-4577-9cd3-e484ff32cbef.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type              Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ GANDALFBackbone  │ 14.5 K │
│ 1 │ _embedding_layer │ Embedding1dLayer │     92 │
│ 2 │ _head            │ Sequential       │     60 │
│ 3 │ loss             │ CrossEntropyLoss │      0 │
└───┴──────────────────┴──────────────────┴────────┘
Trainable params: 14.6 K                                                        
Non-trainable params: 0                                                         
Total params: 14.6 K                                                            
Total estimated model params size (MB): 0                                       

Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7fdde261ea50>

评估模型

result = tabular_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9196000099182129     │
│         test_loss              0.203481063246727     │
└───────────────────────────┴───────────────────────────┘
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



本地全局特征重要性

一些模型,如GANDALF、GATE和FTTransformer,具有原生特征重要性,类似于您在使用GDBTs时获得的特征重要性。

tabular_model.feature_importance().sort_values("importance", ascending=False)
Features importance
0 cat_col_5 5.482856e-01
3 cat_col_19 3.333333e-01
1 cat_col_9 1.183810e-01
11 num_col_8 1.562194e-08
10 num_col_7 1.480499e-08
13 num_col_12 1.365922e-08
19 num_col_18 1.304813e-08
12 num_col_11 1.212310e-08
2 cat_col_10 1.032072e-08
9 num_col_6 1.018892e-08
17 num_col_16 1.012315e-08
14 num_col_13 1.005833e-08
5 num_col_1 9.984563e-09
18 num_col_17 9.899406e-09
8 num_col_4 9.853183e-09
7 num_col_3 9.719336e-09
15 num_col_14 9.535079e-09
4 num_col_0 8.864745e-09
16 num_col_15 8.863941e-09
6 num_col_2 8.387005e-09
# 柱状图,前10个特征
tabular_model.feature_importance().sort_values("importance", ascending=False).head(
    10
).plot.bar(logy=True)
<Axes: >
No description has been provided for this image

局部特征归因

我们也可以使用像 SHAP 这样的技术来获取局部特征归因。这是一种非常强大的技术,用于解释模型的预测。我们可以使用 explain 方法来获取给定输入的局部特征归因。

PyTorch Tabular 支持来自 captum 的这些方法,适用于除 Tabnet、TabTransformer 和 MDN 之外的所有模型:

  • GradientShap: https://captum.ai/api/gradient_shap.html
  • IntegratedGradients: https://captum.ai/api/integrated_gradients.html
  • DeepLift: https://captum.ai/api/deep_lift.html
  • DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
  • InputXGradient: https://captum.ai/api/input_x_gradient.html
  • FeaturePermutation: https://captum.ai/api/feature_permutation.html
  • FeatureAblation: https://captum.ai/api/feature_ablation.html
  • KernelShap: https://captum.ai/api/kernel_shap.html

PyTorch Tabular 还支持解释单个实例以及实例的批量。但是,较大的数据集将需要更长的时间来解释。例外的是 FeaturePermutationFeatureAblation 方法,这对于大批量实例才有意义。

大多数这些可解释性方法需要一个基线。基线用于将输入的归因与基线的归因进行比较。基线可以是一个标量值,一个与输入形状相同的张量,或者是一个特殊字符串,如 "b|100",表示来自训练数据的 100 个样本。如果未提供基线,则使用默认基线(零)。

单实例

exp = tabular_model.explain(test.head(1), method="GradientShap", baselines="b|10000")
exp = exp.T.sort_values(0, ascending=False)
exp.columns = ["GradientSHAP"]
exp.index.name = "Features"
exp.reset_index(inplace=True)
exp["colors"] = "red"
exp.loc[exp["GradientSHAP"] &gt; 0, "colors"] = "green"
# 绘制图表
plt.figure(figsize=(14, 10), dpi=80)

# 绘制水平线
plt.hlines(
    y=exp.index,
    linewidth=5,
    xmin=0,
    xmax=exp.GradientSHAP,
    colors=exp.colors.values,
    alpha=0.5,
)
# 装饰品
# 设置x轴和y轴的标签
plt.gca().set(ylabel="Features", xlabel="GradientSHAP")

# 将日期设置为y轴
plt.yticks(exp.index, exp.Features, fontsize=12)

# 条形图标题
plt.title("GradientSHAP Local Explanation", fontdict={"size": 20})

# 可选的网格布局
plt.grid(linestyle="--", alpha=0.5)

# 显示分叉条形图
plt.show()
No description has been provided for this image

多个实例

exp = tabular_model.explain(
    test[test.cat_col_19 == 3], method="GradientShap", baselines="b|10000"
)
exp.head()
num_col_0 num_col_1 num_col_2 num_col_3 num_col_4 num_col_6 num_col_7 num_col_8 num_col_11 num_col_12 num_col_13 num_col_14 num_col_15 num_col_16 num_col_17 num_col_18 cat_col_5 cat_col_9 cat_col_10 cat_col_19
0 -1.724874 -1.447601 0.005033 -0.023054 0.019009 0.008514 -0.083980 0.004687 -0.007507 -0.126826 0.010685 -0.100001 -0.000743 0.006760 -0.023292 0.008852 0.243883 0.129083 0.101006 -0.103718
1 0.033266 0.354682 -0.023931 -0.027005 0.010930 0.001140 0.056476 -0.021277 0.001867 0.622130 0.026182 0.008232 -0.061105 0.012860 0.041835 -0.006741 1.755563 0.008189 -0.069742 0.023084
2 0.274286 0.525516 0.021727 -0.084166 0.020813 0.001058 -0.048389 0.031756 -0.002266 -0.270582 0.013404 -0.043228 0.048533 -0.013822 0.038167 0.021330 -0.239157 -0.150969 -0.063134 -0.168519
3 0.016431 -0.107677 0.015629 0.012411 -0.006554 0.001735 0.009270 0.024726 -0.001018 0.586440 0.003225 -0.019198 0.015317 -0.002911 0.001742 -0.002891 0.287464 0.171883 -0.011178 -0.051022
4 -0.004894 1.684316 0.016587 -0.357520 0.022842 -0.013848 0.005630 0.025678 -0.004070 0.802940 -0.027120 0.062992 0.179163 0.001076 0.019801 -0.022297 -1.746769 0.793595 -0.042146 0.055857
exp_agg = exp.abs().sum().to_frame()
exp_agg = exp_agg.sort_values(0, ascending=True)
exp_agg.columns = ["GradientSHAP"]
exp_agg.index.name = "Features"
exp_agg.reset_index(inplace=True)
exp_agg["colors"] = "red"
exp_agg.loc[exp_agg["GradientSHAP"] &gt; 0, "colors"] = "green"
# 绘制图表
plt.figure(figsize=(14, 10), dpi=80)

# 绘制水平线
plt.hlines(
    y=exp_agg.index,
    linewidth=5,
    xmin=0,
    xmax=exp_agg.GradientSHAP,
    colors=exp_agg.colors.values,
    alpha=0.5,
)
# 装饰品
# 设置x轴和y轴的标签
plt.gca().set(ylabel="Features", xlabel="GradientSHAP")

# 将日期设置到y轴
plt.yticks(exp_agg.index, exp_agg.Features, fontsize=12)

# 条形图标题
plt.title("GradientSHAP Global Explanation", fontdict={"size": 20})

# 可选的网格布局
plt.grid(linestyle="--", alpha=0.5)

# 显示发散条形图
plt.show()
No description has been provided for this image