tslearn.metrics.SoftDTWLossPyTorch¶
- tslearn.metrics.SoftDTWLossPyTorch(gamma=1.0, normalize=False, dist_func=None)[source]¶
PyTorch中的Soft-DTW损失函数。
Soft-DTW最初在[1]中提出,并在我们的关于DTW及其变体的用户指南页面中进行了更详细的讨论。
Soft-DTW 的计算方式如下:
\[\text{soft-DTW}_{\gamma}(X, Y) = \min_{\pi}{}^\gamma \sum_{(i, j) \in \pi} d \left( X_i, Y_j \right)\]其中 \(d\) 是一个支持 PyTorch 自动微分的距离函数或不相似度度量,\(\min^\gamma\) 是参数 \(\gamma\) 的软最小运算符,定义如下:
\[\min{}^\gamma \left( a_{1}, ..., a_{n} \right) = - \gamma \log \sum_{i=1}^{n} e^{- a_{i} / \gamma}\]在极限情况下 \(\gamma = 0\),\(\min^\gamma\) 简化为一个硬最小操作符。当 \(d\) 是平方欧几里得距离时,软DTW被定义为DTW不相似度度量的平方。
与DTW相反,soft-DTW的下界不为零,我们甚至有:
\[\text{soft-DTW}_{\gamma}(X, Y) \rightarrow - \infty \text{ when } \gamma \rightarrow + \infty\]在[2]中,定义了新的相异性度量,这些度量依赖于soft-DTW。 特别是,引入了soft-DTW散度来抵消soft-DTW的非正性:
\[D_{\gamma} \left( X, Y \right) = \text{soft-DTW}_{\gamma}(X, Y) - \frac{1}{2} \left( \text{soft-DTW}_{\gamma}(X, X) + \text{soft-DTW}_{\gamma}(Y, Y) \right)\]这种差异的优势在于当\(X = Y\)时被最小化,并且在这种情况下恰好为0。
- Parameters:
- gammafloat
正则化参数。 它应该是严格正的。 值越小,平滑程度越低(更接近真实的DTW)。
- normalizebool
如果为True,则使用Soft-DTW差异。 Soft-DTW差异始终为正。 可选,默认值:False。
- dist_funccallable
距离函数或不相似度度量。 它接受形状为 (batch_size, ts_length, dim) 的两个输入参数。 它应支持 PyTorch 自动微分。 可选,默认值:None 如果为 None,则使用平方欧几里得距离。
另请参阅
soft_dtw计算两个时间序列之间的Soft-DTW度量。
cdist_soft_dtw使用Soft-DTW度量计算交叉相似性矩阵。
cdist_soft_dtw_normalized使用归一化版本的Soft-DTW度量计算交叉相似性矩阵。
参考文献
[1]Marco Cuturi 和 Mathieu Blondel. “Soft-DTW: 一种用于时间序列的可微损失函数”, ICML 2017.
[2]Mathieu Blondel, Arthur Mensch & Jean-Philippe Vert. “时间序列之间的可微散度”, 国际人工智能与统计会议, 2021.
示例
>>> import torch >>> from tslearn.metrics import SoftDTWLossPyTorch >>> soft_dtw_loss = SoftDTWLossPyTorch(gamma=0.1) >>> x = torch.zeros((4, 3, 2), requires_grad=True) >>> y = torch.arange(0, 24).reshape(4, 3, 2) >>> soft_dtw_loss_mean_value = soft_dtw_loss(x, y).mean() >>> print(soft_dtw_loss_mean_value) tensor(1081., grad_fn=<MeanBackward0>) >>> soft_dtw_loss_mean_value.backward() >>> print(x.grad.shape) torch.Size([4, 3, 2]) >>> print(x.grad) tensor([[[ 0.0000, -0.5000], [ -1.0000, -1.5000], [ -2.0000, -2.5000]], [[ -3.0000, -3.5000], [ -4.0000, -4.5000], [ -5.0000, -5.5000]], [[ -6.0000, -6.5000], [ -7.0000, -7.5000], [ -8.0000, -8.5000]], [[ -9.0000, -9.5000], [-10.0000, -10.5000], [-11.0000, -11.5000]]])