损失#

convex_kl_divergence(log_predictions, targets)

计算Kullback-Leibler散度损失的凸版本。

cosine_distance(predictions, targets[, ...])

计算目标与预测之间的余弦距离。

cosine_similarity(predictions, targets[, ...])

计算目标与预测之间的余弦相似度。

ctc_loss(logits, logit_paddings, labels, ...)

计算CTC损失。

ctc_loss_with_forward_probs(logits, ...[, ...])

计算CTC损失和CTC前向概率。

hinge_loss(predictor_outputs, targets)

计算二分类的铰链损失。

huber_loss(predictions[, targets, delta])

Huber 损失,类似于接近零的 L2 损失,远离零的 L1 损失。

kl_divergence(log_predictions, targets[, ...])

计算 Kullback-Leibler 散度(相对熵)损失。

kl_divergence_with_log_targets(...[, axis, ...])

计算Kullback-Leibler散度(相对熵)损失。

l2_loss(predictions[, targets])

计算一组预测的L2损失。

log_cosh(predictions[, targets])

计算一组预测的对数余弦损失。

make_fenchel_young_loss(max_fun)

从最大函数创建Fenchel-Young损失。

multiclass_hinge_loss(scores, labels)

多类别铰链损失。

multiclass_perceptron_loss(scores, labels)

多类感知器损失。

multiclass_sparsemax_loss(scores, labels)

多类稀疏最大损失。

ntxent(embeddings, labels[, temperature])

归一化温度缩放的交叉熵损失 (NT-Xent).

perceptron_loss(predictor_outputs, targets)

二元感知器损失。

poly_loss_cross_entropy(logits, labels[, ...])

计算logits与labels之间的PolyLoss。

ranking_softmax_loss(logits, labels, *[, ...])

排名 softmax 损失。

safe_softmax_cross_entropy(logits, labels)

计算一组logits和标签之间的softmax交叉熵。

sigmoid_binary_cross_entropy(logits, labels)

根据logits和labels计算逐元素sigmoid交叉熵。

sigmoid_focal_loss(logits, labels[, alpha, ...])

Sigmoid焦点损失。

smooth_labels(labels, alpha)

应用标签平滑。

softmax_cross_entropy(logits, labels[, ...])

计算 logits 和 labels 之间的 softmax 交叉熵。

softmax_cross_entropy_with_integer_labels(...)

计算logits与整数标签之间的softmax交叉熵。

sparsemax_loss(logits, labels)

二元 sparsemax 损失。

squared_error(predictions[, targets])

计算一组预测的平方误差。

triplet_margin_loss(anchors, positives, ...)

返回一批嵌入的三元组损失。

凸Kullback-Leibler散度#

optax.losses.convex_kl_divergence(log_predictions: chex.Array, targets: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算Kullback-Leibler散度损失的一个凸版本。

测量如果使用目标概率分布而不是预测概率分布所获得的信息增益。此版本在 p (目标) 和 q (log_predictions) 上是联合凸的。

Parameters:
  • log_predictions – 预测分布的概率,形状为 […, dim]。预期为对数空间,以避免下溢。

  • targets – 目标分布的概率,形状为 […, dim]。预计为严格正数。

  • axis – 计算的轴或多个轴。

  • where – 包含在计算中的元素。

Returns:

预测分布与目标分布的Kullback-Leibler散度,形状为[…].

参考文献

Kullback 和 Leibler, 信息与充分性, 1951

在版本 0.2.4 中更改:添加了 axiswhere 参数。

余弦距离#

optax.losses.cosine_distance(predictions: chex.Array, targets: chex.Array, epsilon: float = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算目标和预测之间的余弦距离。

这里实现的余弦 距离,测量两个向量的 不相似性,作为余弦 相似性 的相反:1 - cos(theta)

Parameters:
  • predictions – 预测向量,形状为 […, dim]

  • targets – 真实目标向量,形状为 […, dim]

  • epsilon – 余弦相似度中分母项的最小范数。

  • axis – 要计算的轴或多个轴。

  • where – 要包含在计算中的元素。

Returns:

余弦距离,形状为 […]

参考文献

余弦距离, 维基百科。

在版本 0.2.4 中更改: 添加了 axiswhere 参数。

余弦相似度#

optax.losses.cosine_similarity(predictions: chex.Array, targets: chex.Array, epsilon: float = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算目标与预测之间的余弦相似度。

余弦 相似度 是衡量向量之间相似性的一个度量,定义为它们之间角度的余弦,这也是这些向量的内积,经过归一化以具有单位范数。

Parameters:
  • predictions – 预测的向量,形状为 […, dim]

  • targets – 真实目标向量,形状为 […, dim]

  • epsilon – 余弦相似度分母中项的最小范数。

  • axis – 计算的轴或多个轴。

  • where – 要包含在计算中的元素。

Returns:

余弦相似度测量,形状 […]

参考文献

余弦相似度, 维基百科。

在版本 0.2.4 中更改:添加了 axiswhere 参数。

连接主义时间分类损失#

optax.losses.ctc_loss(logits: chex.Array, logit_paddings: chex.Array, labels: chex.Array, label_paddings: chex.Array, blank_id: int = 0, log_epsilon: float = -100000.0) chex.Array[来源]#

计算CTC损失。

有关详细信息,请参见 ctc_loss_with_forward_probs 的文档字符串。

Parameters:
  • logits – (B, T, K) 数组,包含每个类别的 logits,其中 B 表示批大小,T 表示 logits 中的最大时间帧,K 表示包括空白类别在内的类别数量。

  • logit_paddings – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels – (B, N)-数组,包含参考整数标签,其中 N 表示标签序列中的最大时间帧。

  • label_paddings – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id – 空白符号的ID。 logits[b, :, blank_id] 用作空白符号的概率。

  • log_epsilon – 数值稳定的 log(+0) 近似值。

Returns:

(B,)-一个数组,包含批次中每个序列的损失值。

optax.losses.ctc_loss_with_forward_probs(logits: chex.Array, logit_paddings: chex.Array, labels: chex.Array, label_paddings: chex.Array, blank_id: int = 0, log_epsilon: float = -100000.0) tuple[TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.Array')][来源]#

计算CTC损失和CTC前向概率。

CTC损失是一个基于模型的对数似然的损失函数,它引入了一个特殊的空白符号 \(\phi\) 来表示可变长度的输出序列。

该函数返回的前向概率作为辅助结果,分为两部分:空白的alpha概率和非空白的alpha概率。定义如下:

\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). \]

在这里,\(\pi\) 表示参考中的对齐序列 [Graves et al, 2006],这是 labels 的空白插入表示形式。返回值是上述概率的对数。

Parameters:
  • logits – (B, T, K)-数组,包含每个类别的logits,其中B表示批量大小,T表示logits中的最大时间帧数,K表示类别数量,包括用于空白的类别。

  • logit_paddings – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels – 具有参考整数标签的(B, N)数组,其中N表示标签序列中的最大时间帧。

  • label_paddings – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id – 空白标记的标识符。 logits[b, :, blank_id] 被用作 空白符号的概率。

  • log_epsilon – 数值稳定的 log(+0) 近似值。

Returns:

A tuple (loss_value, logalpha_blank, logalpha_nonblank). Here, loss_value is a (B,)-array containing the loss values for each sequence in the batch, logalpha_blank and logalpha_nonblank are (T, B, N+1)-arrays where the (t, b, n)-th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, for b-th sequence in the batch.

参考文献

Graves et al, 连接主义时间分类:用递归神经网络标记未分段的序列数据, 2006

芬切尔-杨损失#

optax.losses.make_fenchel_young_loss(max_fun: MaxFun)[来源]#

从最大函数创建一个Fenchel-Young损失。

Parameters:

max_fun – 建立在Fenchel-Young损失上的最大函数。

Returns:

一个具有相同签名的Fenchel-Young损失函数。

示例

给定一个最大函数,例如对数和指数,你可以如下轻松构造一个Fenchel-Young损失:

>>> from jax.scipy.special import logsumexp
>>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp)
Reference:

Blondel 等人 使用 Fenchel-Young 损失进行学习, 2020

警告

生成的损失接受任意数量的前导维度,fy_loss 在最后一个维度上操作。该函数的 jaxopt 版本将会把任何向量展平为一个大的 1D 向量。

铰链损失#

optax.losses.hinge_loss(predictor_outputs: chex.Array, targets: chex.Array) chex.Array[来源]#

计算二元分类的铰链损失。

Parameters:
  • predictor_outputs – 决策函数的输出。

  • targets – 目标值。目标值应严格在集合 {-1, 1} 中。

Returns:

损失值。

optax.losses.multiclass_hinge_loss(scores: chex.Array, labels: chex.Array) chex.Array[来源]#

多类别铰链损失。

Parameters:
  • scores – 模型产生的分数(浮点数)。

  • labels – 真实的整数标签。

Returns:

损失值

参考文献

铰链损失, 维基百科

在版本 0.2.3 中添加。

Huber损失#

optax.losses.huber_loss(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None, delta: float = 1.0) chex.Array[来源]#

Huber 损失,类似于接近零的 L2 损失,远离零的 L1 损失。

如果对huber loss应用梯度下降,则等同于在反向传播中将l2_loss的梯度裁剪到[-delta, delta]

Parameters:
  • predictions – 一个任意形状的向量 […].

  • targets – 一个形状可广播到predictions的向量;如果没有提供,假定为一个零的向量。

  • delta – huber损失变换的边界,默认为1。

Returns:

逐元素的Huber损失,具有与predictions相同的形状。

参考文献

Huber loss, 维基百科.

Kullback-Leibler 散度#

optax.losses.kl_divergence(log_predictions: chex.Array, targets: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算Kullback-Leibler散度(相对熵)损失。

衡量如果使用目标概率分布而不是预测概率分布所获得的信息增益。

Parameters:
  • log_predictions – 预测分布的概率,形状为 […, dim]。预期为对数空间,以避免下溢。

  • targets – 目标分布的概率,形状为 […, dim]。预计为严格正数。

  • axis – 要计算的轴或多个轴。

  • where – 要包含在计算中的元素。

Returns:

预测分布与目标分布的Kullback-Leibler散度,形状为[…].

参考文献

Kullback 和 Leibler, 信息与充分性, 1951

在版本 0.2.4 中更改: 添加了 axiswhere 参数。

optax.losses.kl_divergence_with_log_targets(log_predictions: chex.Array, log_targets: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算Kullback-Leibler散度(相对熵)损失。

kl_div_loss 的版本,其中目标以对数空间给出。

Parameters:
  • log_predictions – 预测分布的概率,形状为 […, dim]。预期为对数空间,以避免下溢。

  • log_targets – 目标分布的概率,形状为[…, dim]。预计在对数空间中。

  • axis – 要计算的轴或多个轴。

  • where – 要包含在计算中的元素。

Returns:

预测分布与目标分布的Kullback-Leibler散度,形状为[…].

在版本 0.2.4 中更改: 添加了 axiswhere 参数。

L2 平方损失#

optax.losses.squared_error(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算一组预测的平方误差。

均方误差可以通过 squared_error(a, b).mean() 计算。

Parameters:
  • predictions – 一个任意形状的向量 […].

  • targets – 一个形状可以广播到 predictions 的向量;如果未提供,则假定为一个零向量。

Returns:

逐元素平方差,与predictions具有相同的形状。

注意

l2_loss = 0.5 * squared_error,其中0.5项在Bishop的“模式识别与机器学习”中是标准的,但在Tibshirani的“统计学习的要素”中则不是。

optax.losses.l2_loss(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算一组预测的L2损失。

Parameters:
  • predictions – 一个任意形状的向量 […].

  • targets – 一个可广播到predictions形状的向量;如果没有提供,则假定它是一个零的向量。

Returns:

逐元素平方差,与预测值具有相同形状。

注意

0.5 这个术语在 Bishop 的《模式识别与机器学习》中是标准的,但在 Tibshirani 的《统计学习的要素》中则不是。

对数双曲余弦损失#

optax.losses.log_cosh(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算一组预测的对数余弦损失。

log(cosh(x)) 对于小的 x 大约是 (x**2) / 2,对于大的 x 是 abs(x) - log(2)。它是 Huber 损失的一个二次可微替代方案。

Parameters:
  • predictions – 一个任意形状的向量 […].

  • targets – 一个形状可以广播到predictions的向量;如果没有提供,则被假定为一个零的向量。

Returns:

log-cosh损失,与predictions具有相同形状。

参考文献

陈等,对数双曲余弦损失改善变分自编码器 ,2019

归一化温度缩放交叉熵 (NT-Xent) 损失#

optax.losses.ntxent(embeddings: chex.Array, labels: chex.Array, temperature: chex.Numeric = 0.07) chex.Numeric[来源]#

标准化温度缩放交叉熵损失 (NT-Xent)。

示例

>>> import jax
>>> import optax
>>> import jax.numpy as jnp
>>>
>>> key = jax.random.key(42)
>>> key1, key2, key3 = jax.random.split(key, 3)
>>> x = jax.random.normal(key1, shape=(4,2))
>>> labels = jnp.array([0, 0, 1, 1])
>>>
>>> print("input:", x)
input: [[ 0.07592554 -0.48634264]
 [ 1.2903206   0.5196119 ]
 [ 0.30040437  0.31034866]
 [ 0.5761609  -0.8074621 ]]
>>> print("labels:", labels)
labels: [0 0 1 1]
>>>
>>> w = jax.random.normal(key2, shape=(2,1)) # params
>>> b = jax.random.normal(key3, shape=(1,)) # params
>>> out = x @ w + b # model
>>>
>>> print("Embeddings:", out)
Embeddings: [[0.08969027]
 [1.6291292 ]
 [0.8622629 ]
 [0.13612625]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
loss: 1.0986123
Parameters:
  • embeddings – 嵌入批次,形状为 [批次, 特征长度]

  • labels – 代表正对的组的标签。例如,如果你有一批4个嵌入,前两个和最后两个是正对,您的labels应该像 [0, 0, 1, 1]。形状 [batch]

  • temperature - 温度缩放参数。

Returns:

一个标量损失值,NT-Xent值在所有正对上取平均

参考文献

T. Chen et al 对视觉表征的对比学习的简单框架, 2020

kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss

在版本 0.2.3 中添加。

多项式损失交叉熵#

optax.losses.poly_loss_cross_entropy(logits: chex.Array, labels: chex.Array, epsilon: float = 2.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算 logits 和标签之间的 PolyLoss。

PolyLoss 是一种损失函数,它将常用的分类损失函数分解为一系列加权多项式基。它的灵感来自于交叉熵损失和焦点损失在 \((1 - P_t)^j\) 基础上的泰勒展开。

\[L_{Poly} = \sum_1^\infty \alpha_j \cdot (1 - P_t)^j \\ L_{Poly-N} = (\epsilon_1 + 1) \cdot (1 - P_t) + \ldots + \\ (\epsilon_N + \frac{1}{N}) \cdot (1 - P_t)^N + \frac{1}{N + 1} \cdot (1 - P_t)^{N + 1} + \ldots = \\ - \log(P_t) + \sum_{j = 1}^N \epsilon_j \cdot (1 - P_t)^j \]

此函数提供了 \(L_{Poly-N}\) 的简化版本,仅改变了第一个多项式项的系数。

Parameters:
  • logits – 未归一化的对数概率,形状为 […, num_classes]

  • labels – 有效的概率分布(非负,和为1),例如一个热编码,指定每个输入的正确类别;必须具有可以广播到 […, num_classes] 的形状。

  • epsilon – 第一个多项式项的系数。根据论文,推荐以下值: - 对于ImageNet 2d图像分类,epsilon = 2.0。 - 对于2d实例分割和物体检测,epsilon = -1.0。 - 还建议根据任务调整此值,例如通过使用网格搜索。

  • axis – 要计算的轴或多个轴。

  • where – 要包含在计算中的元素。

Returns:

每个预测与相应目标分布之间的多项式损失,形状为 […]

参考文献

Leng 等, PolyLoss: 分类损失函数的多项式扩展视角, 2022

在版本 0.2.4 中更改: 添加了 axiswhere 参数。

感知器#

optax.losses.perceptron_loss(predictor_outputs: chex.Numeric, targets: chex.Numeric) chex.Numeric[来源]#

二进制感知器损失。

Parameters:
  • predictor_outputs – 模型生成的分数(浮点数)。

  • targets – 目标值。目标值应严格在集合 {-1, 1} 中。

Returns:

损失值。

参考文献

感知器, 维基百科

optax.losses.multiclass_perceptron_loss(scores: chex.Array, labels: chex.Array) chex.Array[来源]#

多类别感知机损失。

Parameters:
  • scores – 模型产生的分数。

  • labels – 真实的整数标签。

Returns:

损失值。

参考文献

迈克尔·科林斯。隐马尔可夫模型的区分训练方法:与感知机算法的理论与实验。EMNLP 2002

在版本 0.2.2 中添加。

排名Softmax损失#

optax.losses.ranking_softmax_loss(logits: chex.Array, labels: chex.Array, *, where: TypeAliasForwardRef('chex.Array') | None = None, weights: TypeAliasForwardRef('chex.Array') | None = None, reduce_fn: ~collections.abc.Callable[[...], TypeAliasForwardRef('chex.Array')] | None = <function mean>) chex.Array[来源]#

排名软max损失。

定义:

\[\ell(s, y) = -\sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]
Parameters:
  • logits – 一个 [..., list_size]-数组,表示每个项目的得分。

  • labels – 一个 [..., list_size]-Array,表示每个项目的相关性标签。

  • where – 一个可选的 [..., list_size]-Array,指示 哪些项目在计算损失时是有效的。对于那些为 False 的项目,在计算损失时将被忽略。

  • weights – 可选的 [..., list_size]-Array,表示每个项目的权重。

  • reduce_fn – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Returns:

排名softmax损失。

sigmoid 二元交叉熵#

optax.losses.sigmoid_binary_cross_entropy(logits, labels)[来源]#

根据 logits 和标签计算逐元素的 Sigmoid 交叉熵。

此函数可用于二分类或多分类(其中每个类别是一个独立的二元预测,不同类别之间不是互斥的,例如预测一张图像同时包含猫和狗。)

由于此函数是重载的,请确保您的 logitslabels 彼此兼容。如果您传入的是二进制 labels(值在 {0, 1} 中),请确保您的 logits 仅对应于类 1。如果您传入的是逐类目标概率或独热 labels,请确保您的 logits 也是多类的。如果您依赖隐式广播来重塑 logitslabels,请特别小心。

Parameters:
  • logits – 每个元素是二元预测的未归一化对数概率。请参阅上面关于与labels兼容性的说明。

  • 标签 – 二进制标签,其值为 {0,1} 或多类目标概率。请参见上面关于 logits 的兼容性说明。

Returns:

每个二进制预测的交叉熵,与 logits 形状相同。

参考文献

Goodfellow 等人, 深度学习, 2016

Sigmoid焦点损失#

optax.losses.sigmoid_focal_loss(logits: chex.Array, labels: chex.Array, alpha: float | None = None, gamma: float = 2.0) chex.Array[来源]#

Sigmoid焦点损失。

焦点损失是一种针对不平衡问题的重加权交叉熵。
如果类别不是互斥的,请使用此损失函数。
有关更多信息,请参见 sigmoid_binary_cross_entropy

Parameters:
  • logits – 浮点数组。每个示例的预测。每个示例的预测。

  • labels – 浮点数数组。标签和 logits 必须具有相同的形状。 标签数组必须包含数据集每个元素的二元分类标签(类别外为 0,类别内为 1)。

  • alpha – (可选)权重因子,范围为(0,1),用于平衡正例与负例。默认为None(无权重)。

  • gamma – 调制因子的指数 (1 - p_t)。平衡简单与困难示例。

Returns:

一个与logits和目标数组形状相同的损失值数组。

参考文献

Lin 等. Focal Loss for Dense Object Detection, 2017

平滑标签#

optax.losses.smooth_labels(labels: chex.Array, alpha: float) 数组[来源]#

应用标签平滑。

标签平滑通常与交叉熵损失结合使用。 平滑的标签倾向于减小对数值间隙,并且已经证明这可以 通过防止过于自信的预测来提供更好的模型校准。

Parameters:
  • labels – 要平滑的独热标签。

  • alpha – 平滑因子。

Returns:

输入标签的一种平滑版本。

参考文献

Muller等, 当标签平滑有帮助时?, 2019

Soft-max 交叉熵#

optax.losses.safe_softmax_cross_entropy(logits: chex.Array, labels: chex.Array) chex.Array[来源]#

计算一组逻辑值和标签之间的softmax交叉熵。

Contrarily to optax.softmax_cross_entropy() this function handles labels*logsoftmax(logits) as 0 when logits=-inf and labels=0, following the convention that 0 log 0 = 0.

Parameters:
  • logits – 未归一化的对数概率,形状为 […, num_classes]

  • labels – 有效的概率分布(非负,和为1),例如一个指定每个输入正确类别的独热编码;必须具有可广播到 […, num_classes] 的形状。

Returns:

每个预测与相应目标分布之间的交叉熵,形状为 […]

optax.losses.softmax_cross_entropy(logits: chex.Array, labels: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算一组逻辑值和标签之间的softmax交叉熵。

这个损失函数通常用于多类分类任务。它衡量预测概率分布(通过对logits应用softmax函数获得)与真实概率分布(由独热编码标签表示)之间的差异。这个损失也被称为分类交叉熵。

Let \(x\) denote the logits array of size [batch_size, num_classes] and \(y\) denote the labels array of size [batch_size, num_classes]. Then this function returns a vector \(\sigma\) of size [batch_size] defined as:

\[\sigma_i = - \sum_j y_{i j} \log\left(\frac{\exp(x_{i j})}{\sum_k \exp(x_{i k})}\right) \,. \]
Parameters:
  • logits – 未归一化的对数概率,形状为 [batch_size, num_classes].

  • labels – 一热编码标签,形状为 [batch_size, num_classes]。每一行表示单个示例的真实类别分布。

  • axis – 要计算的轴或多个轴。

  • where – 要包含在计算中的元素。

Returns:

每个预测与相应目标分布之间的交叉熵,形状为 [batch_size]

示例

>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> print(optax.softmax_cross_entropy(logits, labels))
[0.2761 2.9518]

参考文献

交叉熵损失, 维基百科

多项式逻辑回归, 维基百科

另请参见

该函数类似于 optax.losses.softmax_cross_entropy_with_integer_labels(), 但接受的是独热标签而不是整数标签。

optax.losses.safe_softmax_cross_entropy() 提供了一种替代实现,该实现与 logits=-inf 的处理方式不同。

在版本 0.2.4 中更改: 添加了 axiswhere 参数。

optax.losses.softmax_cross_entropy_with_integer_labels(logits: chex.Array, labels: chex.Array, axis: int | tuple[int, ...] = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[来源]#

计算logits与整数标签之间的softmax交叉熵。

这个损失函数对有整数标签且未进行独热编码的分类问题很有用。这个损失函数也被称为类别交叉熵。

Let \(x\) denote the logits array of size [batch_size, num_classes] and \(y\) denote the labels array of size [batch_size]. Then this function returns a vector \(\sigma\) of size [batch_size] defined as:

\[\sigma_i = \log\left(\frac{\exp(x_{i y_i})}{\sum_j \exp(x_{i j})}\right)\,. \]
Parameters:
  • logits – 未归一化的对数概率,形状为 [batch_size, num_classes].

  • labels – 整数,指定每个输入的正确类别,形状为 [batch_size]。类别标签假定在 0 和 num_classes - 1 之间(含0和1).

  • axis – Axis or axes along which to compute. If a tuple of axes is passed then num_classes must match the total number of elements in axis dimensions and a label is interpreted as a flat index in a logits slice of shape logits[axis].

  • where – 要包含在计算中的元素。

Returns:

每个预测与相应目标分布之间的交叉熵,形状为 [batch_size]

示例

>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([0, 1])
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761 2.9518]
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
>>> # elements indices in slice of shape (3, 4)
>>> ix = jnp.array([[1, 2]])
>>> jx = jnp.array([[1, 3]])
>>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
...     logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.4587 0.4587]]

参考文献

交叉熵损失, 维基百科

多项式逻辑回归, 维基百科

另请参见

这个函数与 optax.losses.softmax_cross_entropy()类似,但接受整数标签而不是独热标签。

在版本 0.2.4 中更改: 添加了 axiswhere 参数。

稀疏最大值#

optax.losses.sparsemax_loss(logits: chex.Array, labels: chex.Array) chex.Array[来源]#

二元稀疏最大损失。

当且仅当 jax.nn.sparse_sigmoid(logits) == labels 时,这个损失为零。

Parameters:
  • logits – 模型产生的分数(浮点数)。

  • labels – 真实标签整数 (0 或 1)。

Returns:

损失值

参考文献

使用Fenchel-Young损失进行学习。Mathieu Blondel, André F. T. Martins, Vlad Niculae。JMLR 2020。(第4.4节)

在版本 0.2.3 中添加。

optax.losses.multiclass_sparsemax_loss(scores: chex.Array, labels: chex.Array) chex.Array[来源]#

多类别稀疏最大损失。

Parameters:
  • scores – 模型产生的分数。

  • labels – 真实的整数标签。

Returns:

损失值

参考文献

Martins et al, 从 Softmax 到 Sparsemax: 一种稀疏注意力和多标签分类模型 , 2016.

三元组边距损失#

optax.losses.triplet_margin_loss(anchors: chex.Array, positives: chex.Array, negatives: chex.Array, axis: int = -1, norm_degree: chex.Numeric = 2, margin: chex.Numeric = 1.0, eps: chex.Numeric = 1e-06) chex.Array[来源]#

返回一批嵌入的三元组损失。

示例

>>> import jax.numpy as jnp, optax, chex
>>> jnp.set_printoptions(precision=4)
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
...                                           margin=1.0)
>>> print(output)
[0.1414 0.1414]
Parameters:
  • anchors – 一组锚点嵌入,形状为 [batch, feature_dim]。

  • positives – 一组正嵌入(类似于锚点),形状为 [batch, feature_dim]。

  • negatives – 一组负嵌入(与锚点不同),形状为 [batch, feature_dim]。

  • axis – 计算距离的轴(默认是 -1)。

  • norm_degree – 距离计算的范数度(默认值为2,表示欧几里得距离)。

  • margin – 正距离应小于负距离的最小边距。

  • eps – 一个小的 epsilon 值,用以确保距离计算中的数值稳定性。

Returns:

返回计算出的三元组损失作为数组。

参考文献

V. Balntas 等, 使用三元损失学习浅层卷积特征描述符 _, 2016