重心

这个例子展示了三种计算时间序列重心的方法。 有关可用方法的概述,请参见tslearn.barycenters模块。

tslearn 提供了三种方法来计算给定时间序列集合的重心:

  • 欧几里得重心 仅仅是每个时间点上各个点的算术平均值,最小化每个点的欧几里得距离之和。如下所示,它与基于DTW的方法非常不同,可能经常不适用。然而,它是所示方法中最快的。

  • DTW 重心平均 (DBA) 是一种迭代优化的重心,从一个(可能)不好的候选开始,并不断改进,直到满足收敛条件。优化可以通过 (a) 期望最大化 [1] 和 (b) 随机次梯度下降 [2] 来实现。经验表明,后者“通常更稳定,并且在更短的时间内找到更好的解决方案” [2]。

  • Soft-DTW 重心 使用可微分的损失函数来迭代地找到一个重心 [3]。该方法本身以及参数 \(\gamma=1.0\) 在关于 DTW 的部分中有更详细的描述。还有一个专门的 示例 可用。

[1] F. Petitjean, A. Ketterlin & P. Gancarski. 一种用于动态时间规整的全局平均方法,及其在聚类中的应用。模式识别, Elsevier, 2011, 卷. 44, 期. 3, 页. 678-693.

[2] D. Schultz & B. Jain. 动态时间规整空间中的非光滑分析和次梯度方法用于平均。模式识别,74,340-358。

[3] M. Cuturi & M. Blondel. Soft-DTW: 一种用于时间序列的可微损失函数。ICML 2017.

Euclidean barycenter, DBA (vectorized version of Petitjean's EM), DBA (subgradient descent approach), Soft-DTW barycenter ($\gamma$=1.0)
# Author: Romain Tavenard, Felix Divo
# License: BSD 3 clause

import numpy
import matplotlib.pyplot as plt

from tslearn.barycenters import \
    euclidean_barycenter, \
    dtw_barycenter_averaging, \
    dtw_barycenter_averaging_subgradient, \
    softdtw_barycenter
from tslearn.datasets import CachedDatasets

# fetch the example data set
numpy.random.seed(0)
X_train, y_train, _, _ = CachedDatasets().load_dataset("Trace")
X = X_train[y_train == 2]
length_of_sequence = X.shape[1]


def plot_helper(barycenter):
    # plot all points of the data set
    for series in X:
        plt.plot(series.ravel(), "k-", alpha=.2)
    # plot the given barycenter of them
    plt.plot(barycenter.ravel(), "r-", linewidth=2)


# plot the four variants with the same number of iterations and a tolerance of
# 1e-3 where applicable
ax1 = plt.subplot(4, 1, 1)
plt.title("Euclidean barycenter")
plot_helper(euclidean_barycenter(X))

plt.subplot(4, 1, 2, sharex=ax1)
plt.title("DBA (vectorized version of Petitjean's EM)")
plot_helper(dtw_barycenter_averaging(X, max_iter=50, tol=1e-3))

plt.subplot(4, 1, 3, sharex=ax1)
plt.title("DBA (subgradient descent approach)")
plot_helper(dtw_barycenter_averaging_subgradient(X, max_iter=50, tol=1e-3))

plt.subplot(4, 1, 4, sharex=ax1)
plt.title("Soft-DTW barycenter ($\gamma$=1.0)")
plot_helper(softdtw_barycenter(X, gamma=1., max_iter=50, tol=1e-3))

# clip the axes for better readability
ax1.set_xlim([0, length_of_sequence])

# show the plot(s)
plt.tight_layout()
plt.show()

脚本总运行时间: (0 分钟 3.754 秒)

Gallery generated by Sphinx-Gallery