平滑且强凸的最近Brenier势

这个例子旨在展示如何在POT中使用SSNB [58]。 SSNB计算一个具有L-Lipschitz梯度的l-强凸势函数 \(\varphi\),使得 \(\nabla \varphi \# \mu \approx \nu\)。这种正则性只能在环境空间的分区组件上强加,这相比于施加全局正则性是一种放松。

在这个例子中,我们考虑源测度 \(\mu_s\),它是在\(\mathbb{R}^2\)中单位正方形上的均匀测度,目标测度 \(\mu_t\)\(\mu_x\) 通过\(T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), 2 * x_2)\)的映像。映射 \(T\) 是非光滑的,我们希望使用一个在分区\(\lbrace x_1 <=0, x_1>0\rbrace\)上规则的“布伦尼尔风格”映射 \(\nabla \varphi\) 来逼近它,这对这个特定数据集非常适合。

我们表示“边界势能” \(\varphi_l, \varphi_u\) 的梯度(来自 [59],定理 3.14),它们界定了在 [58],定义 1 的意义上是最优的任何 SSNB 势能:

\[\varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t),\]

其中 \(\mathcal{F}\) 是在每个集合 \(E_k\) l-强凸的空间函数,具有L-Lipschitz梯度,给定 \((E_k)_{k \in [K]}\) 是环境源空间的一个划分。

我们在少量的拟合样本和较少的迭代下进行优化,因为解决SSNB问题在计算上是相当昂贵的。

此示例需要 CVXPY

# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import matplotlib.pyplot as plt
import numpy as np
import ot

生成拟合数据

n_fitting_samples = 30
rng = np.random.RandomState(seed=0)
Xs = rng.uniform(-1, 1, size=(n_fitting_samples, 2))
Xs_classes = (Xs[:, 0] < 0).astype(int)
Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1)

plt.scatter(
    Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c="blue", label="source class 0"
)
plt.scatter(
    Xs[Xs_classes == 1, 0],
    Xs[Xs_classes == 1, 1],
    c="dodgerblue",
    label="source class 1",
)
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
plt.axis("equal")
plt.title("Splitting sphere dataset")
plt.legend(loc="upper right")
plt.show()
Splitting sphere dataset

拟合最近的布雷尼尔势

L = 3  # need L > 2 to allow the 2*y term, default is 1.4
phi, G = ot.mapping.nearest_brenier_potential_fit(
    Xs, Xt, Xs_classes, its=10, init_method="barycentric", gradient_lipschitz_constant=L
)

绘制源数据的图像

plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
for i in range(n_fitting_samples):
    plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color="black", alpha=0.5)
plt.title("Images of in-data source samples by the fitted SSNB")
plt.legend(loc="upper right")
plt.axis("equal")
plt.show()
Images of in-data source samples by the fitted SSNB

为源分布的随机样本计算预测(由 nabla phi 提供的图像)

n_predict_samples = 50
Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2))
Ys_classes = (Ys[:, 0] < 0).astype(int)
phi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds(
    Xs, phi, G, Ys, Xs_classes, Ys_classes, gradient_lipschitz_constant=L
)

为下界潜力的梯度绘制预测

plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
for i in range(n_predict_samples):
    plt.plot(
        [Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color="black", alpha=0.5
    )
plt.title("Images of new source samples by $\\nabla \\varphi_l$")
plt.legend(loc="upper right")
plt.axis("equal")
plt.show()
Images of new source samples by $\nabla \varphi_l$

绘制上界势能梯度的预测

plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source")
plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target")
for i in range(n_predict_samples):
    plt.plot(
        [Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color="black", alpha=0.5
    )
plt.title("Images of new source samples by $\\nabla \\varphi_u$")
plt.legend(loc="upper right")
plt.axis("equal")
plt.show()
Images of new source samples by $\nabla \varphi_u$

脚本的总运行时间: (0分钟52.842秒)

由 Sphinx-Gallery 生成的画廊