模型持久化

许多tslearn模型可以保存到磁盘,并在以后的时间用于预测。当模型需要很长时间来训练时,这尤其有用。

可用格式: hdf5, json, pickle

将模型保存到磁盘:

model.to_<format>

从磁盘加载模型:

model.from_<format>

基本用法

# Instantiate a model
model = ModelClass(<hyper-parameters>)

# Train the model
model.fit(X_train)

# Save the model to disk
model.to_hdf5('./trained_model.hdf5')

# Load model from disk
model.from_hdf5('./trained_mode.hdf5')

# Make predictions
y = model.predict(X_test)

注意

对于以下模型,训练数据会被保存到磁盘上,如果训练数据集很大,可能会导致模型文件很大: KNeighborsTimeSeries, KNeighborsTimeSeriesClassifier, 和 KernelKMeans

Cluster 1, Cluster 2, Cluster 3
0.009 --> 0.009 --> 0.008 --> 0.008 --> 0.008 --> 0.007 --> 0.007 --> 0.006 --> 0.005 --> 0.005 --> 0.005 --> 0.005 --> 0.004 --> 0.004 --> 0.004 --> 0.004 --> 0.004 --> 0.004 --> 0.003 --> 0.003 --> 0.003 --> 0.003 --> 0.003 --> 0.003 --> 0.003 --> 0.003 --> 0.003 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 --> 0.002 -->

# Example using KShape

import numpy
import matplotlib.pyplot as plt

from tslearn.clustering import KShape
from tslearn.datasets import CachedDatasets
from tslearn.preprocessing import TimeSeriesScalerMeanVariance

seed = 0
numpy.random.seed(seed)
X_train, y_train, X_test, y_test = CachedDatasets().load_dataset("Trace")

# Keep first 3 classes
X_train = X_train[y_train < 4]
numpy.random.shuffle(X_train)
# Keep only 50 time series
X_train = TimeSeriesScalerMeanVariance().fit_transform(X_train[:50])
sz = X_train.shape[1]

# Instantiate k-Shape model
ks = KShape(n_clusters=3, verbose=True, random_state=seed)

# Train
ks.fit(X_train)

# Save model
ks.to_hdf5('./ks_trained.hdf5')

# Load model
trained_ks = KShape.from_hdf5('./ks_trained.hdf5')

# Use loaded model to make predictions
y_pred = trained_ks.predict(X_train)

plt.figure()
for yi in range(3):
    plt.subplot(3, 1, 1 + yi)
    for xx in X_train[y_pred == yi]:
        plt.plot(xx.ravel(), "k-", alpha=.2)
    plt.plot(ks.cluster_centers_[yi].ravel(), "r-")
    plt.xlim(0, sz)
    plt.ylim(-4, 4)
    plt.title("Cluster %d" % (yi + 1))

plt.tight_layout()
plt.show()

脚本总运行时间: (0 分钟 4.923 秒)

Gallery generated by Sphinx-Gallery