可解释的回归#
在本笔记本中,我们将拟合可解释的提升机(EBM)、线性回归和回归树模型。拟合完成后,我们将利用它们的透明性来理解它们的全局和局部解释。
这个笔记本可以在我们的examples folder在GitHub上找到。
# install interpret if not already installed
try:
import interpret
except ModuleNotFoundError:
!pip install --quiet interpret scikit-learn
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from interpret import show
from interpret.perf import RegressionPerf
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())
dataset = fetch_california_housing()
X = dataset.data
y = dataset.target
names = dataset.feature_names
seed = 42
np.random.seed(seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)
探索数据集
from interpret import show
from interpret.data import Marginal
marginal = Marginal(names).explain_data(X_train, y_train, name='Train Data')
show(marginal)
训练可解释的提升机(EBM)
from interpret.glassbox import ExplainableBoostingRegressor, LinearRegression, RegressionTree
ebm = ExplainableBoostingRegressor(names, interactions=3)
ebm.fit(X_train, y_train)
ExplainableBoostingRegressor(feature_names=['MedInc', 'HouseAge', 'AveRooms',
'AveBedrms', 'Population',
'AveOccup', 'Latitude',
'Longitude'],
interactions=3)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ExplainableBoostingRegressor(feature_names=['MedInc', 'HouseAge', 'AveRooms',
'AveBedrms', 'Population',
'AveOccup', 'Latitude',
'Longitude'],
interactions=3)EBMs 是玻璃盒模型,因此我们可以编辑它们
# post-process monotonize the MedInc feature
ebm.monotonize("MedInc", increasing=True)
ExplainableBoostingRegressor(feature_names=['MedInc', 'HouseAge', 'AveRooms',
'AveBedrms', 'Population',
'AveOccup', 'Latitude',
'Longitude'],
interactions=3)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ExplainableBoostingRegressor(feature_names=['MedInc', 'HouseAge', 'AveRooms',
'AveBedrms', 'Population',
'AveOccup', 'Latitude',
'Longitude'],
interactions=3)全局解释:模型整体学到了什么
ebm_global = ebm.explain_global(name='EBM')
show(ebm_global)
局部解释:单个预测是如何做出的
ebm_local = ebm.explain_local(X_test[:5], y_test[:5], name='EBM')
show(ebm_local, 0)
评估EBM性能
ebm_perf = RegressionPerf(ebm, names).explain_perf(X_test, y_test, name='EBM')
show(ebm_perf)
让我们测试一些其他可解释的模型
from interpret.glassbox import LinearRegression, RegressionTree
lr = LinearRegression(names)
lr.fit(X_train, y_train)
rt = RegressionTree(names, random_state=seed)
rt.fit(X_train, y_train)
<interpret.glassbox._decisiontree.RegressionTree at 0x7fb5e0686dc0>
使用仪表板比较性能
lr_perf = RegressionPerf(lr, names).explain_perf(X_test, y_test, name='Linear Regression')
show(lr_perf)
rt_perf = RegressionPerf(rt, names).explain_perf(X_test, y_test, name='Regression Tree')
show(rt_perf)
Glassbox:我们所有的模型都有全局和局部解释
lr_global = lr.explain_global(name='Linear Regression')
show(lr_global)
rt_global = rt.explain_global(name='Regression Tree')
show(rt_global)