ot.unbalanced
与不平衡最优运输问题相关的求解器。
- ot.unbalanced.barycenter_unbalanced(A, M, reg, reg_m, method='sinkhorn', weights=None, numItermax=1000, stopThr=1e-06, verbose=False, log=False, **kwargs)[源]
计算\(\mathbf{A}\)的熵不平衡瓦瑟斯坦重心。
该函数解决以下优化问题,带有 \(\mathbf{a}\)
\[\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)\]其中 :
\(W_{u_{reg}}(\cdot,\cdot)\) 是不平衡的熵正则化Wasserstein距离(见
ot.unbalanced.sinkhorn_unbalanced())\(\mathbf{a}_i\) 是矩阵 \(\mathbf{A}\) 列中的训练分布
reg 和 \(\mathbf{M}\) 分别是正则化项和 OT 的成本矩阵
reg_mis 是边际松弛超参数
用于解决该问题的算法是广义Sinkhorn-Knopp矩阵缩放算法,如[10]中所提出的。
- Parameters:
A (类数组 (维度, 历史数量)) – 历史数量 训练分布 \(\mathbf{a}_i\) 的维度 维度
M (类似数组 (维度, 维度)) – OT 的基础度量矩阵。
reg (float) – 熵正则化项 > 0
reg_m (float) – 边际放松项 > 0
权重 (类数组 (n_hists,) 可选) – 每个分布的权重(重心坐标)如果为 None,则使用均匀权重。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果为真,则记录日志
- Returns:
a ((dim,) 类数组) – 不平衡Wasserstein重心
log (字典) – 仅当参数log==True时返回日志字典
参考文献
- ot.unbalanced.barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-06, verbose=False, log=False)[源]
计算\(\mathbf{A}\)的熵不平衡瓦瑟斯坦重心。
该函数解决以下优化问题,带有 \(\mathbf{a}\)
\[\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)\]其中 :
\(W_{u_{reg}}(\cdot,\cdot)\) 是不平衡的熵正则化Wasserstein距离(见
ot.unbalanced.sinkhorn_unbalanced())\(\mathbf{a}_i\) 是矩阵 \(\mathbf{A}\) 列中的训练分布
reg 和 \(\mathbf{M}\) 分别是正则化项和 OT 的成本矩阵
reg_mis 是边际松弛超参数
用于解决该问题的算法是广义Sinkhorn-Knopp矩阵缩放算法,如[10]中所提出的。
- Parameters:
A (类数组 (维度, 历史数量)) – 历史数量 训练分布 \(\mathbf{a}_i\) 的维度 维度
M (类似数组 (维度, 维度)) – OT 的基础度量矩阵。
reg (float) – 熵正则化项 > 0
reg_m (float) – 边际放松项 > 0
权重 (类数组 (n_hists,) 可选) – 每个分布的权重(重心坐标)如果为 None,则使用均匀权重。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果 True 记录 log
- Returns:
a ((dim,) 类数组) – 非平衡瓦瑟斯坦重心
log (字典) – 只有在参数中\(log==True\)时返回日志字典
参考文献
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). 用于正则化运输问题的迭代Bregman投影。 SIAM科学计算杂志,37(2),A1111-A1138。
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 用于不平衡运输问题的缩放算法。arXiv 预印本 arXiv:1607.05816.
- ot.unbalanced.barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1000.0, numItermax=1000, stopThr=1e-06, verbose=False, log=False)[源]
计算带稳定性的熵不平衡wasserstein重心 \(\mathbf{A}\)。
该函数解决以下优化问题:
\[\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)\]其中 :
\(W_{u_{reg}}(\cdot,\cdot)\) 是不平衡的熵正则化Wasserstein距离(见
ot.unbalanced.sinkhorn_unbalanced())\(\mathbf{a}_i\) 是矩阵 \(\mathbf{A}\) 列中的训练分布
reg 和 \(\mathbf{M}\) 分别是正则化项和 OT 的成本矩阵
reg_mis 是边际松弛超参数
用于解决该问题的算法是广义Sinkhorn-Knopp矩阵缩放算法,如[10]中所提出的。
- Parameters:
A (类数组 (维度, 历史数量)) – 历史数量 训练分布 \(\mathbf{a}_i\) 的维度 维度
M (类似数组 (维度, 维度)) – OT 的基础度量矩阵。
reg (float) – 熵正则化项 > 0
reg_m (float) – 边际放松项 > 0
tau (float) – 对数域吸收的稳定性阈值。
权重 (类数组 (n_hists,) 可选) – 每个分布的权重(重心坐标)如果为 None,则使用均匀权重。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果 True 记录 log
- Returns:
a ((dim,) 类数组) – 不平衡的 Wasserstein 中心
log (字典) – 仅在参数中 \(log==True\) 时返回日志字典
参考文献
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). 用于正则化运输问题的迭代Bregman投影. SIAM科学计算杂志, 37(2), A1111-A1138.
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 用于不平衡运输问题的扩展算法。arXiv预印本 arXiv:1607.05816。
- ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False)[源]
解决不平衡的最优运输问题,并使用L-BFGS-B算法返回OT计划。该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中:
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
\(\mathrm{div_m}\) 是一种散度,可以是 Kullback-Leibler 散度,
或半平方\(\ell_2\)散度,或总变差 - \(\mathrm{div}\)是一个散度,可以是Kullback-Leibler散度, 或半平方\(\ell_2\)散度
注意
此函数与后端兼容,将在所有兼容的后端上工作。首先,它将所有数组转换为Numpy数组,然后使用来自scipy.optimize的L-BFGS-B算法来解决优化问题。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 未归一化的直方图,维度为 dim_b 如果 b 是一个空列表或数组 ([]), 则 b 被设定为均匀分布。
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 正则化项 >=0
c (类数组 (dim_a, dim_b), 可选 (默认 = None)) – 正则化的参考度量。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。
reg_m (浮点数 或 可索引对象,长度为1 或 2) – 边际放宽项:非负(包括0)但不能是无穷大。如果 \(\mathrm{reg_{m}}\) 是标量或长度为1的可索引对象,则相同的\(\mathrm{reg_{m}}\) 应用于两个边际放宽。如果 \(\mathrm{reg_{m}}\) 是数组,它必须是一个Numpy数组。
reg_div (string, optional) – 用于正则化的散度。可以取三种值:‘entropy’(负熵),或‘kl’(Kullback-Leibler)或‘l2’(半平方),或者一个返回正则项及其导数的两个可调用函数的元组。请注意,可调用函数应该能够处理Numpy数组而不是后端的张量。
regm_div (string, optional) – 用于量化边际之间差异的散度。可以取三个值:‘kl’ (Kullback-Leibler) 或 ‘l2’ (半平方) 或 ‘tv’ (总变化)
G0 (类似数组 (dim_a, dim_b)) – 传输矩阵的初始化
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果为真,则记录日志
- Returns:
gamma ((dim_a, dim_b) 类数组) – 给定参数的最佳运输矩阵
log (字典) – 仅在 log 为 True 时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) array([[0.45, 0. ], [0. , 0.34]]) >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) array([[0.4, 0. ], [0. , 0.1]])
参考文献
[41] Chapel, L., Flamary, R., Wu, H., Févotte, C., 和 Gasso, G. (2021). 通过非负惩罚线性回归进行不平衡最优传输。NeurIPS。
另请参见
ot.lp.emd2未正则化的OT损失
ot.unbalanced.sinkhorn_unbalanced2熵正则化OT损失
- ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, returnCost='linear', numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False)[源]
解决不平衡的最优传输问题,并使用L-BFGS-B返回OT成本。该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}\min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中:
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
\(\mathrm{div_m}\) 是一种散度,可以是 Kullback-Leibler 散度,
或半平方\(\ell_2\)散度,或总变差 - \(\mathrm{div}\)是一个散度,可以是Kullback-Leibler散度, 或半平方\(\ell_2\)散度
注意
此函数与后端兼容,将在所有兼容的后端上工作。首先,它将所有数组转换为Numpy数组,然后使用来自scipy.optimize的L-BFGS-B算法来解决优化问题。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 未归一化的直方图,维度为 dim_b 如果 b 是一个空列表或数组 ([]), 则 b 被设定为均匀分布。
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 正则化项 >=0
c (类数组 (dim_a, dim_b), 可选 (默认 = None)) – 正则化的参考度量。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。
reg_m (浮点数 或 可索引对象,长度为1 或 2) – 边际放宽项:非负(包括0)但不能是无穷大。如果 \(\mathrm{reg_{m}}\) 是标量或长度为1的可索引对象,则相同的\(\mathrm{reg_{m}}\) 应用于两个边际放宽。如果 \(\mathrm{reg_{m}}\) 是数组,它必须是一个Numpy数组。
reg_div (string, optional) – 用于正则化的散度。可以取三种值:‘entropy’(负熵),或‘kl’(Kullback-Leibler)或‘l2’(半平方),或者一个返回正则项及其导数的两个可调用函数的元组。请注意,可调用函数应该能够处理Numpy数组而不是后端的张量。
regm_div (string, optional) – 用于量化边际之间差异的散度。可以取三个值:‘kl’ (Kullback-Leibler) 或 ‘l2’ (半平方) 或 ‘tv’ (总变化)
G0 (类似数组 (dim_a, dim_b)) – 传输矩阵的初始化
returnCost (string, optional (default = "linear")) – 如果 returnCost = “linear”,则返回不平衡OT损失的线性部分。 如果 returnCost = “total”,则返回总的不平衡OT损失。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果为真,则记录日志
- Returns:
ot_cost (array-like) – \(\mathbf{a}\) 和 \(\mathbf{b}\) 之间的 OT 成本
log (dict) – 仅在 log 为 True 时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) 1.79 >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) 0.8
参考文献
[41] Chapel, L., Flamary, R., Wu, H., Févotte, C., 和 Gasso, G. (2021). 通过非负惩罚线性回归进行不平衡最优传输。NeurIPS。
另请参见
ot.lp.emd2未正则化的OT损失
ot.unbalanced.sinkhorn_unbalanced2熵正则化OT损失
- ot.unbalanced.mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False)[源]
解决不平衡的最优运输问题并返回OT计划。 该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中:
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
div 是一种散度,既可以是Kullback-Leibler散度,也可以是半平方\(\ell_2\)散度
用于解决该问题的算法是最大化-最小化算法,如在 [41] 中提出的。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 未归一化的直方图,维度为 dim_b 如果 b 是一个空列表或数组 ([]), 则 b 被设定为均匀分布。
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg_m (浮点数 或 可索引对象 长度为 1 或 2) – 边际松弛项:非负但不能为无穷大。 如果 \(\mathrm{reg_{m}}\) 是标量或长度为 1 的可索引对象, 那么同样的 \(\mathrm{reg_{m}}\) 适用于两个边际松弛。 如果 \(\mathrm{reg_{m}}\) 是数组, 它必须与输入数组 (a, b, M) 具有相同的后端。
reg (float, 可选 (默认 = 0)) – 正则化项 >= 0。默认情况下,求解无正则化的问题
c (类数组 (dim_a, dim_b), 可选 (默认 = None)) – 正则化的参考度量。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。
div (string, optional) – 用于量化边际之间差异的发散度。可以取两个值:‘kl’(Kullback-Leibler)或 ‘l2’(半平方)
G0 (类似数组 (dim_a, dim_b)) – 传输矩阵的初始化
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果为真,则记录日志
- Returns:
gamma ((dim_a, dim_b) 类数组) – 给定参数的最佳运输矩阵
log (字典) – 仅在 log 为 True 时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2) array([[0.45, 0. ], [0. , 0.34]]) >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2) array([[0.4, 0. ], [0. , 0.1]])
参考文献
[41] Chapel, L., Flamary, R., Wu, H., Févotte, C., 和 Gasso, G. (2021). 通过非负惩罚线性回归进行不平衡最优传输。NeurIPS。
另请参见
ot.lp.emd未正则化的OT
ot.unbalanced.sinkhorn_unbalanced熵正则化最优传输
- ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost='linear', numItermax=1000, stopThr=1e-15, verbose=False, log=False)[源]
求解不平衡最优运输问题并返回 OT 成本。该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}\min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中:
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
\(\mathrm{div}\) 是发散度,可以是Kullback-Leibler发散度或半平方的 \(\ell_2\) 发散度
用于解决该问题的算法是最大化-最小化算法,如在 [41] 中提出的。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 未归一化的直方图,维度为 dim_b 如果 b 是一个空列表或数组 ([]), 则 b 被设定为均匀分布。
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg_m (浮点数 或 可索引对象 长度为 1 或 2) – 边际松弛项:非负但不能为无穷大。 如果 \(\mathrm{reg_{m}}\) 是标量或长度为 1 的可索引对象, 那么同样的 \(\mathrm{reg_{m}}\) 适用于两个边际松弛。 如果 \(\mathrm{reg_{m}}\) 是数组, 它必须与输入数组 (a, b, M) 具有相同的后端。
reg (float, 可选 (默认 = 0)) – 熵正则化项 >= 0。 默认情况下,解决无正则化的问题
c (类似数组 (dim_a, dim_b), 可选 (默认 = None)) – 正则化的参考度量。 如果为 None,则使用 \(\mathbf{c} = mathbf{a} mathbf{b}^T\)。
div (string, optional) – 用于量化边际之间差异的发散度。可以取两个值:‘kl’(Kullback-Leibler)或 ‘l2’(半平方)
G0 (类似数组 (dim_a, dim_b)) – 传输矩阵的初始化
returnCost (string, optional (default = "linear")) – 如果 returnCost = “linear”,则返回不平衡OT损失的线性部分。 如果 returnCost = “total”,则返回总的不平衡OT损失。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果为真,则记录日志
- Returns:
ot_cost (array-like) – \(\mathbf{a}\) 和 \(\mathbf{b}\) 之间的 OT 成本
log (dict) – 仅在 log 为 True 时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2) 0.8 >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2) 1.79
参考文献
[41] Chapel, L., Flamary, R., Wu, H., Févotte, C., 和 Gasso, G. (2021). 通过非负惩罚线性回归进行不平衡最优传输。NeurIPS。
另请参见
ot.lp.emd2未正则化的OT损失
ot.unbalanced.sinkhorn_unbalanced2熵正则化OT损失
- ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type='kl', c=None, warmstart=None, numItermax=1000, stopThr=1e-06, verbose=False, log=False, **kwargs)[源]
解决熵正则化的不平衡最优运输问题并返回OT方案
该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中 :
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
KL是Kullback-Leibler散度
用于解决问题的算法是泛化的Sinkhorn-Knopp矩阵缩放算法,如[10, 25]中所提到的
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 一个或多个维度为 dim_b 的未规范化直方图。 如果 b 是一个空列表或数组 ([]),那么 b 被设置为均匀分布。 如果有多个,计算所有的OT成本 \((\mathbf{a}, \mathbf{b}_i)_i\)
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 熵正则化项 > 0
reg_m (浮点数 或 可索引对象,长度为 1 或 2) – 边际放松项。 如果 \(\mathrm{reg_{m}}\) 是一个标量或长度为 1 的可索引对象, 则相同的 \(\mathrm{reg_{m}}\) 应用于两个边际放松。 可以使用 \(\mathrm{reg_{m}}=float("inf")\) 来恢复熵平衡的 OT。 对于半放松情况,可以使用以下任一方法: \(\mathrm{reg_{m}}=(float("inf"), 标量)\) 或 \(\mathrm{reg_{m}}=(标量, float("inf"))\)。 如果 \(\mathrm{reg_{m}}\) 是一个数组, 它必须与输入数组 (a, b, M) 具有相同的后端。
reg_type (string, optional) – 正则化项。可以取两个值: + 负熵:‘entropy’: \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\)。 这与\(\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)\)在常数范围内是等价的。 + Kullback-Leibler散度:‘kl’: \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\)。
c (数组类型 (dim_a, dim_b), 可选 (默认=None)) – 参考测量,用于正则化。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。 如果 \(\texttt{reg_type}='entropy'\),则 \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\)。
warmstart (tuple 的 数组, 形状 (dim_a, dim_b), 可选) – 双重潜在值的初始化。如果提供,应该给出双重潜在值 (即 u, v Sinkhorn 缩放向量的对数)。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果 True 记录 log
- Returns:
如果 n_hists == 1 –
- gamma(dim_a, dim_b) array-like
给定参数的最优运输矩阵
- logdict
仅当 log 为 True 时返回的日志字典
否则 –
- ot_cost(n_hists,) array-like
在 \(\mathbf{a}\) 和每个直方图 \(\mathbf{b}_i\) 之间的 OT 成本
- logdict
仅当 log 为 True 时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> np.round(ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.), 7) array([[0.3220536, 0.1184769], [0.1184769, 0.3220536]])
参考文献
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 用于不平衡运输问题的扩展算法。arXiv预印本 arXiv:1607.05816。
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : 使用Wasserstein损失进行学习,神经信息处理系统进展 (NIPS) 2015
另请参见
ot.lp.emd未正则化的OT
ot.optim.cg通用正则化OT
- ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type='kl', c=None, warmstart=None, tau=100000.0, numItermax=1000, stopThr=1e-06, verbose=False, log=False, **kwargs)[源]
求解熵正则化的不平衡最优运输问题并返回损失
该函数使用对数域稳定化方法解决以下优化问题,如[10]中所提议:
\[ \begin{align}\begin{aligned}W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中 :
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
KL是Kullback-Leibler散度
用于解决该问题的算法是广义Sinkhorn-Knopp矩阵缩放算法,如[10, 25]中所提到的。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 一个或多个维度为 dim_b 的未规范化直方图。 如果 b 是一个空列表或数组 ([]),那么 b 被设置为均匀分布。 如果有多个,计算所有的OT成本 \((\mathbf{a}, \mathbf{b}_i)_i\)
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 熵正则化项 > 0
reg_m (浮点数 或 可索引对象,长度为 1 或 2) – 边际放松项。 如果 \(\mathrm{reg_{m}}\) 是一个标量或长度为 1 的可索引对象, 则相同的 \(\mathrm{reg_{m}}\) 应用于两个边际放松。 可以使用 \(\mathrm{reg_{m}}=float("inf")\) 来恢复熵平衡的 OT。 对于半放松情况,可以使用以下任一方法: \(\mathrm{reg_{m}}=(float("inf"), 标量)\) 或 \(\mathrm{reg_{m}}=(标量, float("inf"))\)。 如果 \(\mathrm{reg_{m}}\) 是一个数组, 它必须与输入数组 (a, b, M) 具有相同的后端。
方法 (str) – 求解器使用的方法,可以是‘sinkhorn’,‘sinkhorn_stabilized’或‘sinkhorn_reg_scaling’,具体参数请参见这些函数
reg_type (string, optional) – 正则化项。可以取两个值: + 负熵:‘entropy’: \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\)。 这与\(\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)\)在常数范围内是等价的。 + Kullback-Leibler散度:‘kl’: \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\)。
c (数组类型 (dim_a, dim_b), 可选 (默认=None)) – 参考测量,用于正则化。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。 如果 \(\texttt{reg_type}='entropy'\),则 \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\)。
warmstart (tuple 的 数组, 形状 (dim_a, dim_b), 可选) – 双重潜在值的初始化。如果提供,应该给出双重潜在值 (即 u, v Sinkhorn 缩放向量的对数)。
tau (float) – log缩放中u或v的最大值阈值
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (>0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果 True 记录 log
- Returns:
如果 n_hists == 1 –
- gamma(dim_a, dim_b) array-like
给定参数的最优运输矩阵
- logdict
仅在 log 为 True 时返回的日志字典
否则 –
- ot_cost(n_hists,) array-like
在 \(\mathbf{a}\) 和每个直方图 \(\mathbf{b}_i\) 之间的 OT 成本
- logdict
仅在 log 为 True 时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> np.round(ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.), 7) array([[0.3220536, 0.1184769], [0.1184769, 0.3220536]])
参考文献
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 不平衡运输问题的规模化算法。arXiv 预印本 arXiv:1607.05816.
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : 使用Wasserstein损失进行学习,神经信息处理系统进展 (NIPS) 2015
另请参见
ot.lp.emd未正则化的OT
ot.optim.cg通用正则化OT
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', reg_type='kl', c=None, warmstart=None, numItermax=1000, stopThr=1e-06, verbose=False, log=False, **kwargs)[源]
求解不平衡的熵正则化最优传输问题并返回OT计划
该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}W = \arg \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中 :
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
KL是Kullback-Leibler散度
用于解决该问题的算法是广义Sinkhorn-Knopp矩阵缩放算法,如[10, 25]中所提到的。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 一个或多个维度为 dim_b 的未规范化直方图。 如果 b 是一个空列表或数组 ([]),那么 b 被设置为均匀分布。 如果有多个,计算所有的OT成本 \((\mathbf{a}, \mathbf{b}_i)_i\)
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 熵正则化项 > 0
reg_m (浮点数 或 可索引对象,长度为 1 或 2) – 边际放松项。 如果 \(\mathrm{reg_{m}}\) 是一个标量或长度为 1 的可索引对象, 则相同的 \(\mathrm{reg_{m}}\) 应用于两个边际放松。 可以使用 \(\mathrm{reg_{m}}=float("inf")\) 来恢复熵平衡的 OT。 对于半放松情况,可以使用以下任一方法: \(\mathrm{reg_{m}}=(float("inf"), 标量)\) 或 \(\mathrm{reg_{m}}=(标量, float("inf"))\)。 如果 \(\mathrm{reg_{m}}\) 是一个数组, 它必须与输入数组 (a, b, M) 具有相同的后端。
方法 (str) – 求解器使用的方法,可为‘sinkhorn’,‘sinkhorn_stabilized’,‘sinkhorn_translation_invariant’或‘sinkhorn_reg_scaling’,具体参数请参见这些函数
reg_type (string, optional) – 正则化项。可以取两个值: + 负熵:‘entropy’: \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\)。 这与\(\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)\)在常数范围内是等价的。 + Kullback-Leibler散度:‘kl’: \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\)。
c (数组类型 (dim_a, dim_b), 可选 (默认=None)) – 参考测量,用于正则化。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。 如果 \(\texttt{reg_type}='entropy'\),则 \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\)。
warmstart (tuple 的 数组, 形状 (dim_a, dim_b), 可选) – 双重潜在值的初始化。如果提供,应该给出双重潜在值 (即 u, v Sinkhorn 缩放向量的对数)。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (>0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果 True 记录 log
- Returns:
if n_hists == 1 –
- gamma(dim_a, dim_b) array-like
给定参数的最优运输矩阵
- logdict
仅在log为True时返回的日志字典
else –
- ot_distance(n_hists,) array-like
\(\mathbf{a}\)与每个直方图\(\mathbf{b}_i\)之间的OT距离
- logdict
仅在log为True时返回的日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> np.round(ot.sinkhorn_unbalanced(a, b, M, 1, 1), 7) array([[0.3220536, 0.1184769], [0.1184769, 0.3220536]])
参考文献
[2] M. Cuturi, Sinkhorn 距离:快速计算最优运输,神经信息处理系统进展 (NIPS) 26, 2013
[9] Schmitzer, B. (2016). 稳定稀疏缩放算法用于熵正则化运输问题。arXiv 预印本 arXiv:1610.06519.
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 用于不平衡运输问题的扩展算法。arXiv预印本 arXiv:1607.05816。
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : 使用Wasserstein损失进行学习,神经信息处理系统进展 (NIPS) 2015
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). 更快的非平衡最优运输:平移不变的Sinkhorn和1维Frank-Wolfe。 在国际人工智能与统计会议上 (pp. 4995-5021). PMLR.
另请参见
ot.unbalanced.sinkhorn_knopp_unbalanced非平衡经典Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized_unbalanced不平衡稳定化Sinkhorn [9, 10]
ot.unbalanced.sinkhorn_reg_scaling_unbalanced带有 epsilon 缩放的非平衡 Sinkhorn [9, 10]
ot.unbalanced.sinkhorn_unbalanced_translation_invariant翻译不变的非平衡Sinkhorn [73]
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', reg_type='kl', c=None, warmstart=None, returnCost='linear', numItermax=1000, stopThr=1e-06, verbose=False, log=False, **kwargs)[源]
解决熵正则化不平衡最优运输问题并返回成本
该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}\min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma\geq 0\end{aligned}\end{align} \]其中 :
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
\(\mathbf{c}\) 是正则化的参考分布
KL是Kullback-Leibler散度
用于解决该问题的算法是广义Sinkhorn-Knopp矩阵缩放算法,如[10, 25]中所提到的。
- Parameters:
a (数组类型 (dim_a,)) – 未归一化的维度 dim_a 的直方图 如果 a 是空列表或数组([]), 则 a 被设置为均匀分布。
b (类数组 (dim_b,)) – 一个或多个维度为 dim_b 的未规范化直方图。 如果 b 是一个空列表或数组 ([]),那么 b 被设置为均匀分布。 如果有多个,计算所有的OT成本 \((\mathbf{a}, \mathbf{b}_i)_i\)
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 熵正则化项 > 0
reg_m (浮点数 或 可索引对象,长度为 1 或 2) – 边际放松项。 如果 \(\mathrm{reg_{m}}\) 是一个标量或长度为 1 的可索引对象, 则相同的 \(\mathrm{reg_{m}}\) 应用于两个边际放松。 可以使用 \(\mathrm{reg_{m}}=float("inf")\) 来恢复熵平衡的 OT。 对于半放松情况,可以使用以下任一方法: \(\mathrm{reg_{m}}=(float("inf"), 标量)\) 或 \(\mathrm{reg_{m}}=(标量, float("inf"))\)。 如果 \(\mathrm{reg_{m}}\) 是一个数组, 它必须与输入数组 (a, b, M) 具有相同的后端。
方法 (str) – 求解器使用的方法,可为‘sinkhorn’,‘sinkhorn_stabilized’,‘sinkhorn_translation_invariant’或‘sinkhorn_reg_scaling’,具体参数请参见这些函数
reg_type (string, optional) – 正则化项。可以取两个值: + 负熵:‘entropy’: \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\)。 这与\(\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)\)在常数范围内是等价的。 + Kullback-Leibler散度:‘kl’: \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\)。
c (数组类型 (dim_a, dim_b), 可选 (默认=None)) – 参考测量,用于正则化。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。 如果 \(\texttt{reg_type}='entropy'\),则 \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\)。
warmstart (元组 of 数组, 形状 (dim_a, dim_b), 可选) – 对偶势的初始化。如果提供,应该给出对偶势 (即u,v sinkhorn缩放向量的对数)。
returnCost (string, optional (default = "linear")) – 如果 returnCost = “linear”,则返回不平衡OT损失的线性部分。 如果 returnCost = “total”,则返回总的不平衡OT损失。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (>0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果 True 记录 log
- Returns:
ot_cost ((n_hists,) 类似于数组) – \(\mathbf{a}\) 与每个直方图 \(\mathbf{b}_i\) 之间的OT成本
log (字典) – 仅在 log 为 True 时返回日志字典
示例
>>> import ot >>> import numpy as np >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8) 0.19600125
参考文献
[2] M. Cuturi, Sinkhorn 距离:快速计算最优运输,神经信息处理系统进展 (NIPS) 26, 2013
[9] Schmitzer, B. (2016). 稳定稀疏缩放算法用于熵正则化运输问题。arXiv 预印本 arXiv:1610.06519.
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). 用于不平衡运输问题的扩展算法。arXiv预印本 arXiv:1607.05816。
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : 使用Wasserstein损失进行学习,神经信息处理系统进展 (NIPS) 2015
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). 更快的非平衡最优运输:平移不变的Sinkhorn和1维Frank-Wolfe。 在国际人工智能与统计会议上 (pp. 4995-5021). PMLR.
另请参见
ot.unbalanced.sinkhorn_knopp非平衡经典Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized不平衡稳定化Sinkhorn [9, 10]
ot.unbalanced.sinkhorn_reg_scaling带有 epsilon 缩放的非平衡 Sinkhorn [9, 10]
ot.unbalanced.sinkhorn_unbalanced_translation_invariant翻译不变的非平衡Sinkhorn [73]
- ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type='kl', c=None, warmstart=None, numItermax=1000, stopThr=1e-06, verbose=False, log=False, **kwargs)[源]
解决熵正则化的不平衡最优运输问题并返回OT方案
该函数解决以下优化问题:
\[ \begin{align}\begin{aligned}W = \arg \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]其中 :
\(\mathbf{M}\) 是 (dim_a, dim_b) 费用矩阵
\(\Omega\) 是熵正则化项,KL 效度
\(\mathbf{a}\) 和 \(\mathbf{b}\) 是源和目标不平衡分布
KL是Kullback-Leibler散度
用于解决该问题的算法是翻译不变Sinkhorn算法,如[73]中所提出的
- Parameters:
a (类数组 (dim_a,)) – 未归一化的维度 dim_a 的直方图
b (数组类型 (dim_b,) 或 数组类型 (dim_b, n_hists)) – 一个或多个维度为 dim_b 的未归一化直方图 如果有多个,计算所有的OT距离 (a, b_i)
M (类似数组 (dim_a, dim_b)) – 损失矩阵
reg (float) – 熵正则化项 > 0
reg_m (浮点数 或 可索引对象 的 长度为 1 或 2) – 边际松弛项。 如果 reg_m 是一个标量或长度为 1 的可索引对象, 则相同的 reg_m 应用于两个边际松弛。 熵平衡 OT 可以通过 reg_m=float(“inf”) 恢复。 对于半放松情况,使用 reg_m=(float(“inf”), scalar) 或 reg_m=(scalar, float(“inf”))。 如果 reg_m 是一个数组,它必须与输入数组 (a, b, M) 具有相同的后端。
reg_type (字符串, 可选) – 正则化项。可以取两个值: ‘entropy’ (负熵) \(\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}\),或 ‘kl’ (Kullback-Leibler) \(\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)\)。
c (数组类型 (dim_a, dim_b), 可选 (默认=None)) – 参考测量,用于正则化。 如果为 None,则使用 \(\mathbf{c} = \mathbf{a} \mathbf{b}^T\)。 如果 \(\texttt{reg_type}='entropy'\),则 \(\mathbf{c} = 1_{dim_a} 1_{dim_b}^T\)。
warmstart (元组 of 数组, 形状 (dim_a, dim_b), 可选) – 对偶势的初始化。如果提供,应该给出对偶势 (即u,v sinkhorn缩放向量的对数)。
numItermax (int, 可选) – 最大迭代次数
stopThr (float, 可选) – 错误的停止阈值 (> 0)
verbose (bool, 可选) – 在迭代过程中打印信息
log (bool, 可选) – 如果为真,则记录日志
- Returns:
if n_hists == 1 –
- gamma(dim_a, dim_b) array-like
给定参数的最优运输矩阵
- logdict
仅在log为True时返回的日志字典
else –
- ot_distance(n_hists,) array-like
\(\mathbf{a}\)与每个直方图\(\mathbf{b}_i\)之间的OT距离
- logdict
仅在log为True时返回的日志字典
示例
>>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, 1., 1.) array([[0.32205357, 0.11847689], [0.11847689, 0.32205357]])
参考文献
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). 更快的非平衡最优运输:平移不变的Sinkhorn和1维Frank-Wolfe。 在国际人工智能与统计会议上 (pp. 4995-5021). PMLR.