导入库¶
定义配置¶
data_config = DataConfig(
target=[
"target"
], # 目标应始终为一个列表。仅回归任务支持多目标。多任务分类功能尚未实现。
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
auto_lr_find=True, # 运行LRFinder以自动推导学习率
batch_size=1024,
max_epochs=100,
accelerator="auto", # can be 'cpu','gpu', 'tpu', or 'ipu'
)
optimizer_config = OptimizerConfig()
head_config = LinearHeadConfig(
layers="", # 头部没有增加额外的层,仅包含一个映射层至输出维度。
dropout=0.1,
initialization="kaiming",
).__dict__ # 转换为字典以传递给模型配置(OmegaConf不接受对象)
model_config = GANDALFConfig(
task="classification",
gflu_stages=3, # GFLU模块的级数
gflu_dropout=0.0, # 每个GFLU块中的丢弃
gflu_feature_init_sparsity=0.1, # 初始特征选择的稀疏性
head="LinearHead", # 线性磁头
head_config=head_config, # 线性磁头配置
learning_rate=1e-3,
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
verbose=False,
)
训练模型¶
评估模型¶
本地全局特征重要性¶
一些模型,如GANDALF、GATE和FTTransformer,具有原生特征重要性,类似于您在使用GDBTs时获得的特征重要性。
局部特征归因¶
我们也可以使用像 SHAP 这样的技术来获取局部特征归因。这是一种非常强大的技术,用于解释模型的预测。我们可以使用 explain
方法来获取给定输入的局部特征归因。
PyTorch Tabular
支持来自 captum
的这些方法,适用于除 Tabnet、TabTransformer 和 MDN 之外的所有模型:
- GradientShap: https://captum.ai/api/gradient_shap.html
- IntegratedGradients: https://captum.ai/api/integrated_gradients.html
- DeepLift: https://captum.ai/api/deep_lift.html
- DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
- InputXGradient: https://captum.ai/api/input_x_gradient.html
- FeaturePermutation: https://captum.ai/api/feature_permutation.html
- FeatureAblation: https://captum.ai/api/feature_ablation.html
- KernelShap: https://captum.ai/api/kernel_shap.html
PyTorch Tabular
还支持解释单个实例以及实例的批量。但是,较大的数据集将需要更长的时间来解释。例外的是 FeaturePermutation
和 FeatureAblation
方法,这对于大批量实例才有意义。
大多数这些可解释性方法需要一个基线。基线用于将输入的归因与基线的归因进行比较。基线可以是一个标量值,一个与输入形状相同的张量,或者是一个特殊字符串,如 "b|100",表示来自训练数据的 100 个样本。如果未提供基线,则使用默认基线(零)。
单实例¶
# 绘制图表
plt.figure(figsize=(14, 10), dpi=80)
# 绘制水平线
plt.hlines(
y=exp.index,
linewidth=5,
xmin=0,
xmax=exp.GradientSHAP,
colors=exp.colors.values,
alpha=0.5,
)
# 装饰品
# 设置x轴和y轴的标签
plt.gca().set(ylabel="Features", xlabel="GradientSHAP")
# 将日期设置为y轴
plt.yticks(exp.index, exp.Features, fontsize=12)
# 条形图标题
plt.title("GradientSHAP Local Explanation", fontdict={"size": 20})
# 可选的网格布局
plt.grid(linestyle="--", alpha=0.5)
# 显示分叉条形图
plt.show()
多个实例¶
# 绘制图表
plt.figure(figsize=(14, 10), dpi=80)
# 绘制水平线
plt.hlines(
y=exp_agg.index,
linewidth=5,
xmin=0,
xmax=exp_agg.GradientSHAP,
colors=exp_agg.colors.values,
alpha=0.5,
)
# 装饰品
# 设置x轴和y轴的标签
plt.gca().set(ylabel="Features", xlabel="GradientSHAP")
# 将日期设置到y轴
plt.yticks(exp_agg.index, exp_agg.Features, fontsize=12)
# 条形图标题
plt.title("GradientSHAP Global Explanation", fontdict={"size": 20})
# 可选的网格布局
plt.grid(linestyle="--", alpha=0.5)
# 显示发散条形图
plt.show()