解释每层的局部分类器

一个极简示例,展示如何使用HiClass Explainer获取LCPL模型的SHAP值。 Explainer类的详细总结已在分层可解释性的算法概述部分给出。 SHAP值是基于一个可以这里下载的合成鸭嘴兽疾病数据集计算的。

plot lcpl explainer

输出:

<xarray.Dataset>
Dimensions:          (class: 15, level: 3, sample: 246, feature: 9)
Coordinates:
  * class            (class) <U16 'Allergy' 'Bee Allergy' ... 'Respiratory'
  * level            (level) int64 0 1 2
Dimensions without coordinates: sample, feature
Data variables:
    node             (sample, level) object 'Respiratory' ... 'Milk Allergy'
    predicted_class  (sample, level) object 'Respiratory' ... 'Milk Allergy'
    predict_proba    (sample, level, class) float64 0.16 nan nan ... 0.02 nan
    classes          (sample, level, class) object 'Allergy' nan ... nan
    shap_values      (level, class, sample, feature) float64 0.01697 ... nan

from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerLevel, Explainer
import shap
from hiclass.datasets import load_platypus

# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every level
rfc = RandomForestClassifier()
classifier = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

# Train local classifiers per level
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Let's filter the Shapley values corresponding to the Covid (level 1)
# and 'Respiratory' (level 0)

covid_idx = classifier.predict(X_test)[:, 1] == "Covid"

shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx}
shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx}
shap_val_covid = explanations.sel(**shap_filter_covid)
shap_val_resp = explanations.sel(**shap_filter_resp)


# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases.

# Feature names for the X-axis
feature_names = X_train.columns.values

# SHAP values for 'Covid'
shap_values_covid = shap_val_covid.shap_values.values

# SHAP values for 'Respiratory'
shap_values_resp = shap_val_resp.shap_values.values

shap.summary_plot(
    [shap_values_covid, shap_values_resp],
    features=X_test.iloc[covid_idx],
    feature_names=X_train.columns.values,
    plot_type="bar",
    class_names=["Covid", "Respiratory"],
)

脚本总运行时间: ( 0 分钟 52.827 秒)

Gallery generated by Sphinx-Gallery