注意
转到末尾 以下载完整的示例代码
重心¶
这个例子展示了三种计算时间序列重心的方法。
有关可用方法的概述,请参见tslearn.barycenters模块。
tslearn 提供了三种方法来计算给定时间序列集合的重心:
欧几里得重心 仅仅是每个时间点上各个点的算术平均值,最小化每个点的欧几里得距离之和。如下所示,它与基于DTW的方法非常不同,可能经常不适用。然而,它是所示方法中最快的。
DTW 重心平均 (DBA) 是一种迭代优化的重心,从一个(可能)不好的候选开始,并不断改进,直到满足收敛条件。优化可以通过 (a) 期望最大化 [1] 和 (b) 随机次梯度下降 [2] 来实现。经验表明,后者“通常更稳定,并且在更短的时间内找到更好的解决方案” [2]。
Soft-DTW 重心 使用可微分的损失函数来迭代地找到一个重心 [3]。该方法本身以及参数 \(\gamma=1.0\) 在关于 DTW 的部分中有更详细的描述。还有一个专门的 示例 可用。
[1] F. Petitjean, A. Ketterlin & P. Gancarski. 一种用于动态时间规整的全局平均方法,及其在聚类中的应用。模式识别, Elsevier, 2011, 卷. 44, 期. 3, 页. 678-693.
[2] D. Schultz & B. Jain. 动态时间规整空间中的非光滑分析和次梯度方法用于平均。模式识别,74,340-358。
[3] M. Cuturi & M. Blondel. Soft-DTW: 一种用于时间序列的可微损失函数。ICML 2017.
# Author: Romain Tavenard, Felix Divo
# License: BSD 3 clause
import numpy
import matplotlib.pyplot as plt
from tslearn.barycenters import \
euclidean_barycenter, \
dtw_barycenter_averaging, \
dtw_barycenter_averaging_subgradient, \
softdtw_barycenter
from tslearn.datasets import CachedDatasets
# fetch the example data set
numpy.random.seed(0)
X_train, y_train, _, _ = CachedDatasets().load_dataset("Trace")
X = X_train[y_train == 2]
length_of_sequence = X.shape[1]
def plot_helper(barycenter):
# plot all points of the data set
for series in X:
plt.plot(series.ravel(), "k-", alpha=.2)
# plot the given barycenter of them
plt.plot(barycenter.ravel(), "r-", linewidth=2)
# plot the four variants with the same number of iterations and a tolerance of
# 1e-3 where applicable
ax1 = plt.subplot(4, 1, 1)
plt.title("Euclidean barycenter")
plot_helper(euclidean_barycenter(X))
plt.subplot(4, 1, 2, sharex=ax1)
plt.title("DBA (vectorized version of Petitjean's EM)")
plot_helper(dtw_barycenter_averaging(X, max_iter=50, tol=1e-3))
plt.subplot(4, 1, 3, sharex=ax1)
plt.title("DBA (subgradient descent approach)")
plot_helper(dtw_barycenter_averaging_subgradient(X, max_iter=50, tol=1e-3))
plt.subplot(4, 1, 4, sharex=ax1)
plt.title("Soft-DTW barycenter ($\gamma$=1.0)")
plot_helper(softdtw_barycenter(X, gamma=1., max_iter=50, tol=1e-3))
# clip the axes for better readability
ax1.set_xlim([0, length_of_sequence])
# show the plot(s)
plt.tight_layout()
plt.show()
脚本总运行时间: (0 分钟 3.754 秒)