GraphPipeline¶
GraphPipelines (tpot2.GraphPipeline
) 的工作方式与 scikit-learn 的 Pipeline 类类似。与提供步骤列表不同,在 GraphPipeline 中,您使用 networkx 提供步骤的有向无环图 (networkx.DiGraph
)。在 GraphPipeline 中,父节点从其子节点获取输入(即叶子节点获取原始输入 (X,y),而根节点是最终的分类器/回归器)。
节点的标签可以是任何内容,但在每个sklearn估计器实例中必须是唯一的。每个节点都有一个名为“instance”的属性,用于表示scikit-learn估计器的实例。
GraphPipeline 允许在管道中间使用分类器和回归器。在这种情况下,GraphPipeline 将尝试按顺序使用 predict_proba、decision_function 或 predict 的输出。如果设置了 cross_val_predict_cv,下游模型将使用 sklearn.model_selection.cross_val_predict
的输出进行训练(最终结果使用在完整数据上训练的模型进行预测)。
Parameters
----------
graph: networkx.DiGraph
A directed graph where the nodes are sklearn estimators and the edges are the inputs to those estimators.
cross_val_predict_cv: int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy used in inner classifiers or regressors
method: str, optional
The prediction method to use for the inner classifiers or regressors. If 'auto', it will try to use predict_proba, decision_function, or predict in that order.
memory: str or object with the joblib.Memory interface, optional
Used to cache the input and outputs of nodes to prevent refitting or computationally heavy transformations. By default, no caching is performed. If a string is given, it is the path to the caching directory.
use_label_encoder: bool, optional
If True, the label encoder is used to encode the labels to be 0 to N. If False, the label encoder is not used.
Mainly useful for classifiers (XGBoost) that require labels to be ints from 0 to N.
Can also be a sklearn.preprocessing.LabelEncoder object. If so, that label encoder is used.
In [5]:
Copied!
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
import networkx as nx
from tpot2 import GraphPipeline
import sklearn.metrics
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
g = nx.DiGraph()
g.add_node("scaler", instance=StandardScaler())
g.add_node("svc", instance=SVC())
g.add_node("LogisticRegression", instance=LogisticRegression())
g.add_node("LogisticRegression2", instance=LogisticRegression())
g.add_edge("svc","scaler")
g.add_edge("LogisticRegression", "scaler")
g.add_edge("LogisticRegression2", "LogisticRegression")
g.add_edge("LogisticRegression2", "svc")
est = GraphPipeline(g)
est.plot()
est.fit(X_train, y_train)
print("score")
print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
import networkx as nx
from tpot2 import GraphPipeline
import sklearn.metrics
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
g = nx.DiGraph()
g.add_node("scaler", instance=StandardScaler())
g.add_node("svc", instance=SVC())
g.add_node("LogisticRegression", instance=LogisticRegression())
g.add_node("LogisticRegression2", instance=LogisticRegression())
g.add_edge("svc","scaler")
g.add_edge("LogisticRegression", "scaler")
g.add_edge("LogisticRegression2", "LogisticRegression")
g.add_edge("LogisticRegression2", "svc")
est = GraphPipeline(g)
est.plot()
est.fit(X_train, y_train)
print("score")
print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
score 0.8974358974358974
交叉验证预测¶
在某些情况下,使用cross_val_predict_cv可以提高性能。
In [6]:
Copied!
est = GraphPipeline(g, cross_val_predict_cv=10)
est.plot()
est.fit(X_train, y_train)
print("score")
print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
est = GraphPipeline(g, cross_val_predict_cv=10)
est.plot()
est.fit(X_train, y_train)
print("score")
print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
score 0.9166666666666666
您可以使用每个节点的标签访问GraphPipeline的各个步骤。
In [7]:
Copied!
svc = est.graph.nodes["svc"]["instance"]
svc
svc = est.graph.nodes["svc"]["instance"]
svc
Out[7]:
SVC()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SVC()