plot_scatter_predictions

plot_scatter_predictions(y, y_pred, title=None)[source]

绘制一个散点图,比较给定数据集的实际值和预测值。

这个散点图通常用于绘制时间序列外部回归方法的预测结果,因为输出是连续的。在时间序列分类的情况下,它将类似于混淆矩阵。

Parameters:
y: np.array

实际值。

y_pred: np.array

预测值。

title: str, default = None

要在图表顶部显示的标题。

Returns:
figmatplotlib.figure.Figure
axmatplotlib.axes.Axes

示例

>>> from aeon.visualisation import plot_scatter_predictions
>>> from aeon.datasets import load_covid_3month
>>> from aeon.regression.feature_based import FreshPRINCERegressor  
>>> X_train, y_train = load_covid_3month(split="train")
>>> X_test, y_test = load_covid_3month(split="test")
>>> fp = FreshPRINCERegressor(n_estimators=10)  
>>> fp.fit(X_train, y_train)  
>>> y_pred_fp = fp.predict(X_test)  
>>> plot = plot_scatter_predictions(y_test, y_pred_fp, title="FP-Covid3Month")        
>>> plot.show()  
>>> plot.savefig("scatterplot_predictions.pdf")