融合Gromov-Wasserstein求解器的比较

该示例说明了使用4种不同求解器计算属性图的FGW,基于条件梯度[24]、Sinkhorn投影[12, 51]和交替Bregman投影[63, 64]来估计距离。

我们生成两个图,遵循随机块模型,并赋予节点特征,并计算它们的FGW匹配。

[12] Gabriel Peyré, Marco Cuturi 和 Justin Solomon (2016), “Gromov-Wasserstein 平均核和距离矩阵”。 国际机器学习会议 (ICML)。

[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain 和 Courty Nicolas “适用于图的结构化数据的最优运输” 国际机器学习会议 (ICML). 2019.

[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). “Gromov-wasserstein学习用于图匹配和节点嵌入”。 在国际机器学习大会(ICML),2019。

[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. “一个用于图数据中Gromov-Wasserstein松弛的收敛单循环算法”。国际学习表征会议 (ICLR), 2023.

[64] 马, X., 朱, X., 王, Y., 林, Y., 赵, J., 马, L., & 朱, W. “融合的 Gromov-Wasserstein 图混合用于图级分类”。在第三十七届神经信息处理系统会议(NeurIPS),2023。

# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
import matplotlib.pylab as pl
from ot.gromov import (
    fused_gromov_wasserstein,
    entropic_fused_gromov_wasserstein,
    BAPG_fused_gromov_wasserstein,
)
import networkx
from networkx.generators.community import stochastic_block_model as sbm
from time import time

生成遵循2个和3个聚类的随机区块模型的两个图形。

np.random.seed(0)

N2 = 20  # 2 communities
N3 = 30  # 3 communities
p2 = [[1.0, 0.1], [0.1, 0.9]]
p3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]]
G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
part_G2 = [G2.nodes[i]["block"] for i in range(N2)]
part_G3 = [G3.nodes[i]["block"] for i in range(N3)]

C2 = networkx.to_numpy_array(G2)
C3 = networkx.to_numpy_array(G3)


# We add node features with given mean - by clusters
# and inversely proportional to clusters' intra-connectivity

F2 = np.zeros((N2, 1))
for i, c in enumerate(part_G2):
    F2[i, 0] = np.random.normal(loc=c, scale=0.01)

F3 = np.zeros((N3, 1))
for i, c in enumerate(part_G3):
    F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01)

# Compute pairwise euclidean distance between node features
M = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T)

h2 = np.ones(C2.shape[0]) / C2.shape[0]
h3 = np.ones(C3.shape[0]) / C3.shape[0]

计算它们的融合 Gromov-Wasserstein 距离

alpha = 0.5


# Conditional Gradient algorithm
print("Conditional Gradient \n")
start_cg = time()
T_cg, log_cg = fused_gromov_wasserstein(
    M, C2, C3, h2, h3, "square_loss", alpha=alpha, tol_rel=1e-9, verbose=True, log=True
)
end_cg = time()
time_cg = 1000 * (end_cg - start_cg)

# Proximal Point algorithm with Kullback-Leibler as proximal operator
print("Proximal Point Algorithm \n")
start_ppa = time()
T_ppa, log_ppa = entropic_fused_gromov_wasserstein(
    M,
    C2,
    C3,
    h2,
    h3,
    "square_loss",
    alpha=alpha,
    epsilon=1.0,
    solver="PPA",
    tol=1e-9,
    log=True,
    verbose=True,
    warmstart=False,
    numItermax=10,
)
end_ppa = time()
time_ppa = 1000 * (end_ppa - start_ppa)

# Projected Gradient algorithm with entropic regularization
print("Projected Gradient Descent \n")
start_pgd = time()
T_pgd, log_pgd = entropic_fused_gromov_wasserstein(
    M,
    C2,
    C3,
    h2,
    h3,
    "square_loss",
    alpha=alpha,
    epsilon=0.01,
    solver="PGD",
    tol=1e-9,
    log=True,
    verbose=True,
    warmstart=False,
    numItermax=10,
)
end_pgd = time()
time_pgd = 1000 * (end_pgd - start_pgd)

# Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator
print("Bregman Alternated Projected Gradient \n")
start_bapg = time()
T_bapg, log_bapg = BAPG_fused_gromov_wasserstein(
    M,
    C2,
    C3,
    h2,
    h3,
    "square_loss",
    alpha=alpha,
    epsilon=1.0,
    tol=1e-9,
    marginal_loss=True,
    verbose=True,
    log=True,
)
end_bapg = time()
time_bapg = 1000 * (end_bapg - start_bapg)

print(
    "Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: "
    + str(log_cg["fgw_dist"])
)
print(
    "Fused Gromov-Wasserstein distance estimated with Proximal Point solver: "
    + str(log_ppa["fgw_dist"])
)
print(
    "Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: "
    + str(log_pgd["fgw_dist"])
)
print(
    "Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: "
    + str(log_bapg["fgw_dist"])
)

# compute OT sparsity level
T_cg_sparsity = 100 * (T_cg == 0.0).astype(np.float64).sum() / (N2 * N3)
T_ppa_sparsity = 100 * (T_ppa == 0.0).astype(np.float64).sum() / (N2 * N3)
T_pgd_sparsity = 100 * (T_pgd == 0.0).astype(np.float64).sum() / (N2 * N3)
T_bapg_sparsity = 100 * (T_bapg == 0.0).astype(np.float64).sum() / (N2 * N3)

# Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the
# marginal constraints

err_cg = np.linalg.norm(T_cg.sum(1) - h2) + np.linalg.norm(T_cg.sum(0) - h3)
err_ppa = np.linalg.norm(T_ppa.sum(1) - h2) + np.linalg.norm(T_ppa.sum(0) - h3)
err_pgd = np.linalg.norm(T_pgd.sum(1) - h2) + np.linalg.norm(T_pgd.sum(0) - h3)
err_bapg = np.linalg.norm(T_bapg.sum(1) - h2) + np.linalg.norm(T_bapg.sum(0) - h3)
Conditional Gradient

It.  |Loss        |Relative loss|Absolute loss
------------------------------------------------
    0|8.271184e-01|0.000000e+00|0.000000e+00
    1|4.211305e-01|9.640431e-01|4.059879e-01
    2|4.024660e-01|4.637523e-02|1.866445e-02
    3|3.936346e-01|2.243555e-02|8.831410e-03
    4|3.891614e-01|1.149450e-02|4.473216e-03
    5|3.854134e-01|9.724554e-03|3.747973e-03
    6|3.850574e-01|9.244899e-04|3.559817e-04
    7|3.841819e-01|2.279017e-03|8.755571e-04
    8|3.819396e-01|5.870728e-03|2.242264e-03
    9|3.784264e-01|9.283767e-03|3.513222e-03
   10|3.772225e-01|3.191339e-03|1.203845e-03
   11|3.764565e-01|2.034974e-03|7.660790e-04
   12|3.761179e-01|9.000610e-04|3.385291e-04
   13|3.761179e-01|0.000000e+00|0.000000e+00
Proximal Point Algorithm

/home/circleci/project/ot/bregman/_sinkhorn.py:667: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
  warnings.warn(
It.  |Err
-------------------
    0|1.536990e-02|
   10|7.472502e-04|
   20|6.129779e-04|
   30|5.936118e-04|
   40|6.334121e-04|
   50|6.852583e-04|
   60|7.134797e-04|
   70|7.177453e-04|
   80|7.236970e-04|
   90|7.465712e-04|
  100|7.858590e-04|
  110|8.312306e-04|
  120|8.789776e-04|
  130|9.493878e-04|
  140|1.029873e-03|
  150|1.095678e-03|
  160|1.177744e-03|
  170|1.257392e-03|
  180|1.242762e-03|
  190|1.125684e-03|
It.  |Err
-------------------
  200|9.770321e-04|
  210|8.368455e-04|
  220|7.094853e-04|
  230|6.159374e-04|
  240|5.564338e-04|
  250|4.974174e-04|
  260|4.297098e-04|
  270|3.639787e-04|
  280|3.106686e-04|
  290|2.716049e-04|
  300|2.433353e-04|
  310|2.227167e-04|
  320|2.080617e-04|
  330|1.980812e-04|
  340|1.912468e-04|
  350|1.858901e-04|
  360|1.805235e-04|
  370|1.740559e-04|
  380|1.658655e-04|
  390|1.557834e-04|
It.  |Err
-------------------
  400|1.440242e-04|
  410|1.310793e-04|
  420|1.175881e-04|
/home/circleci/project/ot/backend.py:1168: RuntimeWarning: divide by zero encountered in log
  return np.log(a)
  430|1.042086e-04|
  440|9.151311e-05|
  450|7.992882e-05|
  460|6.972309e-05|
  470|6.101506e-05|
  480|5.379844e-05|
  490|4.797194e-05|
  500|4.337741e-05|
  510|3.984020e-05|
  520|3.720212e-05|
  530|3.534028e-05|
  540|3.417109e-05|
  550|3.364311e-05|
  560|3.372351e-05|
  570|3.438312e-05|
  580|3.558368e-05|
  590|3.726958e-05|
It.  |Err
-------------------
  600|3.936418e-05|
  610|4.176966e-05|
  620|4.436877e-05|
  630|4.702796e-05|
  640|4.960175e-05|
  650|5.193888e-05|
  660|5.389046e-05|
  670|5.531979e-05|
  680|5.611295e-05|
  690|5.618871e-05|
  700|5.550627e-05|
  710|5.406954e-05|
  720|5.192695e-05|
  730|4.916691e-05|
  740|4.590932e-05|
  750|4.229430e-05|
  760|3.846976e-05|
  770|3.457933e-05|
  780|3.075230e-05|
  790|2.709643e-05|
It.  |Err
-------------------
  800|2.369414e-05|
  810|2.060184e-05|
  820|1.785186e-05|
  830|1.545590e-05|
  840|1.340916e-05|
  850|1.169442e-05|
  860|1.028566e-05|
  870|9.150926e-06|
  880|8.254840e-06|
  890|7.560795e-06|
  900|7.033105e-06|
  910|6.638947e-06|
  920|6.349865e-06|
  930|6.142583e-06|
  940|5.999120e-06|
  950|5.906371e-06|
  960|5.855388e-06|
  970|5.840590e-06|
  980|5.859006e-06|
  990|5.909627e-06|
Projected Gradient Descent

It.  |Err
-------------------
    0|4.981055e-02|
   10|1.110756e-01|
   20|1.139916e-01|
   30|1.158952e-01|
   40|1.159648e-01|
   50|1.159715e-01|
   60|1.159728e-01|
   70|1.159732e-01|
   80|1.159732e-01|
   90|1.159733e-01|
  100|1.159733e-01|
  110|1.159733e-01|
  120|1.159733e-01|
  130|1.159733e-01|
  140|1.159733e-01|
  150|1.159733e-01|
  160|1.159733e-01|
  170|1.159733e-01|
  180|1.159733e-01|
  190|1.159733e-01|
It.  |Err
-------------------
  200|1.159733e-01|
  210|1.159733e-01|
  220|1.159733e-01|
  230|1.159733e-01|
  240|1.159733e-01|
  250|1.159733e-01|
  260|1.159733e-01|
  270|1.159733e-01|
  280|1.159733e-01|
  290|1.159733e-01|
  300|1.159733e-01|
  310|1.159733e-01|
  320|1.159733e-01|
  330|1.159733e-01|
  340|1.159733e-01|
  350|1.159733e-01|
  360|1.159733e-01|
  370|1.159733e-01|
  380|1.159733e-01|
  390|1.159733e-01|
It.  |Err
-------------------
  400|1.159733e-01|
  410|1.159733e-01|
  420|1.159733e-01|
  430|1.159733e-01|
  440|1.159733e-01|
  450|1.159733e-01|
  460|1.159733e-01|
  470|1.159733e-01|
  480|1.159733e-01|
  490|1.159733e-01|
  500|1.159733e-01|
  510|1.159733e-01|
  520|1.159733e-01|
  530|1.159733e-01|
  540|1.159733e-01|
  550|1.159733e-01|
  560|1.159733e-01|
  570|1.159733e-01|
  580|1.159733e-01|
  590|1.159733e-01|
It.  |Err
-------------------
  600|1.159733e-01|
  610|1.159733e-01|
  620|1.159733e-01|
  630|1.159733e-01|
  640|1.159733e-01|
  650|1.159733e-01|
  660|1.159733e-01|
  670|1.159733e-01|
  680|1.159733e-01|
  690|1.159733e-01|
  700|1.159733e-01|
  710|1.159733e-01|
  720|1.159733e-01|
  730|1.159733e-01|
  740|1.159733e-01|
  750|1.159733e-01|
  760|1.159733e-01|
  770|1.159733e-01|
  780|1.159733e-01|
  790|1.159733e-01|
It.  |Err
-------------------
  800|1.159733e-01|
  810|1.159733e-01|
  820|1.159733e-01|
  830|1.159733e-01|
  840|1.159733e-01|
  850|1.159733e-01|
  860|1.159733e-01|
  870|1.159733e-01|
  880|1.159733e-01|
  890|1.159733e-01|
  900|1.159733e-01|
  910|1.159733e-01|
  920|1.159733e-01|
  930|1.159733e-01|
  940|1.159733e-01|
  950|1.159733e-01|
  960|1.159733e-01|
  970|1.159733e-01|
  980|1.159733e-01|
  990|1.159733e-01|
Bregman Alternated Projected Gradient

It.  |Err
-------------------
    0|2.710197e-02|
   10|7.722091e-04|
   20|1.141907e-03|
   30|1.705361e-03|
   40|2.217013e-03|
   50|2.780133e-03|
   60|2.754405e-03|
   70|2.618865e-03|
   80|2.155365e-03|
   90|1.467885e-03|
  100|1.234502e-03|
  110|1.427166e-03|
  120|1.580081e-03|
  130|1.480216e-03|
  140|1.016301e-03|
  150|7.489908e-04|
  160|5.527976e-04|
  170|3.724086e-04|
  180|2.505381e-04|
  190|1.890667e-04|
It.  |Err
-------------------
  200|1.660689e-04|
  210|1.596126e-04|
  220|1.580415e-04|
  230|1.573214e-04|
  240|1.554877e-04|
  250|1.506710e-04|
  260|1.413884e-04|
  270|1.273670e-04|
  280|1.099326e-04|
  290|9.144168e-05|
  300|7.415348e-05|
  310|5.942604e-05|
  320|4.761838e-05|
  330|3.844073e-05|
  340|3.135983e-05|
  350|2.585715e-05|
  360|2.152819e-05|
  370|1.808727e-05|
  380|1.533873e-05|
  390|1.314610e-05|
It.  |Err
-------------------
  400|1.140922e-05|
  410|1.004962e-05|
  420|9.001810e-06|
  430|8.208638e-06|
  440|7.619328e-06|
  450|7.189212e-06|
  460|6.880145e-06|
  470|6.660812e-06|
  480|6.506527e-06|
  490|6.398538e-06|
  500|6.323025e-06|
  510|6.270041e-06|
  520|6.232558e-06|
  530|6.205670e-06|
  540|6.185991e-06|
  550|6.171195e-06|
  560|6.159695e-06|
  570|6.150412e-06|
  580|6.142611e-06|
  590|6.135798e-06|
It.  |Err
-------------------
  600|6.129637e-06|
  610|6.123902e-06|
  620|6.118440e-06|
  630|6.113149e-06|
  640|6.107960e-06|
  650|6.102825e-06|
  660|6.097715e-06|
  670|6.092609e-06|
  680|6.087492e-06|
  690|6.082355e-06|
  700|6.077194e-06|
  710|6.072003e-06|
  720|6.066779e-06|
  730|6.061523e-06|
  740|6.056231e-06|
  750|6.050905e-06|
  760|6.045543e-06|
  770|6.040145e-06|
  780|6.034711e-06|
  790|6.029242e-06|
It.  |Err
-------------------
  800|6.023737e-06|
  810|6.018197e-06|
  820|6.012621e-06|
  830|6.007010e-06|
  840|6.001365e-06|
  850|5.995684e-06|
  860|5.989968e-06|
  870|5.984218e-06|
  880|5.978434e-06|
  890|5.972615e-06|
  900|5.966762e-06|
  910|5.960875e-06|
  920|5.954955e-06|
  930|5.949001e-06|
  940|5.943013e-06|
  950|5.936992e-06|
  960|5.930938e-06|
  970|5.924851e-06|
  980|5.918731e-06|
  990|5.912579e-06|
Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: 0.3761179313933098
Fused Gromov-Wasserstein distance estimated with Proximal Point solver: 0.3671471715862438
Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: 0.21736592892258025
Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: 0.2576635020911173

融合的Gromov-Wasserstein匹配的可视化

我们给右侧的图中的节点上色 - 然后根据FGW匹配的最佳运输计划投影其节点颜色。我们根据发送的质量调整跨域链接的强度,如果发送的质量不为零,则添加最小强度0.1。对于每个匹配,所有节点的大小与其从OT计划的边际计算得到的质量成比例,以说明潜在的可行性错误。注意:颜色指的是簇 - 而不是节点特征

# Add weights on the edges for visualization later on
weight_intra_G2 = 5
weight_inter_G2 = 0.5
weight_intra_G3 = 1.0
weight_inter_G3 = 1.5

weightedG2 = networkx.Graph()
part_G2 = [G2.nodes[i]["block"] for i in range(N2)]

for node in G2.nodes():
    weightedG2.add_node(node)
for i, j in G2.edges():
    if part_G2[i] == part_G2[j]:
        weightedG2.add_edge(i, j, weight=weight_intra_G2)
    else:
        weightedG2.add_edge(i, j, weight=weight_inter_G2)

weightedG3 = networkx.Graph()
part_G3 = [G3.nodes[i]["block"] for i in range(N3)]

for node in G3.nodes():
    weightedG3.add_node(node)
for i, j in G3.edges():
    if part_G3[i] == part_G3[j]:
        weightedG3.add_edge(i, j, weight=weight_intra_G3)
    else:
        weightedG3.add_edge(i, j, weight=weight_inter_G3)


def draw_graph(
    G,
    C,
    nodes_color_part,
    Gweights=None,
    pos=None,
    edge_color="black",
    node_size=None,
    shiftx=0,
    seed=0,
):
    if pos is None:
        pos = networkx.spring_layout(G, scale=1.0, seed=seed)

    if shiftx != 0:
        for k, v in pos.items():
            v[0] = v[0] + shiftx

    alpha_edge = 0.7
    width_edge = 1.8
    if Gweights is None:
        networkx.draw_networkx_edges(
            G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color
        )
    else:
        # We make more visible connections between activated nodes
        n = len(Gweights)
        edgelist_activated = []
        edgelist_deactivated = []
        for i in range(n):
            for j in range(n):
                if Gweights[i] * Gweights[j] * C[i, j] > 0:
                    edgelist_activated.append((i, j))
                elif C[i, j] > 0:
                    edgelist_deactivated.append((i, j))

        networkx.draw_networkx_edges(
            G,
            pos,
            edgelist=edgelist_activated,
            width=width_edge,
            alpha=alpha_edge,
            edge_color=edge_color,
        )
        networkx.draw_networkx_edges(
            G,
            pos,
            edgelist=edgelist_deactivated,
            width=width_edge,
            alpha=0.1,
            edge_color=edge_color,
        )

    if Gweights is None:
        for node, node_color in enumerate(nodes_color_part):
            networkx.draw_networkx_nodes(
                G,
                pos,
                nodelist=[node],
                node_size=node_size,
                alpha=1,
                node_color=node_color,
            )
    else:
        scaled_Gweights = Gweights / (0.5 * Gweights.max())
        nodes_size = node_size * scaled_Gweights
        for node, node_color in enumerate(nodes_color_part):
            networkx.draw_networkx_nodes(
                G,
                pos,
                nodelist=[node],
                node_size=nodes_size[node],
                alpha=1,
                node_color=node_color,
            )
    return pos


def draw_transp_colored_GW(
    G1,
    C1,
    G2,
    C2,
    part_G1,
    p1,
    p2,
    T,
    pos1=None,
    pos2=None,
    shiftx=4,
    switchx=False,
    node_size=70,
    seed_G1=0,
    seed_G2=0,
):
    starting_color = 0
    # get graphs partition and their coloring
    part1 = part_G1.copy()
    unique_colors = ["C%s" % (starting_color + i) for i in np.unique(part1)]
    nodes_color_part1 = []
    for cluster in part1:
        nodes_color_part1.append(unique_colors[cluster])

    nodes_color_part2 = []
    # T: getting colors assignment from argmin of columns
    for i in range(len(G2.nodes())):
        j = np.argmax(T[:, i])
        nodes_color_part2.append(nodes_color_part1[j])
    pos1 = draw_graph(
        G1,
        C1,
        nodes_color_part1,
        Gweights=p1,
        pos=pos1,
        node_size=node_size,
        shiftx=0,
        seed=seed_G1,
    )
    pos2 = draw_graph(
        G2,
        C2,
        nodes_color_part2,
        Gweights=p2,
        pos=pos2,
        node_size=node_size,
        shiftx=shiftx,
        seed=seed_G2,
    )

    for k1, v1 in pos1.items():
        max_Tk1 = np.max(T[k1, :])
        for k2, v2 in pos2.items():
            if T[k1, k2] > 0:
                pl.plot(
                    [pos1[k1][0], pos2[k2][0]],
                    [pos1[k1][1], pos2[k2][1]],
                    "-",
                    lw=0.7,
                    alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.0),
                    color=nodes_color_part1[k1],
                )
    return pos1, pos2


node_size = 40
fontsize = 13
seed_G2 = 0
seed_G3 = 4

pl.figure(2, figsize=(15, 3.5))
pl.clf()
pl.subplot(141)
pl.axis("off")

pl.title(
    "(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s"
    % (
        np.round(log_cg["fgw_dist"], 3),
        str(np.round(T_cg_sparsity, 2)) + " %",
        np.round(err_cg, 4),
        str(np.round(time_cg, 2)) + " ms",
    ),
    fontsize=fontsize,
)

pos1, pos2 = draw_transp_colored_GW(
    weightedG2,
    C2,
    weightedG3,
    C3,
    part_G2,
    p1=T_cg.sum(1),
    p2=T_cg.sum(0),
    T=T_cg,
    shiftx=1.5,
    node_size=node_size,
    seed_G1=seed_G2,
    seed_G2=seed_G3,
)

pl.subplot(142)
pl.axis("off")

pl.title(
    "(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s"
    % (
        np.round(log_ppa["fgw_dist"], 3),
        str(np.round(T_ppa_sparsity, 2)) + " %",
        np.round(err_ppa, 4),
        str(np.round(time_ppa, 2)) + " ms",
    ),
    fontsize=fontsize,
)

pos1, pos2 = draw_transp_colored_GW(
    weightedG2,
    C2,
    weightedG3,
    C3,
    part_G2,
    p1=T_ppa.sum(1),
    p2=T_ppa.sum(0),
    T=T_ppa,
    pos1=pos1,
    pos2=pos2,
    shiftx=0.0,
    node_size=node_size,
    seed_G1=0,
    seed_G2=0,
)

pl.subplot(143)
pl.axis("off")

pl.title(
    "(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s"
    % (
        np.round(log_pgd["fgw_dist"], 3),
        str(np.round(T_pgd_sparsity, 2)) + " %",
        np.round(err_pgd, 4),
        str(np.round(time_pgd, 2)) + " ms",
    ),
    fontsize=fontsize,
)

pos1, pos2 = draw_transp_colored_GW(
    weightedG2,
    C2,
    weightedG3,
    C3,
    part_G2,
    p1=T_pgd.sum(1),
    p2=T_pgd.sum(0),
    T=T_pgd,
    pos1=pos1,
    pos2=pos2,
    shiftx=0.0,
    node_size=node_size,
    seed_G1=0,
    seed_G2=0,
)


pl.subplot(144)
pl.axis("off")

pl.title(
    "(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s"
    % (
        np.round(log_bapg["fgw_dist"], 3),
        str(np.round(T_bapg_sparsity, 2)) + " %",
        np.round(err_bapg, 4),
        str(np.round(time_bapg, 2)) + " ms",
    ),
    fontsize=fontsize,
)

pos1, pos2 = draw_transp_colored_GW(
    weightedG2,
    C2,
    weightedG3,
    C3,
    part_G2,
    p1=T_bapg.sum(1),
    p2=T_bapg.sum(0),
    T=T_bapg,
    pos1=pos1,
    pos2=pos2,
    shiftx=0.0,
    node_size=node_size,
    seed_G1=0,
    seed_G2=0,
)

pl.tight_layout()

pl.show()
(CG) FGW=0.376    OT sparsity = 90.33 %   marg. error = 0.0   runtime = 4.39 ms, (PPA) FGW=0.367    OT sparsity = 33.33 %   marg. error = 0.004   runtime = 362.31 ms, (PGD) Entropic FGW=0.217    OT sparsity = 0.0 %   marg. error = 0.0711   runtime = 351.01 ms, (BAPG) FGW=0.258    OT sparsity = 50.0 %   marg. error = 0.0902   runtime = 68.15 ms

脚本的总运行时间: (0分钟 3.107秒)

由 Sphinx-Gallery 生成的画廊