线性分配问题#
线性分配问题是组合优化中的一个基本问题。
在这个问题中,我们给定一个\(n \times m\)的成本矩阵。 目标是计算一个分配,即一组行和列的对,使得:
每行最多分配一列。
每列最多分配一行。
任务的总数是 \(\min(n, m)\)。
该分配最小化了成本的总和。
等价地,给定一个加权的完全二分图,问题是要找到一个最大基数的匹配,使得匹配中包含的边的权重之和最小。
正式地,问题如下。给定 \(C \in \mathbb{R}^{n \times m}\),解决以下 整数线性规划:
Hungarian algorithm 是一个用于此问题的立方时间算法。
首先,我们安装NetworkX,这是一个允许我们绘制图形的Python库。 你可以通过在终端运行以下命令来完成此操作:
pip install -U networkx
接下来,我们导入将要使用的库:
import networkx as nx
from jax import random
import optax
from matplotlib import pyplot as plt
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 1
----> 1 import networkx as nx
2 from jax import random
3 import optax
ModuleNotFoundError: No module named 'networkx'
我们随机采样一个成本矩阵:
n = 5 # number of rows
m = 3 # number of columns
key = random.key(0)
costs = random.normal(key, (n, m))
print(costs)
[[ 0.08482574 1.9097648 0.29561743]
[ 1.120948 0.33432344 -0.82606775]
[ 0.6481277 -1.0353061 -0.7824839 ]
[-0.4539462 0.6297971 0.81524646]
[-0.32787678 -1.1234448 -1.6607416 ]]
我们可以将成本矩阵可视化如下:
plt.imshow(costs, cmap="gray");
我们还可以将成本可视化为加权二分图。下面,行显示为左侧的节点,列显示为右侧的节点。
G = nx.Graph()
rows = [f"row {i}" for i in range(n)]
cols = [f"col {j}" for j in range(m)]
edges = [(rows[i], cols[j], {"cost": costs[i, j]}) for i in range(n) for j in range(m)]
G.add_nodes_from(rows + cols)
G.add_edges_from(edges)
layout = nx.bipartite_layout(G, rows)
nx.draw(G, layout)
nx.draw_networkx_edge_labels(
G,
layout,
edge_labels={(u, v): f"{info["cost"]:g}" for u, v, info in edges},
rotate=False,
font_size=8,
bbox=dict(
pad=0.0,
facecolor="white",
edgecolor="none",
),
);
为了解决这个问题,我们在成本矩阵上调用optax.assignment.hungarian_algorithm()。
sol_i, sol_j = optax.assignment.hungarian_algorithm(costs)
print(sol_i, sol_j)
[2 3 4] [1 0 2]
我们可以将解决方案可视化如下:
def highlight_cell(x, y, **kwargs):
rect = plt.Rectangle((x - 0.5, y - 0.5), 1, 1, fill=False, **kwargs)
plt.gca().add_patch(rect)
return rect
plt.imshow(costs, cmap="gray")
for i, j in zip(sol_i, sol_j):
highlight_cell(j, i, color="red", linewidth=3)
plt.show()
我们还可以通过在之前的二分图上绘制来可视化解决方案。下面,解决方案中包含的节点和边以红色显示。
nx.draw(
G,
layout,
node_color=["red" if i in sol_i else "grey" for i in range(n)] + ["red" if j in sol_j else "grey" for j in range(m)],
edge_color=["red" if (i, j) in zip(sol_i, sol_j) else "grey" for i in range(n) for j in range(m)],
)
nx.draw_networkx_edge_labels(
G,
layout,
edge_labels={(u, v): f"{info["cost"]:g}" for u, v, info in edges},
rotate=False,
font_size=8,
bbox=dict(
pad=0.0,
facecolor="white",
edgecolor="none",
),
);
最优传输#
线性分配求解器可用于解决最优传输问题:给定一个点的多重集 \(X \in \mathbb{R}^{n \times d}\) 和另一个点的多重集 \(Y \in \mathbb{R}^{n \times d}\),找到一个排列 \(\pi \in \text{Sym}(n)\),使得总传输成本最小:
其中 \(d\) 是一个 metric,例如在 \(\mathbb{R}^d\) 上的 Euclidean distance。
下面是一个图示示例,其中\(X\)是一组设施位置,\(Y\)是一组必须相互匹配的客户位置。
import jax
import optax
from jax import numpy as jnp, random
from matplotlib import collections, pyplot as plt, rcParams
def get_optimal_transport(x, y):
assert x.ndim == 2
assert x.shape == y.shape
displacements = x[:, None] - y[None, :]
distance_matrix = jnp.linalg.norm(displacements, axis=-1)
i, j = optax.assignment.hungarian_algorithm(distance_matrix)
total_distance = distance_matrix[i, j].sum()
return (i, j), total_distance
def main():
num_points = 200
markersize = 16.0
key = random.key(0)
keys = random.split(key)
x = random.normal(keys[0], (num_points, 2))
y = random.normal(keys[1], (num_points, 2)) + jnp.array([0.2, 0.0])
(i, j), total_distance = get_optimal_transport(x, y)
fig, ax = plt.subplots(constrained_layout=True)
data = jnp.stack((x[i], y[j]), 1)
lc = collections.LineCollection(data, linewidth=1.0, color="lightgrey", zorder=0, label="assignment")
ax.add_collection(lc)
ax.scatter(*x.T, s=markersize, edgecolor="none", label="facility")
ax.scatter(*y.T, s=markersize, edgecolor="none", label="client")
ax.set(title=f"Optimal transport distance: {total_distance:g}")
ax.legend()
plt.show()
main()
这反过来可以用来估计两个分布之间的Wasserstein距离,通过从每个分布中采样大量点,然后计算这些批次之间的最优传输成本。
更准确地说,如果\(W_p\)表示\(p\)-Wasserstein距离,且\(P\)和\(Q\)分别是具有样本\(X\)和\(Y\)的经验分布,那么: