Assignment problem#
|
The Hungarian algorithm for the linear assignment problem. |
Hungarian algorithm#
- optax.assignment.hungarian_algorithm(cost_matrix)[source]#
The Hungarian algorithm for the linear assignment problem.
In this problem, we are given an \(n \times m\) cost matrix. The goal is to compute an assignment, i.e. a set of pairs of rows and columns, in such a way that:
At most one column is assigned to each row.
At most one row is assigned to each column.
The total number of assignments is \(\min(n, m)\).
The assignment minimizes the sum of costs.
Equivalently, given a weighted complete bipartite graph, the problem is to find a maximum-cardinality matching that minimizes the sum of the weights of the edges included in the matching.
Formally, the problem is as follows. Given \(C \in \mathbb{R}^{n \times m }\), solve the following integer linear program:
\[\begin{align*} \text{minimize} \quad & \sum_{i \in [n]} \sum_{j \in [m]} C_{ij} X_{ij} \\ \text{subject to} \quad & X_{ij} \in \{0, 1\} & \forall i \in [n], j \in [m] \\ & \sum_{i \in [n]} X_{ij} \leq 1 & \forall j \in [m] \\ & \sum_{j \in [m]} X_{ij} \leq 1 & \forall i \in [n] \\ & \sum_{i \in [n]} \sum_{j \in [m]} X_{ij} = \min(n, m) \end{align*}\]The Hungarian algorithm is a cubic-time algorithm that solves this problem.
This implementation is based on that of the Scenic library (see references).
Unlike base_hungarian_algorithm, this version yields a simpler Jaxpr and appears to be faster.
- Parameters:
cost_matrix โ A matrix of costs.
- Returns:
A pair
(i, j)whereiis an array of row indices andjis an array of column indices. The cost of the assignment iscost_matrix[i, j].sum().
Examples
>>> import optax >>> from jax import numpy as jnp >>> cost = jnp.array( ... [ ... [8, 4, 7], ... [5, 2, 3], ... [9, 6, 7], ... [9, 4, 8], ... ]) >>> i, j = optax.assignment.hungarian_algorithm(cost) >>> print("cost:", cost[i, j].sum()) cost: 15 >>> cost = jnp.array( ... [ ... [90, 80, 75, 70], ... [35, 85, 55, 65], ... [125, 95, 90, 95], ... [45, 110, 95, 115], ... [50, 100, 90, 100], ... ]) >>> i, j = optax.assignment.hungarian_algorithm(cost) >>> print("cost:", cost[i, j].sum()) cost: 265
References
Dehghani et al., Scenic: A JAX Library for Computer Vision Research and Beyond, 2022