平滑且强凸的最近Brenier势

这个例子旨在展示如何在POT中使用SSNB [58]。 SSNB计算一个具有L-Lipschitz梯度的l-强凸势函数 \(\varphi\),使得 \(\nabla \varphi \# \mu \approx \nu\)。这种正则性只能在环境空间的分区组件上强加,这相比于施加全局正则性是一种放松。

在这个例子中,我们考虑源测度 \(\mu_s\),它是在\(\mathbb{R}^2\)中单位正方形上的均匀测度,目标测度 \(\mu_t\)\(\mu_x\) 通过\(T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), 2 * x_2)\)的映像。映射 \(T\) 是非光滑的,我们希望使用一个在分区\(\lbrace x_1 <=0, x_1>0\rbrace\)上规则的“布伦尼尔风格”映射 \(\nabla \varphi\) 来逼近它,这对这个特定数据集非常适合。

我们表示“边界势能” \(\varphi_l, \varphi_u\) 的梯度(来自 [59],定理 3.14),它们界定了在 [58],定义 1 的意义上是最优的任何 SSNB 势能:

\[\varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t),\]

其中 \(\mathcal{F}\) 是在每个集合 \(E_k\) l-强凸的空间函数,具有L-Lipschitz梯度,给定 \((E_k)_{k \in [K]}\) 是环境源空间的一个划分。

我们在少量的拟合样本和较少的迭代下进行优化,因为解决SSNB问题在计算上是相当昂贵的。

此示例需要 CVXPY

# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import matplotlib.pyplot as plt
import numpy as np
import ot

生成拟合数据

n_fitting_samples = 30
rng = np.random.RandomState(seed=0)
Xs = rng.uniform(-1, 1, size=(n_fitting_samples, 2))
Xs_classes = (Xs[:, 0] < 0).astype(int)
Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1)

plt.scatter(
    Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c="blue", label="source class 0"
)
plt.scatter(
    Xs[Xs_classes == 1, 0],
    Xs[Xs_classes == 1, 1],
    c="dodgerblue",
    label="source class 1",
)
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
plt.axis("equal")
plt.title("Splitting sphere dataset")
plt.legend(loc="upper right")
plt.show()
Splitting sphere dataset

拟合最近的布雷尼尔势

L = 3  # need L > 2 to allow the 2*y term, default is 1.4
phi, G = ot.mapping.nearest_brenier_potential_fit(
    Xs, Xt, Xs_classes, its=10, init_method="barycentric", gradient_lipschitz_constant=L
)
/home/circleci/.local/lib/python3.10/site-packages/cvxpy/reductions/solvers/solving_chain.py:356: FutureWarning:
    You specified your problem should be solved by ECOS. Starting in
    CXVPY 1.6.0, ECOS will no longer be installed by default with CVXPY.
    Please either add ECOS as an explicit install dependency to your project
    or switch to our new default solver, Clarabel, by either not specifying a
    solver argument or specifying ``solver=cp.CLARABEL``. To suppress this
    warning while continuing to use ECOS, you can filter this warning using
    Python's ``warnings`` module until you are using 1.6.0.

  warnings.warn(ECOS_DEP_DEPRECATION_MSG, FutureWarning)

绘制源数据的图像

plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
for i in range(n_fitting_samples):
    plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color="black", alpha=0.5)
plt.title("Images of in-data source samples by the fitted SSNB")
plt.legend(loc="upper right")
plt.axis("equal")
plt.show()
Images of in-data source samples by the fitted SSNB

为源分布的随机样本计算预测(由 nabla phi 提供的图像)

n_predict_samples = 50
Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2))
Ys_classes = (Ys[:, 0] < 0).astype(int)
phi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds(
    Xs, phi, G, Ys, Xs_classes, Ys_classes, gradient_lipschitz_constant=L
)

为下界潜力的梯度绘制预测

plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
for i in range(n_predict_samples):
    plt.plot(
        [Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color="black", alpha=0.5
    )
plt.title("Images of new source samples by $\\nabla \\varphi_l$")
plt.legend(loc="upper right")
plt.axis("equal")
plt.show()
Images of new source samples by $\nabla \varphi_l$

绘制上界势能梯度的预测

plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
for i in range(n_predict_samples):
    plt.plot(
        [Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color="black", alpha=0.5
    )
plt.title("Images of new source samples by $\\nabla \\varphi_u$")
plt.legend(loc="upper right")
plt.axis("equal")
plt.show()
Images of new source samples by $\nabla \varphi_u$

脚本的总运行时间: (0分钟 47.016秒)

由 Sphinx-Gallery 生成的画廊