Assignment problem

Assignment problem#

hungarian_algorithm(cost_matrix)

The Hungarian algorithm for the linear assignment problem.

Hungarian algorithm#

optax.assignment.hungarian_algorithm(cost_matrix)[source]#

The Hungarian algorithm for the linear assignment problem.

In this problem, we are given an \(n \times m\) cost matrix. The goal is to compute an assignment, i.e. a set of pairs of rows and columns, in such a way that:

  • At most one column is assigned to each row.

  • At most one row is assigned to each column.

  • The total number of assignments is \(\min(n, m)\).

  • The assignment minimizes the sum of costs.

Equivalently, given a weighted complete bipartite graph, the problem is to find a maximum-cardinality matching that minimizes the sum of the weights of the edges included in the matching.

Formally, the problem is as follows. Given \(C \in \mathbb{R}^{n \times m }\), solve the following integer linear program:

\[\begin{align*} \text{minimize} \quad & \sum_{i \in [n]} \sum_{j \in [m]} C_{ij} X_{ij} \\ \text{subject to} \quad & X_{ij} \in \{0, 1\} & \forall i \in [n], j \in [m] \\ & \sum_{i \in [n]} X_{ij} \leq 1 & \forall j \in [m] \\ & \sum_{j \in [m]} X_{ij} \leq 1 & \forall i \in [n] \\ & \sum_{i \in [n]} \sum_{j \in [m]} X_{ij} = \min(n, m) \end{align*}\]

The Hungarian algorithm is a cubic-time algorithm that solves this problem.

This implementation is based on that of the Scenic library (see references).

Unlike base_hungarian_algorithm, this version yields a simpler Jaxpr and appears to be faster.

Parameters:

cost_matrix โ€“ A matrix of costs.

Returns:

A pair (i, j) where i is an array of row indices and j is an array of column indices. The cost of the assignment is cost_matrix[i, j].sum().

Examples

>>> import optax
>>> from jax import numpy as jnp
>>> cost = jnp.array(
...  [
...    [8, 4, 7],
...    [5, 2, 3],
...    [9, 6, 7],
...    [9, 4, 8],
...  ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 15
>>> cost = jnp.array(
...  [
...    [90, 80, 75, 70],
...    [35, 85, 55, 65],
...    [125, 95, 90, 95],
...    [45, 110, 95, 115],
...    [50, 100, 90, 100],
...  ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 265

References

Dehghani et al., Scenic: A JAX Library for Computer Vision Research and Beyond, 2022