线性分配问题

线性分配问题#

线性分配问题是组合优化中的一个基本问题。

在这个问题中,我们给定一个\(n \times m\)的成本矩阵。 目标是计算一个分配,即一组行和列的对,使得:

  • 每行最多分配一列。

  • 每列最多分配一行。

  • 任务的总数是 \(\min(n, m)\)

  • 该分配最小化了成本的总和。

等价地,给定一个加权的完全二分图,问题是要找到一个最大基数的匹配,使得匹配中包含的边的权重之和最小。

正式地,问题如下。给定 \(C \in \mathbb{R}^{n \times m}\),解决以下 整数线性规划

\[ \begin{align*} \text{最小化} \quad & \sum_{i \in [n]} \sum_{j \in [m]} C_{ij} X_{ij} \\ \text{约束条件} \quad & X_{ij} \in \{0, 1\} & \forall i \in [n], j \in [m] \\ & \sum_{i \in [n]} X_{ij} \leq 1 & \forall j \in [m] \\ & \sum_{j \in [m]} X_{ij} \leq 1 & \forall i \in [n] \\ & \sum_{i \in [n]} \sum_{j \in [m]} X_{ij} = \min(n, m) \end{align*} \]

Hungarian algorithm 是一个用于此问题的立方时间算法。

首先,我们安装NetworkX,这是一个允许我们绘制图形的Python库。 你可以通过在终端运行以下命令来完成此操作:

pip install -U networkx

接下来,我们导入将要使用的库:

import networkx as nx
from jax import random
import optax
from matplotlib import pyplot as plt
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import networkx as nx
      2 from jax import random
      3 import optax

ModuleNotFoundError: No module named 'networkx'

我们随机采样一个成本矩阵:

n = 5  # number of rows
m = 3  # number of columns

key = random.key(0)
costs = random.normal(key, (n, m))
print(costs)
[[ 0.08482574  1.9097648   0.29561743]
 [ 1.120948    0.33432344 -0.82606775]
 [ 0.6481277  -1.0353061  -0.7824839 ]
 [-0.4539462   0.6297971   0.81524646]
 [-0.32787678 -1.1234448  -1.6607416 ]]

我们可以将成本矩阵可视化如下:

plt.imshow(costs, cmap="gray");
../../_images/bd9817fc6a5aeaef4b5223616d95fa6a98d599f934bda1a463b27c1d3328cdc2.png

我们还可以将成本可视化为加权二分图。下面,行显示为左侧的节点,列显示为右侧的节点。

G = nx.Graph()
rows = [f"row {i}" for i in range(n)]
cols = [f"col {j}" for j in range(m)]
edges = [(rows[i], cols[j], {"cost": costs[i, j]}) for i in range(n) for j in range(m)]
G.add_nodes_from(rows + cols)
G.add_edges_from(edges)
layout = nx.bipartite_layout(G, rows)
nx.draw(G, layout)
nx.draw_networkx_edge_labels(
    G,
    layout,
    edge_labels={(u, v): f"{info["cost"]:g}" for u, v, info in edges},
    rotate=False,
    font_size=8,
    bbox=dict(
        pad=0.0,
        facecolor="white",
        edgecolor="none",
    ),
);
../../_images/b2a2e0a97e0c8b34c7b200dad204b4d2d0b4252a60f18549e1a5736848644ccd.png

为了解决这个问题,我们在成本矩阵上调用optax.assignment.hungarian_algorithm()

sol_i, sol_j = optax.assignment.hungarian_algorithm(costs)
print(sol_i, sol_j)
[2 3 4] [1 0 2]

我们可以将解决方案可视化如下:

def highlight_cell(x, y, **kwargs):
    rect = plt.Rectangle((x - 0.5, y - 0.5), 1, 1, fill=False, **kwargs)
    plt.gca().add_patch(rect)
    return rect

plt.imshow(costs, cmap="gray")

for i, j in zip(sol_i, sol_j):
    highlight_cell(j, i, color="red", linewidth=3)

plt.show()
../../_images/94303af322e7edc398985adece1e3ee586dc1011f9ae2611f6c8480353dcaa12.png

我们还可以通过在之前的二分图上绘制来可视化解决方案。下面,解决方案中包含的节点和边以红色显示。

nx.draw(
    G,
    layout,
    node_color=["red" if i in sol_i else "grey" for i in range(n)] + ["red" if j in sol_j else "grey" for j in range(m)],
    edge_color=["red" if (i, j) in zip(sol_i, sol_j) else "grey" for i in range(n) for j in range(m)],
)
nx.draw_networkx_edge_labels(
    G,
    layout,
    edge_labels={(u, v): f"{info["cost"]:g}" for u, v, info in edges},
    rotate=False,
    font_size=8,
    bbox=dict(
        pad=0.0,
        facecolor="white",
        edgecolor="none",
    ),
);
../../_images/719e94f5f0a784dbe1e2b1cce06a7de6eb0a0ddc69a25a99fd1c57cc05f8270c.png

最优传输#

线性分配求解器可用于解决最优传输问题:给定一个点的多重集 \(X \in \mathbb{R}^{n \times d}\) 和另一个点的多重集 \(Y \in \mathbb{R}^{n \times d}\),找到一个排列 \(\pi \in \text{Sym}(n)\),使得总传输成本最小:

\[ C(\pi) = \sum_{i \in [n]} d(X_i, Y_{\pi(i)}) \]

其中 \(d\) 是一个 metric,例如在 \(\mathbb{R}^d\) 上的 Euclidean distance

下面是一个图示示例,其中\(X\)是一组设施位置\(Y\)是一组必须相互匹配的客户位置

import jax
import optax
from jax import numpy as jnp, random
from matplotlib import collections, pyplot as plt, rcParams

def get_optimal_transport(x, y):
    assert x.ndim == 2
    assert x.shape == y.shape
    displacements = x[:, None] - y[None, :]
    distance_matrix = jnp.linalg.norm(displacements, axis=-1)
    i, j = optax.assignment.hungarian_algorithm(distance_matrix)
    total_distance = distance_matrix[i, j].sum()
    return (i, j), total_distance

def main():
    num_points = 200
    markersize = 16.0
    
    key = random.key(0)
    keys = random.split(key)
    x = random.normal(keys[0], (num_points, 2))
    y = random.normal(keys[1], (num_points, 2)) + jnp.array([0.2, 0.0])

    (i, j), total_distance = get_optimal_transport(x, y)

    fig, ax = plt.subplots(constrained_layout=True)
    
    data = jnp.stack((x[i], y[j]), 1)
    lc = collections.LineCollection(data, linewidth=1.0, color="lightgrey", zorder=0, label="assignment")
    ax.add_collection(lc)

    ax.scatter(*x.T, s=markersize, edgecolor="none", label="facility")
    ax.scatter(*y.T, s=markersize, edgecolor="none", label="client")

    ax.set(title=f"Optimal transport distance: {total_distance:g}")
    ax.legend()
    plt.show()

main()
../../_images/dbb220ab9e23f5440a4fa911eb8ad89c7732d9192dccbae9af9bf97a3e533f26.png

这反过来可以用来估计两个分布之间的Wasserstein距离,通过从每个分布中采样大量点,然后计算这些批次之间的最优传输成本。

更准确地说,如果\(W_p\)表示\(p\)-Wasserstein距离,且\(P\)\(Q\)分别是具有样本\(X\)\(Y\)经验分布,那么:

\[ W_p(P, Q) = \inf_{\pi \in \text{Sym}(n)} \left( \frac{1}{n} \sum_{i=1}^n \|X_i - Y_{\pi(i)}\|^p \right)^{1/p} \]