ot.optim

用于正则化OT或其半放松版本的通用求解器。

函数

ot.optim.cg(a, b, M, reg, f, df, G0=None, line_search=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, nx=None, **kwargs)[源]

解决带条件梯度的广义正则化OT问题

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是条件梯度,如[1]中所讨论的

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (数组类型, 形状 (nt,)) – 目标领域中的样本

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg (float) – 正则化项 >0

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • line_search (function,) – 寻找最佳步长的函数。默认值为 None,并调用一个包装器来 line_search_armijo。

  • numItermax (int, 可选) – 最大迭代次数

  • numItermaxEmd (int, 可选) – emd的最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • nx (后端, 可选) – 如果保持默认值 None,后端将根据其他输入推断出来。

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

另请参见

ot.lp.emd

未正则化的最优传输

ot.bregman.sinkhorn

熵正则化的最优运输

使用 ot.optim.cg 的示例

使用通用求解器的正则化OT

带通用求解器的正则化OT
ot.optim.gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, **kwargs)[源]

使用广义条件梯度求解一般正则化OT问题

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(\Omega\) 是熵正则化项 \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是广义条件梯度,如[5, 7]中讨论的那样

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类数组, (nt,)) – 目标域中的样本

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg1 (float) – 熵正则化项 >0

  • reg2 (float) – 第二个正则化项 >0

  • G0 (类似数组, 形状 (ns, nt), 可选) – 初始猜测(默认是独立联合密度)

  • numItermax (int, 可选) – 最大迭代次数

  • numInnerItermax (int, 可选) – Sinkhorn的最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

Returns:

  • gamma (ndarray, shape (ns, nt)) – 给定参数的最优运输矩阵

  • log (dict) – 仅在参数中log==True时返回日志字典

参考文献

另请参见

ot.optim.cg

条件梯度

使用 ot.optim.gcg 的示例

使用通用求解器的正则化OT

带通用求解器的正则化OT
ot.optim.generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, numItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, nx=None, **kwargs)[源]

使用条件梯度或广义条件梯度解决一般正则化OT问题或其半放松版本,具体取决于提供的线性规划求解器。

如果设置为条件梯度,该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b} (可选约束)\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是条件梯度,如[1]中所讨论的

该函数解决以下优化问题,如果设置一个广义条件梯度:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\Omega\) 是熵正则化项 \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)

用于解决该问题的算法是广义条件梯度,如[5, 7]中讨论的那样

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类数组, 形状 (nt,)) – 目标领域中的样本权重

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • f (函数) – 作为参数的正则化函数,接受一个运输矩阵

  • df (function) – 作为参数的运输矩阵的正则化函数的梯度

  • reg1 (float) – 正则化项 >0

  • reg2 (float,) – 熵正则化项 >0。如果设置为 None 则被忽略。

  • lp_solver (function,) –

    用于方向寻找的线性规划求解器(广义)条件梯度。 这个函数必须采用以下形式 lp_solver(a, b, Mi, **kwargs),其中 p: ab 是两个领域中的样本权重; Mi 是正则化目标的梯度;通过 kwargs 获取最优参数。 它必须输出一个可接受的运输计划。

    例如,对于一般的正则化 OT 问题与条件梯度 [1]

    def lp_solver(a, b, M, **kwargs):

    return ot.emd(a, b, M)

    或者使用广义条件梯度 [5, 7]

    def lp_solver(a, b, Mi, **kwargs):

    return ot.sinkhorn(a, b, Mi)

  • line_search (function,) –

    寻找最优步长的函数。此函数必须采用以下形式 line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs),其中:cost 为成本函数,G 为运输计划,deltaG 为由 lp_solver 提供的条件 梯度方向,Mi 为正则化目标的梯度,cost_G 为 G 的成本,df_G 为 G 处正则化器的梯度。 支持两种类型的输出:

    例如,ot.optim.line_search_armijo(通用求解器), ot.gromov.solve_gromov_linesearch(FGW 问题), solve_semirelaxed_gromov_linesearch(srFGW 问题)和 gcg_linesearch(广义 cg),输出:线搜索步长 alpha, 求解器中使用的迭代次数(如果适用)以及在步长 alpha 上的损失值。这些可以被调用,例如:

    def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):

    return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)

    例如,ot.gromov.solve_partial_gromov_linesearch 对于部分 (F)GW 问题,最终输出增加下一个步骤的梯度阅读,作为先前计算的梯度的凸组合,利用正则化器的二次形式。

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • numItermax (int, 可选) – 最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • nx (后端, 可选) – 如果保持默认值 None,后端将根据其他输入推断出来。

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

另请参见

ot.lp.emd

未正则化的最优传输

ot.bregman.sinkhorn

熵正则化的最优运输

ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=0.0001, alpha0=0.99, alpha_min=0.0, alpha_max=None, nx=None, **kwargs)[源]

与矩阵一起工作的Armijo线搜索函数

寻找满足阿米霍条件的\(f(x_k + \alpha \cdot p_k)\)的近似最小值。

注意

如果损失函数 f 返回一个浮点数(响应一个一维数组),则返回的 alpha 和 fa 是浮点数(响应一维数组)。

Parameters:
  • f (可调用的) – 损失函数

  • xk (类数组) – 初始位置

  • pk (数组类型) – 下降方向

  • gfk (类似数组) – 在\(x_k\)处的f的梯度

  • old_fval (浮点数一维数组) – 在 \(x_k\) 的损失值

  • args (tuple, 可选) – 传递给 f 的参数

  • c1 (float, 可选) – \(c_1\) armijo 规则中的常数 (>0)

  • alpha0 (float, 可选) – 初始步长 (>0)

  • alpha_min (float, default=0.) – alpha的最小值

  • alpha_max (float, 可选) – alpha的最大值

  • nx (backend, 可选) – 如果默认值为 None,将进行后端测试。

Returns:

  • alpha (浮点数或1维数组) – 满足阿米霍条件的步长

  • fc (整数) – 函数调用次数

  • fa (浮点数或1维数组) – 步长alpha处的损失值

ot.optim.partial_cg(a, b, a_extended, b_extended, M, reg, f, df, G0=None, line_search=<function line_search_armijo>, numItermax=200, stopThr=1e-09, stopThr2=1e-09, warn=True, verbose=False, log=False, **kwargs)[源]

解决具有条件梯度的一般正则化部分OT问题

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma \mathbf{1} &= \mathbf{b}\\ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重

  • m 是要运输的质量量

用于解决该问题的算法是条件梯度,如[1]中所讨论的

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类似数组, 形状 (nt,)) – 目前在目标域中估计的样本权重

  • a_extended (类数组, 形状 (ns + nb_dummies,)) – 具有添加的虚拟节点的源领域中的样本权重

  • b_extended (array-like, shape (nt + nb_dummies,)) – 当前在目标领域中估计的样本权重,带有附加的虚拟节点

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg (float) – 正则化项 >0

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • line_search (function,) – 找到最佳步长的函数。默认是armijo线搜索。

  • numItermax (int, 可选) – 最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • warn (bool, 可选。) – 当EMD没有收敛时是否引发警告。

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

ot.optim.semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=None, numItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, nx=None, **kwargs)[源]

解决一般的正则化和半放松的OT问题,使用条件梯度法

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是条件梯度,如[1]中所讨论的

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类似数组, 形状 (nt,)) – 目前在目标域中估计的样本权重

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg (float) – 正则化项 >0

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • line_search (function,) – 寻找最佳步长的函数。默认值为 None,并调用一个包装器来 line_search_armijo。

  • numItermax (int, 可选) – 最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • nx (后端, 可选) – 如果保持默认值 None,后端将根据其他输入推断出来。

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

ot.optim.solve_1d_linesearch_quad(a, b)[源]

对于任何凸或非凸的一维二次函数 f,求解以下问题:

\[\mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c\]
Parameters:
  • a (float张量 (1,)) - 二次函数的系数

  • b (floattensors (1,)) – 二次函数的系数

Returns:

x – 导致最小成本的最优值

Return type:

float

ot.optim.cg(a, b, M, reg, f, df, G0=None, line_search=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, nx=None, **kwargs)[源]

解决带条件梯度的广义正则化OT问题

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是条件梯度,如[1]中所讨论的

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (数组类型, 形状 (nt,)) – 目标领域中的样本

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg (float) – 正则化项 >0

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • line_search (function,) – 寻找最佳步长的函数。默认值为 None,并调用一个包装器来 line_search_armijo。

  • numItermax (int, 可选) – 最大迭代次数

  • numItermaxEmd (int, 可选) – emd的最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • nx (后端, 可选) – 如果保持默认值 None,后端将根据其他输入推断出来。

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

另请参见

ot.lp.emd

未正则化的最优传输

ot.bregman.sinkhorn

熵正则化的最优运输

ot.optim.gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, **kwargs)[源]

使用广义条件梯度求解一般正则化OT问题

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(\Omega\) 是熵正则化项 \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是广义条件梯度,如[5, 7]中讨论的那样

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类数组, (nt,)) – 目标域中的样本

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg1 (float) – 熵正则化项 >0

  • reg2 (float) – 第二个正则化项 >0

  • G0 (类似数组, 形状 (ns, nt), 可选) – 初始猜测(默认是独立联合密度)

  • numItermax (int, 可选) – 最大迭代次数

  • numInnerItermax (int, 可选) – Sinkhorn的最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

Returns:

  • gamma (ndarray, shape (ns, nt)) – 给定参数的最优运输矩阵

  • log (dict) – 仅在参数中log==True时返回日志字典

参考文献

另请参见

ot.optim.cg

条件梯度

ot.optim.generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, numItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, nx=None, **kwargs)[源]

使用条件梯度或广义条件梯度解决一般正则化OT问题或其半放松版本,具体取决于提供的线性规划求解器。

如果设置为条件梯度,该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b} (可选约束)\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是条件梯度,如[1]中所讨论的

该函数解决以下优化问题,如果设置一个广义条件梯度:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\Omega\) 是熵正则化项 \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)

用于解决该问题的算法是广义条件梯度,如[5, 7]中讨论的那样

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类数组, 形状 (nt,)) – 目标领域中的样本权重

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • f (函数) – 作为参数的正则化函数,接受一个运输矩阵

  • df (function) – 作为参数的运输矩阵的正则化函数的梯度

  • reg1 (float) – 正则化项 >0

  • reg2 (float,) – 熵正则化项 >0。如果设置为 None 则被忽略。

  • lp_solver (function,) –

    用于方向寻找的线性规划求解器(广义)条件梯度。 这个函数必须采用以下形式 lp_solver(a, b, Mi, **kwargs),其中 p: ab 是两个领域中的样本权重; Mi 是正则化目标的梯度;通过 kwargs 获取最优参数。 它必须输出一个可接受的运输计划。

    例如,对于一般的正则化 OT 问题与条件梯度 [1]

    def lp_solver(a, b, M, **kwargs):

    return ot.emd(a, b, M)

    或者使用广义条件梯度 [5, 7]

    def lp_solver(a, b, Mi, **kwargs):

    return ot.sinkhorn(a, b, Mi)

  • line_search (function,) –

    寻找最优步长的函数。此函数必须采用以下形式 line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs),其中:cost 为成本函数,G 为运输计划,deltaG 为由 lp_solver 提供的条件 梯度方向,Mi 为正则化目标的梯度,cost_G 为 G 的成本,df_G 为 G 处正则化器的梯度。 支持两种类型的输出:

    例如,ot.optim.line_search_armijo(通用求解器), ot.gromov.solve_gromov_linesearch(FGW 问题), solve_semirelaxed_gromov_linesearch(srFGW 问题)和 gcg_linesearch(广义 cg),输出:线搜索步长 alpha, 求解器中使用的迭代次数(如果适用)以及在步长 alpha 上的损失值。这些可以被调用,例如:

    def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):

    return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)

    例如,ot.gromov.solve_partial_gromov_linesearch 对于部分 (F)GW 问题,最终输出增加下一个步骤的梯度阅读,作为先前计算的梯度的凸组合,利用正则化器的二次形式。

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • numItermax (int, 可选) – 最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • nx (后端, 可选) – 如果保持默认值 None,后端将根据其他输入推断出来。

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

另请参见

ot.lp.emd

未正则化的最优传输

ot.bregman.sinkhorn

熵正则化的最优运输

ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=0.0001, alpha0=0.99, alpha_min=0.0, alpha_max=None, nx=None, **kwargs)[源]

与矩阵一起工作的Armijo线搜索函数

寻找满足阿米霍条件的\(f(x_k + \alpha \cdot p_k)\)的近似最小值。

注意

如果损失函数 f 返回一个浮点数(响应一个一维数组),则返回的 alpha 和 fa 是浮点数(响应一维数组)。

Parameters:
  • f (可调用的) – 损失函数

  • xk (类数组) – 初始位置

  • pk (数组类型) – 下降方向

  • gfk (类似数组) – 在\(x_k\)处的f的梯度

  • old_fval (浮点数一维数组) – 在 \(x_k\) 的损失值

  • args (tuple, 可选) – 传递给 f 的参数

  • c1 (float, 可选) – \(c_1\) armijo 规则中的常数 (>0)

  • alpha0 (float, 可选) – 初始步长 (>0)

  • alpha_min (float, default=0.) – alpha的最小值

  • alpha_max (float, 可选) – alpha的最大值

  • nx (backend, 可选) – 如果默认值为 None,将进行后端测试。

Returns:

  • alpha (浮点数或1维数组) – 满足阿米霍条件的步长

  • fc (整数) – 函数调用次数

  • fa (浮点数或1维数组) – 步长alpha处的损失值

ot.optim.partial_cg(a, b, a_extended, b_extended, M, reg, f, df, G0=None, line_search=<function line_search_armijo>, numItermax=200, stopThr=1e-09, stopThr2=1e-09, warn=True, verbose=False, log=False, **kwargs)[源]

解决具有条件梯度的一般正则化部分OT问题

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma \mathbf{1} &= \mathbf{b}\\ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重

  • m 是要运输的质量量

用于解决该问题的算法是条件梯度,如[1]中所讨论的

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类似数组, 形状 (nt,)) – 目前在目标域中估计的样本权重

  • a_extended (类数组, 形状 (ns + nb_dummies,)) – 具有添加的虚拟节点的源领域中的样本权重

  • b_extended (array-like, shape (nt + nb_dummies,)) – 当前在目标领域中估计的样本权重,带有附加的虚拟节点

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg (float) – 正则化项 >0

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • line_search (function,) – 找到最佳步长的函数。默认是armijo线搜索。

  • numItermax (int, 可选) – 最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • warn (bool, 可选。) – 当EMD没有收敛时是否引发警告。

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

ot.optim.semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=None, numItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, nx=None, **kwargs)[源]

解决一般的正则化和半放松的OT问题,使用条件梯度法

该函数解决以下优化问题:

\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma &\geq 0\end{aligned}\end{align} \]

其中 :

  • \(\mathbf{M}\) 是 (ns, nt) 计量成本矩阵

  • \(f\) 是正则化项(而 df 是它的梯度)

  • \(\mathbf{a}\)\(\mathbf{b}\) 是源权重和目标权重(总和为1)

用于解决该问题的算法是条件梯度,如[1]中所讨论的

Parameters:
  • a (类数组, 形状 (ns,)) – 源域中的样本权重

  • b (类似数组, 形状 (nt,)) – 目前在目标域中估计的样本权重

  • M (类数组, 形状 (ns, nt)) – 损失矩阵

  • reg (float) – 正则化项 >0

  • G0 (类似数组, 形状 (ns,nt), 可选) – 初始猜测(默认是独立联合密度)

  • line_search (function,) – 寻找最佳步长的函数。默认值为 None,并调用一个包装器来 line_search_armijo。

  • numItermax (int, 可选) – 最大迭代次数

  • stopThr (float, 可选) – 相对变化的停止阈值 (>0)

  • stopThr2 (float, 可选的) – 绝对变化的停止阈值 (>0)

  • verbose (bool, 可选) – 在迭代过程中打印信息

  • log (bool, 可选) – 如果为真,则记录日志

  • nx (后端, 可选) – 如果保持默认值 None,后端将根据其他输入推断出来。

  • **kwargs (dict) – 线搜索的参数

Returns:

  • gamma ((ns x nt) ndarray) – 给定参数的最优运输矩阵

  • log (dict) – 日志字典仅在参数中log==True时返回

参考文献

ot.optim.solve_1d_linesearch_quad(a, b)[源]

对于任何凸或非凸的一维二次函数 f,求解以下问题:

\[\mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c\]
Parameters:
  • a (float张量 (1,)) - 二次函数的系数

  • b (floattensors (1,)) – 二次函数的系数

Returns:

x – 导致最小成本的最优值

Return type:

float