• 文章
  • Triplet Loss - Advanced Intro
返回机器学习

三元组损失 - 高级介绍

尤素福·萨里戈兹

·

2022年3月24日

Triplet Loss - Advanced Intro

三元组损失是什么?

三元组损失最早在FaceNet:一个用于人脸识别和聚类的统一嵌入中于2015年提出,此后它成为监督相似性或度量学习中最流行的损失函数之一。简单来说,三元组损失鼓励不同的样本对与任何相似样本对之间至少保持一定的距离。数学上,损失值可以计算为 $L=max(d(a,p) - d(a,n) + m, 0)$,其中:

  • $p$,即正样本,是与$a$(即锚点)具有相同标签的样本,
  • $n$, 即负数,是另一个标签与 $a$ 不同的样本,
  • $d$ 是一个用于测量这三个样本之间距离的函数,
  • 并且 $m$ 是一个边距值,用于保持负样本之间的距离。

本文使用欧几里得距离,但使用任何其他距离度量同样有效,例如余弦距离。

这个函数有一个学习目标,可以如以下所示进行可视化:

Triplet Loss learning objective

三元组损失学习目标

请注意,Triplet Loss 并没有像 Contrastive Loss 那样促使将锚点和正样本编码到向量空间中的同一点。这让 Triplet Loss 可以容忍一定的类内方差,这与 Contrastive Loss 不同,因为后者本质上强迫锚点与任何正样本之间的距离为 $0$。换句话说,Triplet Loss 允许以某种方式拉伸簇以包含离群值,同时仍然确保来自不同簇的样本之间有一个边际,例如,负对。

此外,三元组损失(Triplet Loss)不那么贪婪。与对比损失(Contrastive Loss)不同,当不同的样本与相似的样本可以容易区分时,三元组损失就已经满足了。如果没有来自负样本的干扰,它不会改变正聚类中距离的关系。这是因为三元组损失试图确保负对之间的距离和正对之间的距离之间存在一个间隔。然而,对比损失仅在比较不相似的对时考虑间隔值,它完全不关心相似对当时的位置。这意味着对比损失可能会更早达到局部最小值,而三元组损失可能会继续更好地组织向量空间。

我们来演示一下两个损失函数如何通过动画组织向量空间。为了简化可视化,向量在二维空间中用点表示,且它们是从正态分布中随机选择的。

Animation that shows how Contrastive Loss moves points in the course of training.

动画显示对比损失如何在训练过程中移动点。

Animation that shows how Triplet Loss moves points in the course of training.

动画显示了Triplet Loss在训练过程中如何移动点。

从两个损失函数的数学解释来看,三元组损失在理论上更强,但三元组损失还具有一些额外的技巧,这些技巧帮助其表现得更好。最重要的是,三元组损失引入了在线三元组挖掘策略,例如,自动形成最有用的三元组。

三元组挖掘为何重要?

三元损失的公式显示它一次处理三个对象:

  • anchor,
  • positive - 一个与锚点具有相同标签的样本,
  • negative - 一个与锚标签和正标签不同的样本。

在一个简单的实现中,我们可以在每个周期开始时形成这样的样本三元组,然后在整个周期内将这些三元组的批次输入模型。这被称为“离线策略”。然而,这在几个方面并不是那么高效:

  • 它需要传递 $3n$ 个样本才能得到 $n$ 组的损失值。
  • 并不是所有这些三元组对模型学习有用,例如,产生正的损失值。
  • 即使我们在每个纪元的开始使用我将在本系列中实现的一种方法形成“有用”的三元组, 它们在纪元的某个时刻可能会变得“无用”,因为模型权重将不断更新。

相反,我们可以获取一批 $n$ 样本及其相关标签,并即时形成三元组。这被称为“在线策略”。通常,这会生成 $n^3$ 个可能的三元组,但实际上只有一部分这样的可能三元组是有效的。即便在这种情况下,我们也会从远比离线策略更多的三元组中计算出损失值。

给定一个三元组 (a, p, n),只有在满足以下条件时它是有效的:

  • ap 具有相同的标签,
  • ap 是不同的样本,
  • 并且 n 的标签与 ap 不同。

这些约束可能看起来需要在嵌套循环中进行昂贵的计算,但可以通过距离矩阵、掩膜和广播等技巧高效地实现。本系列的其余部分将重点介绍这些技巧的实现。

距离矩阵

距离矩阵是一个形状为 $(n, n)$ 的矩阵,用于存储两个 $n$ 大小集合中所有可能项对之间的距离值。这个矩阵可以用于向量化计算,否则会需要低效的循环。它的计算也可以优化,我们将实现 欧几里得距离矩阵技巧 (PDF),由 Samuel Albanie 解释。你可能想阅读这份三页的文档,以获取这个技巧的完整直觉,但简要说明如下:

  • 计算两个向量集合的点积,例如,在我们的案例中是嵌入。
  • 提取这个矩阵的对角线,该对角线包含每个嵌入的平方欧几里得范数。
  • 根据以下公式计算平方欧几里得距离矩阵: $||a - b||^2 = ||a||^2 - 2 ⟨a, b⟩ + ||b||^2$
  • 获取这个矩阵的平方根用于非方距离。

我们将使用PyTorch来实现,所以让我们从导入开始。

import torch
import torch.nn as nn
import torch.nn.functional as F

eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks

def euclidean_distance_matrix(x):
  """Efficient computation of Euclidean distance matrix

  Args:
    x: Input tensor of shape (batch_size, embedding_dim)
    
  Returns:
    Distance matrix of shape (batch_size, batch_size)
  """
  # step 1 - compute the dot product

  # shape: (batch_size, batch_size)
  dot_product = torch.mm(x, x.t())

  # step 2 - extract the squared Euclidean norm from the diagonal

  # shape: (batch_size,)
  squared_norm = torch.diag(dot_product)

  # step 3 - compute squared Euclidean distances

  # shape: (batch_size, batch_size)
  distance_matrix = squared_norm.unsqueeze(0) - 2 * dot_product + squared_norm.unsqueeze(1)

  # get rid of negative distances due to numerical instabilities
  distance_matrix = F.relu(distance_matrix)

  # step 4 - compute the non-squared distances
  
  # handle numerical stability
  # derivative of the square root operation applied to 0 is infinite
  # we need to handle by setting any 0 to eps
  mask = (distance_matrix == 0.0).float()

  # use this mask to set indices with a value of 0 to eps
  distance_matrix += mask * eps

  # now it is safe to get the square root
  distance_matrix = torch.sqrt(distance_matrix)

  # undo the trick for numerical stability
  distance_matrix *= (1.0 - mask)

  return distance_matrix

无效的三元组掩蔽

现在我们可以计算一个批次中所有可能嵌入对的距离矩阵, 我们可以应用广播来枚举所有可能三元组的距离差,并将它们表示为形状为 (batch_size, batch_size, batch_size) 的张量。 然而,正如我之前提到的,只有这些 $n^3$ 三元组的一个子集实际上是有效的,我们需要一个相应的掩码来正确计算损失值。我们将分三个步骤实现这样的帮助函数:

  • 计算不重复索引的掩码,例如,(i != j and j != k)
  • 计算有效的锚点-正样本-负样本三元组的掩码,例如, labels[i] == labels[j] and labels[j] != labels[k]
  • 合并两个掩码。
def get_triplet_mask(labels):
  """compute a mask for valid triplets

  Args:
    labels: Batch of integer labels. shape: (batch_size,)

  Returns:
    Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
    A triplet is valid if:
    `labels[i] == labels[j] and labels[i] != labels[k]`
    and `i`, `j`, `k` are different.
  """
  # step 1 - get a mask for distinct indices

  # shape: (batch_size, batch_size)
  indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
  indices_not_equal = torch.logical_not(indices_equal)
  # shape: (batch_size, batch_size, 1)
  i_not_equal_j = indices_not_equal.unsqueeze(2)
  # shape: (batch_size, 1, batch_size)
  i_not_equal_k = indices_not_equal.unsqueeze(1)
  # shape: (1, batch_size, batch_size)
  j_not_equal_k = indices_not_equal.unsqueeze(0)
  # Shape: (batch_size, batch_size, batch_size)
  distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)

  # step 2 - get a mask for valid anchor-positive-negative triplets

  # shape: (batch_size, batch_size)
  labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
  # shape: (batch_size, batch_size, 1)
  i_equal_j = labels_equal.unsqueeze(2)
  # shape: (batch_size, 1, batch_size)
  i_equal_k = labels_equal.unsqueeze(1)
  # shape: (batch_size, batch_size, batch_size)
  valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))

  # step 3 - combine two masks
  mask = torch.logical_and(distinct_indices, valid_indices)

  return mask

在线三元组挖掘的批量全策略

现在我们准备好实际实现三元组损失了。 三元组损失涉及几种策略来形成或选择三元组,最简单的方法是使用可以从一批样本中形成的所有有效三元组。 这可以通过我们已经实现的实用函数中的四个简单步骤来实现:

  • 获取批次中嵌入的所有可能对的距离矩阵。
  • 对这个矩阵应用广播,以计算所有可能三元组的损失值。
  • 将无效或简单三元组的损失值设置为 $0$。
  • 对剩余的正值进行平均以返回一个标量损失。

我将开始实施这个策略,后续会有更复杂的策略作为单独的帖子。

class BatchAllTtripletLoss(nn.Module):
  """Uses all valid triplets to compute Triplet loss

  Args:
    margin: Margin value in the Triplet Loss equation
  """
  def __init__(self, margin=1.):
    super().__init__()
    self.margin = margin
    
  def forward(self, embeddings, labels):
    """computes loss value.

    Args:
      embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
      labels: Batch of integer labels associated with embeddings. shape: (batch_size,)

    Returns:
      Scalar loss value.
    """
    # step 1 - get distance matrix
    # shape: (batch_size, batch_size)
    distance_matrix = euclidean_distance_matrix(embeddings)

    # step 2 - compute loss values for all triplets by applying broadcasting to distance matrix

    # shape: (batch_size, batch_size, 1)
    anchor_positive_dists = distance_matrix.unsqueeze(2)
    # shape: (batch_size, 1, batch_size)
    anchor_negative_dists = distance_matrix.unsqueeze(1)
    # get loss values for all possible n^3 triplets
    # shape: (batch_size, batch_size, batch_size)
    triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin

    # step 3 - filter out invalid or easy triplets by setting their loss values to 0

    # shape: (batch_size, batch_size, batch_size)
    mask = get_triplet_mask(labels)
    triplet_loss *= mask
    # easy triplets have negative loss values
    triplet_loss = F.relu(triplet_loss)

    # step 4 - compute scalar loss value by averaging positive losses
    num_positive_losses = (triplet_loss > eps).float().sum()
    triplet_loss = triplet_loss.sum() / (num_positive_losses + eps)

    return triplet_loss

结论

我提到过Triplet Loss在数学上和Contrastive Loss有所不同,不仅在于其样本选择策略,我在这篇文章中通过使用几种技巧有效地实现了在线三元组挖掘的batch-all策略。

还有其他更复杂的策略,例如批量困难和批量半困难挖掘,但它们的实现以及我在本文中用于提高效率的技巧讨论,值得单独撰写文章。

未来的文章将涵盖这些主题,以及关于避免向量坍缩和控制类内和类间方差的一些技巧的附加讨论。

这个页面有用吗?

感谢您的反馈!🙏

我们很遗憾听到这个消息。 😔 你可以 编辑 这个页面在 GitHub上,或者 create 一个 GitHub 问题。