半放松(融合)Gromov-Wasserstein 示例

此示例旨在展示如何使用半放松的Gromov-Wasserstein和半放松的融合Gromov-Wasserstein散度。

sr(F)GW 在两个图 G1 和 G2 之间搜索节点 G2 的重新加权,使其与 G1 的 (F)GW 距离最小。

首先,我们生成两个遵循随机块模型的图,然后展示如何计算它们的srGW匹配并进行说明。这些图随后被赋予节点特征,我们使用相同的过程来处理srFGW。

[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. “半松弛 Gromov-Wasserstein 散度及其在图上的应用” 国际学习表示会议 (ICLR), 2021.

# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
import matplotlib.pylab as pl
from ot.gromov import (
    semirelaxed_gromov_wasserstein,
    semirelaxed_fused_gromov_wasserstein,
    gromov_wasserstein,
    fused_gromov_wasserstein,
)
import networkx
from networkx.generators.community import stochastic_block_model as sbm

生成遵循2个和3个聚类的随机区块模型的两个图形。

N2 = 20  # 2 communities
N3 = 30  # 3 communities
p2 = [[1.0, 0.1], [0.1, 0.9]]
p3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]]
G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)


C2 = networkx.to_numpy_array(G2)
C3 = networkx.to_numpy_array(G3)

h2 = np.ones(C2.shape[0]) / C2.shape[0]
h3 = np.ones(C3.shape[0]) / C3.shape[0]

# Add weights on the edges for visualization later on
weight_intra_G2 = 5
weight_inter_G2 = 0.5
weight_intra_G3 = 1.0
weight_inter_G3 = 1.5

weightedG2 = networkx.Graph()
part_G2 = [G2.nodes[i]["block"] for i in range(N2)]

for node in G2.nodes():
    weightedG2.add_node(node)
for i, j in G2.edges():
    if part_G2[i] == part_G2[j]:
        weightedG2.add_edge(i, j, weight=weight_intra_G2)
    else:
        weightedG2.add_edge(i, j, weight=weight_inter_G2)

weightedG3 = networkx.Graph()
part_G3 = [G3.nodes[i]["block"] for i in range(N3)]

for node in G3.nodes():
    weightedG3.add_node(node)
for i, j in G3.edges():
    if part_G3[i] == part_G3[j]:
        weightedG3.add_edge(i, j, weight=weight_intra_G3)
    else:
        weightedG3.add_edge(i, j, weight=weight_inter_G3)

计算它们的半松弛Gromov-Wasserstein散度

# 0) GW(C2, h2, C3, h3) for reference
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
gw = log["gw_dist"]

# 1) srGW(C2, h2, C3)
OT_23, log_23 = semirelaxed_gromov_wasserstein(
    C2, C3, h2, symmetric=True, log=True, G0=None
)
srgw_23 = log_23["srgw_dist"]

# 2) srGW(C3, h3, C2)

OT_32, log_32 = semirelaxed_gromov_wasserstein(
    C3, C2, h3, symmetric=None, log=True, G0=OT.T
)
srgw_32 = log_32["srgw_dist"]

print("GW(C2, C3) = ", gw)
print("srGW(C2, h2, C3) = ", srgw_23)
print("srGW(C3, h3, C2) = ", srgw_32)
GW(C2, C3) =  0.24722222222222215
srGW(C2, h2, C3) =  0.07000000000000006
srGW(C3, h3, C2) =  0.17111111111111116

半放松Gromov-Wasserstein匹配的可视化

我们对右侧图的节点进行着色 - 然后根据srGW匹配的最优传输计划投影其节点颜色

def draw_graph(
    G,
    C,
    nodes_color_part,
    Gweights=None,
    pos=None,
    edge_color="black",
    node_size=None,
    shiftx=0,
    seed=0,
):
    if pos is None:
        pos = networkx.spring_layout(G, scale=1.0, seed=seed)

    if shiftx != 0:
        for k, v in pos.items():
            v[0] = v[0] + shiftx

    alpha_edge = 0.7
    width_edge = 1.8
    if Gweights is None:
        networkx.draw_networkx_edges(
            G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color
        )
    else:
        # We make more visible connections between activated nodes
        n = len(Gweights)
        edgelist_activated = []
        edgelist_deactivated = []
        for i in range(n):
            for j in range(n):
                if Gweights[i] * Gweights[j] * C[i, j] > 0:
                    edgelist_activated.append((i, j))
                elif C[i, j] > 0:
                    edgelist_deactivated.append((i, j))

        networkx.draw_networkx_edges(
            G,
            pos,
            edgelist=edgelist_activated,
            width=width_edge,
            alpha=alpha_edge,
            edge_color=edge_color,
        )
        networkx.draw_networkx_edges(
            G,
            pos,
            edgelist=edgelist_deactivated,
            width=width_edge,
            alpha=0.1,
            edge_color=edge_color,
        )

    if Gweights is None:
        for node, node_color in enumerate(nodes_color_part):
            networkx.draw_networkx_nodes(
                G,
                pos,
                nodelist=[node],
                node_size=node_size,
                alpha=1,
                node_color=node_color,
            )
    else:
        scaled_Gweights = Gweights / (0.5 * Gweights.max())
        nodes_size = node_size * scaled_Gweights
        for node, node_color in enumerate(nodes_color_part):
            networkx.draw_networkx_nodes(
                G,
                pos,
                nodelist=[node],
                node_size=nodes_size[node],
                alpha=1,
                node_color=node_color,
            )
    return pos


def draw_transp_colored_srGW(
    G1,
    C1,
    G2,
    C2,
    part_G1,
    p1,
    p2,
    T,
    pos1=None,
    pos2=None,
    shiftx=4,
    switchx=False,
    node_size=70,
    seed_G1=0,
    seed_G2=0,
):
    starting_color = 0
    # get graphs partition and their coloring
    part1 = part_G1.copy()
    unique_colors = ["C%s" % (starting_color + i) for i in np.unique(part1)]
    nodes_color_part1 = []
    for cluster in part1:
        nodes_color_part1.append(unique_colors[cluster])

    nodes_color_part2 = []
    # T: getting colors assignment from argmin of columns
    for i in range(len(G2.nodes())):
        j = np.argmax(T[:, i])
        nodes_color_part2.append(nodes_color_part1[j])
    pos1 = draw_graph(
        G1,
        C1,
        nodes_color_part1,
        Gweights=p1,
        pos=pos1,
        node_size=node_size,
        shiftx=0,
        seed=seed_G1,
    )
    pos2 = draw_graph(
        G2,
        C2,
        nodes_color_part2,
        Gweights=p2,
        pos=pos2,
        node_size=node_size,
        shiftx=shiftx,
        seed=seed_G2,
    )
    for k1, v1 in pos1.items():
        for k2, v2 in pos2.items():
            if T[k1, k2] > 0:
                pl.plot(
                    [pos1[k1][0], pos2[k2][0]],
                    [pos1[k1][1], pos2[k2][1]],
                    "-",
                    lw=0.8,
                    alpha=0.5,
                    color=nodes_color_part1[k1],
                )
    return pos1, pos2


node_size = 40
fontsize = 10
seed_G2 = 0
seed_G3 = 4

pl.figure(1, figsize=(8, 2.5))
pl.clf()
pl.subplot(121)
pl.axis("off")
pl.axis
pl.title(
    r"srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$" % (np.round(srgw_23, 3)),
    fontsize=fontsize,
)

hbar2 = OT_23.sum(axis=0)
pos1, pos2 = draw_transp_colored_srGW(
    weightedG2,
    C2,
    weightedG3,
    C3,
    part_G2,
    p1=None,
    p2=hbar2,
    T=OT_23,
    shiftx=1.5,
    node_size=node_size,
    seed_G1=seed_G2,
    seed_G2=seed_G3,
)
pl.subplot(122)
pl.axis("off")
hbar3 = OT_32.sum(axis=0)
pl.title(
    r"srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$" % (np.round(srgw_32, 3)),
    fontsize=fontsize,
)
pos1, pos2 = draw_transp_colored_srGW(
    weightedG3,
    C3,
    weightedG2,
    C2,
    part_G3,
    p1=None,
    p2=hbar3,
    T=OT_32,
    pos1=pos2,
    pos2=pos1,
    shiftx=3.0,
    node_size=node_size,
    seed_G1=0,
    seed_G2=0,
)
pl.tight_layout()

pl.show()
srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =0.07$, srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =0.171$

添加节点特征

# We add node features with given mean - by clusters
# and inversely proportional to clusters' intra-connectivity

F2 = np.zeros((N2, 1))
for i, c in enumerate(part_G2):
    F2[i, 0] = np.random.normal(loc=c, scale=0.01)

F3 = np.zeros((N3, 1))
for i, c in enumerate(part_G3):
    F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01)

计算它们的半松弛融合Gromov-Wasserstein散度

alpha = 0.5
# Compute pairwise euclidean distance between node features
M = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T)

# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference

OT, log = fused_gromov_wasserstein(
    M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True
)
fgw = log["fgw_dist"]

# 1) srFGW(C2, F2, h2, C3, F3)
OT_23, log_23 = semirelaxed_fused_gromov_wasserstein(
    M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None
)
srfgw_23 = log_23["srfgw_dist"]

# 2) srFGW(C3, F3, h3, C2, F2)

OT_32, log_32 = semirelaxed_fused_gromov_wasserstein(
    M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None
)
srfgw_32 = log_32["srfgw_dist"]

print("FGW(C2, F2, C3, F3) = ", fgw)
print("srGW(C2, F2, h2, C3, F3) = ", srfgw_23)
print("srGW(C3, F3, h3, C2, F2) = ", srfgw_32)
FGW(C2, F2, C3, F3) =  0.3778254858275072
srGW(C2, F2, h2, C3, F3) =  0.03757413947207429
srGW(C3, F3, h3, C2, F2) =  0.23454191747683045

半放松融合Gromov-Wasserstein匹配的可视化

我们为右侧图中的节点着色 - 然后根据 srFGW 匹配的最优传输计划投影其节点颜色 注意:颜色指的是簇 - 而不是节点特征

pl.figure(2, figsize=(8, 2.5))
pl.clf()
pl.subplot(121)
pl.axis("off")
pl.axis
pl.title(
    r"srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$"
    % (np.round(srfgw_23, 3)),
    fontsize=fontsize,
)

hbar2 = OT_23.sum(axis=0)
pos1, pos2 = draw_transp_colored_srGW(
    weightedG2,
    C2,
    weightedG3,
    C3,
    part_G2,
    p1=None,
    p2=hbar2,
    T=OT_23,
    shiftx=1.5,
    node_size=node_size,
    seed_G1=seed_G2,
    seed_G2=seed_G3,
)
pl.subplot(122)
pl.axis("off")
hbar3 = OT_32.sum(axis=0)
pl.title(
    r"srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$"
    % (np.round(srfgw_32, 3)),
    fontsize=fontsize,
)
pos1, pos2 = draw_transp_colored_srGW(
    weightedG3,
    C3,
    weightedG2,
    C2,
    part_G3,
    p1=None,
    p2=hbar3,
    T=OT_32,
    pos1=pos2,
    pos2=pos1,
    shiftx=3.0,
    node_size=node_size,
    seed_G1=0,
    seed_G2=0,
)
pl.tight_layout()

pl.show()
srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =0.038$, srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =0.235$

脚本的总运行时间: (0分钟 1.723秒)

由 Sphinx-Gallery 生成的画廊