分配问题#

hungarian_algorithm(cost_matrix)

匈牙利算法用于线性分配问题。

匈牙利算法#

optax.assignment.hungarian_algorithm(cost_matrix)[来源]#

匈牙利算法用于线性分配问题。

这个问题 中, 我们给定一个 \(n \times m\) 成本矩阵。目标是计算一个分配,即一组行和列的对,以满足以下条件:

  • 每行最多分配一个列。

  • 每列最多分配一行。

  • 作业的总数是 \(\min(n, m)\)

  • 该任务最小化成本之和。

等价地,给定一个加权的完全二分图,问题是找到一个最大基数的匹配,最小化包含在匹配中的边的权重总和。

正式地,问题如下。给定 \(C \in \mathbb{R}^{n \times m }\),求解以下 整数线性规划

\[\begin{align*} \text{最小化} \quad & \sum_{i \in [n]} \sum_{j \in [m]} C_{ij} X_{ij} \\ \text{满足条件} \quad & X_{ij} \in \{0, 1\} & \forall i \in [n], j \in [m] \\ & \sum_{i \in [n]} X_{ij} \leq 1 & \forall j \in [m] \\ & \sum_{j \in [m]} X_{ij} \leq 1 & \forall i \in [n] \\ & \sum_{i \in [n]} \sum_{j \in [m]} X_{ij} = \min(n, m) \end{align*}\]

匈牙利算法是一个立方时间算法,可以解决这个问题。

该实现基于Scenic库的实现(请参见参考文献)。

base_hungarian_algorithm 不同,这个版本产生了一个更简单的 Jaxpr,并且似乎更快。

Parameters:

cost_matrix – 成本矩阵。

Returns:

A pair (i, j) where i is an array of row indices and j is an array of column indices. The cost of the assignment is cost_matrix[i, j].sum().

示例

>>> import optax
>>> from jax import numpy as jnp
>>> cost = jnp.array(
...  [
...    [8, 4, 7],
...    [5, 2, 3],
...    [9, 6, 7],
...    [9, 4, 8],
...  ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 15
>>> cost = jnp.array(
...  [
...    [90, 80, 75, 70],
...    [35, 85, 55, 65],
...    [125, 95, 90, 95],
...    [45, 110, 95, 115],
...    [50, 100, 90, 100],
...  ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 265

参考文献

Dehghani 等人, Scenic: A JAX 库用于计算机视觉研究及其他, 2022