ot.weak

弱最优传输求解器

函数

ot.weak.weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs)[源]

解决两个经验分布之间的弱最优运输问题

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \sum_i \mathbf{a}_i \left(\mathbf{X^a}_i - \frac{1}{\mathbf{a}_i} \sum_j \gamma_{ij} \mathbf{X^b}_j \right)^2\\s.t. \ \gamma \mathbf{1} = \mathbf{a}\\ \gamma^T \mathbf{1} = \mathbf{b}\\ \gamma \geq 0\end{aligned}\end{align} \]

其中 :

  • \(X^a\)\(X^b\) 是样本矩阵。

  • \(\mathbf{a}\)\(\mathbf{b}\) 是样本权重

注意

此函数与后端兼容,可用于所有兼容后端的数组。但该算法使用C++ CPU后端,这可能导致在GPU数组上产生复制开销。

使用条件梯度算法来解决在 [39] 中提出的问题。

Parameters:
  • Xa ((ns,d) 数组类型, 浮点数) – 源样本

  • Xb ((nt,d) 数组类型, 浮点数) – 目标样本

  • a ((ns,) 数组类似, 浮点数) – 源直方图(如果空列表,则均匀权重)

  • b ((nt,) 类数组, 浮点数) – 目标直方图(如果为空列表则为均匀权重)

  • G0 ((ns,nt) 数组类似, 浮动) – 初始猜测(默认是独立的联合密度)

  • numItermax (int, 可选) – 最大迭代次数

  • numItermaxEmd (int, 可选) – emd的最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

Returns:

  • gamma (类似数组,形状为 (ns, nt)) – 给定参数的最优运输矩阵

  • log (字典,可选) – 如果输入日志为真,将返回一个包含成本、对偶变量和退出状态的字典

参考文献

另请参见

ot.bregman.sinkhorn

熵正则化最优传输

ot.optim.cg

通用正则化OT

ot.weak.weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs)[源]

解决两个经验分布之间的弱最优运输问题

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \sum_i \mathbf{a}_i \left(\mathbf{X^a}_i - \frac{1}{\mathbf{a}_i} \sum_j \gamma_{ij} \mathbf{X^b}_j \right)^2\\s.t. \ \gamma \mathbf{1} = \mathbf{a}\\ \gamma^T \mathbf{1} = \mathbf{b}\\ \gamma \geq 0\end{aligned}\end{align} \]

其中 :

  • \(X^a\)\(X^b\) 是样本矩阵。

  • \(\mathbf{a}\)\(\mathbf{b}\) 是样本权重

注意

此函数与后端兼容,可用于所有兼容后端的数组。但该算法使用C++ CPU后端,这可能导致在GPU数组上产生复制开销。

使用条件梯度算法来解决在 [39] 中提出的问题。

Parameters:
  • Xa ((ns,d) 数组类型, 浮点数) – 源样本

  • Xb ((nt,d) 数组类型, 浮点数) – 目标样本

  • a ((ns,) 数组类似, 浮点数) – 源直方图(如果空列表,则均匀权重)

  • b ((nt,) 类数组, 浮点数) – 目标直方图(如果为空列表则为均匀权重)

  • G0 ((ns,nt) 数组类似, 浮动) – 初始猜测(默认是独立的联合密度)

  • numItermax (int, 可选) – 最大迭代次数

  • numItermaxEmd (int, 可选) – emd的最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

Returns:

  • gamma (类似数组,形状为 (ns, nt)) – 给定参数的最优运输矩阵

  • log (字典,可选) – 如果输入日志为真,将返回一个包含成本、对偶变量和退出状态的字典

参考文献

另请参见

ot.bregman.sinkhorn

熵正则化最优传输

ot.optim.cg

通用正则化OT