Perturbed optimizers#

Open in Colab

We review in this notebook a universal method to transform any function \(f\) mapping a pytree to another pytree to a differentiable approximation \(f_\varepsilon\), using pertutbations following the method of Berthet et al. (2020).

For a random \(Z\) drawn from a distribution with continuous positive distribution \(\mu\) and a function \(f: X \to Y\), its perturbed approximation defined for any \(x \in X\) by

\[f_\varepsilon(x) = \mathbf{E}[f (x + \varepsilon Z )]\, .\]

We illustrate here on some examples, including the case of an optimizer function \(y^*\) over \(C\) defined for any cost \(\theta \in \mathbb{R}^d\) by

\[y^*(\theta) = \mathop{\mathrm{arg\,max}}_{y \in C} \langle y, \theta \rangle\, .\]

In this case, the perturbed optimizer is given by

\[y_\varepsilon^*(\theta) = \mathbf{E}[\mathop{\mathrm{arg\,max}}_{y\in C} \langle y, \theta + \varepsilon Z \rangle]\, .\]
import jax
import jax.numpy as jnp
from jax import tree_util as jtu

import optax.tree
from optax import perturbations

Argmax one-hot#

We consider an optimizer, such as the following argmax_one_hot function. It transforms a real-valued vector into a binary vector with a 1 in the coefficient with largest magnitude and 0 elsewhere. It corresponds to \(y^*\) for \(C\) being the unit simplex. We run it on an example input values.

One-hot function#

def argmax_one_hot(x, axis=-1):
  return jax.nn.one_hot(jnp.argmax(x, axis=axis), x.shape[axis])
values = jnp.array([-0.6, 1.9, -0.2, 1.1, -1.0])

one_hot_vec = argmax_one_hot(values)
print(one_hot_vec)
[0. 1. 0. 0. 0.]

One-hot with pertubations#

Our implementation transforms the argmax_one_hot function into a perturbed one that we call pert_one_hot. In this case we use Gumbel noise for the perturbation.

N_SAMPLES = 100
SIGMA = 0.5
GUMBEL = perturbations.Gumbel()

rng = jax.random.PRNGKey(1)
pert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,
                                                num_samples=N_SAMPLES,
                                                sigma=SIGMA,
                                                noise=GUMBEL)

In this particular case, it is equal to the usual softmax function. This is not always true, in general there is no closed form for \(y_\varepsilon^*\)

rngs = jax.random.split(rng, 2)

rng = rngs[0]

pert_argmax = pert_one_hot(rng, values)
print(f'computation with {N_SAMPLES} samples, sigma = {SIGMA}')
print(f'perturbed argmax = {pert_argmax}')
jax.nn.softmax(values/SIGMA)
soft_max = jax.nn.softmax(values/SIGMA)
print(f'softmax = {soft_max}')
print(f'square norm of softmax = {jnp.linalg.norm(soft_max):.2e}')
print(f'square norm of difference = {jnp.linalg.norm(pert_argmax - soft_max):.2e}')
computation with 100 samples, sigma = 0.5
perturbed argmax = [0.02       0.87       0.01       0.09999999 0.        ]
softmax = [0.00549293 0.8152234  0.01222475 0.16459078 0.00246813]
square norm of softmax = 8.32e-01
square norm of difference = 8.60e-02

Gradients for one-hot with perturbations#

The perturbed optimizer \(y_\varepsilon^*\) is differentiable, and its gradient can be computed with stochastic estimation automatically, using jax.grad.

We create a scalar loss loss_simplex of the perturbed optimizer \(y^*_\varepsilon\)

\[\ell_\text{simplex}(y_{\text{true}} = y_\varepsilon^*; y_{\text{true}})\]

For values equal to a vector \(\theta\), we can compute gradients of

\[\ell(\theta) = \ell_\text{simplex}(y_\varepsilon^*(\theta); y_{\text{true}})\]

with respect to values, automatically.

# Example loss function

def loss_simplex(values, rng):
  n = values.shape[0]
  v_true = jnp.arange(n) + 2
  y_true = v_true / jnp.sum(v_true)
  y_pred = pert_one_hot(rng, values)
  return jnp.sum((y_true - y_pred) ** 2)

loss_simplex(values, rngs[1])
Array(0.7062, dtype=float32)

We can compute the gradient of \(\ell\) directly

\[\nabla_\theta \ell(\theta) = \partial_\theta y^*_\varepsilon(\theta) \cdot \nabla_1 \ell_{\text{simplex}}(y^*_\varepsilon(\theta); y_{\text{true}})\]

The computation of the jacobian \(\partial_\theta y^*_\varepsilon(\theta)\) is implemented automatically, using an estimation method given by Berthet et al. (2020), [Prop. 3.1].

# Gradient of the loss w.r.t input values

gradient = jax.grad(loss_simplex)(values, rngs[1])
print(gradient)
[-0.09853157  0.10874727 -0.11743014 -0.17878106  0.16792142]

We illustrate the use of this method by running 200 steps of gradient descent on \(\theta_t\) so that it minimizes this loss.

# Doing 200 steps of gradient descent on the values to have the desired ranks

steps = 200
values_t = values
eta = 0.5

grad_func = jax.jit(jax.grad(loss_simplex))

for t in range(steps):
  rngs = jax.random.split(rngs[1], 2)
  values_t = values_t - eta * grad_func(values_t, rngs[1])
rngs = jax.random.split(rngs[1], 2)

n = values.shape[0]
v_true = jnp.arange(n) + 2
y_true = v_true / jnp.sum(v_true)

print(f'initial values = {values}')
print(f'initial one-hot = {argmax_one_hot(values)}')
print(f'initial diff. one-hot = {pert_one_hot(rngs[0], values)}')
print()
print(f'values after GD = {values_t}')
print(f'ranks after GD = {argmax_one_hot(values_t)}')
print(f'diff. one-hot after GD = {pert_one_hot(rngs[1], values_t)}')
print(f'target diff. one-hot = {y_true}')
initial values = [-0.6  1.9 -0.2  1.1 -1. ]
initial one-hot = [0. 1. 0. 0. 0.]
initial diff. one-hot = [0.01       0.83       0.01       0.14999999 0.        ]

values after GD = [-0.11097738  0.10103489  0.28753668  0.3747991   0.47736812]
ranks after GD = [0. 0. 0. 0. 1.]
diff. one-hot after GD = [0.08       0.17999999 0.21       0.26999998 0.26      ]
target diff. one-hot = [0.1  0.15 0.2  0.25 0.3 ]

Differentiable ranking#

Ranking function#

We consider an optimizer, such as the following ranking function. It transforms a real-valued vector of size \(n\) into a vector with coefficients being a permutation of \(\{0,\ldots, n-1\}\) corresponding to the order of the coefficients of the original vector. It corresponds to \(y^*\) for \(C\) being the permutahedron. We run it on an example input values.

# Function outputting a vector of ranks

def ranking(values):
  return jnp.argsort(jnp.argsort(values))
# Example on random values

n = 6

rng = jax.random.PRNGKey(0)
values = jax.random.normal(rng, (n,))

print(f'values = {values}')
print(f'ranking = {ranking(values)}')
values = [ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909  -0.97208923]
ranking = [4 5 1 2 3 0]

Ranking with perturbations#

As above, our implementation transforms this function into a perturbed one that we call pert_ranking. In this case we use Gumbel noise for the perturbation.

N_SAMPLES = 100
SIGMA = 0.2
GUMBEL = perturbations.Gumbel()

pert_ranking = perturbations.make_perturbed_fun(ranking,
                                                num_samples=N_SAMPLES,
                                                sigma=SIGMA,
                                                noise=GUMBEL)
# Expectation of the perturbed ranks on these values

rngs = jax.random.split(rng, 2)

diff_ranks = pert_ranking(rngs[0], values)
print(f'values = {values}')

print(f'diff_ranks = {diff_ranks}')
values = [ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909  -0.97208923]
diff_ranks = [4.11 4.89 1.17 2.02 2.76 0.05]

Gradients for ranking with perturbations#

As above, the perturbed optimizer \(y_\varepsilon^*\) is differentiable, and its gradient can be computed with stochastic estimation automatically, using jax.grad.

We showcase this on a loss of \(y_\varepsilon(\theta)\) that can be directly differentiated w.r.t. the values equal to \(\theta\).

# Example loss function

def loss_example(values, rng):
  n = values.shape[0]
  y_true = ranking(jnp.arange(n))
  y_pred = pert_ranking(rng, values)
  return jnp.sum((y_true - y_pred) ** 2)

print(loss_example(values, rngs[1]))
59.774796
# Gradient of the objective w.r.t input values

gradient = jax.grad(loss_example)(values, rngs[1])
print(gradient)
[-1.4866201 -1.724831   2.797777  -0.1345453 -1.9688779 -1.8026495]

As above, we showcase this example on gradient descent to minimize this loss.

steps = 20
values_t = values
eta = 0.1

grad_func = jax.jit(jax.grad(loss_example))

for t in range(steps):
  rngs = jax.random.split(rngs[1], 2)
  values_t = values_t - eta * grad_func(values_t, rngs[1])
rngs = jax.random.split(rngs[1], 2)

y_true = ranking(jnp.arange(n))

print(f'initial values = {values}')
print(f'initial ranks = {ranking(values)}')
print(f'initial diff. ranks = {pert_ranking(rngs[0], values)}')
print()
print(f'values after GD = {values_t}')
print(f'ranks after GD = {ranking(values_t)}')
print(f'diff. ranks after GD = {pert_ranking(rngs[1], values_t)}')
print(f'target diff. ranks = {y_true}')
initial values = [ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909  -0.97208923]
initial ranks = [4 5 1 2 3 0]
initial diff. ranks = [4.0899997 4.91      1.1       1.99      2.84      0.07     ]

values after GD = [-1.9037365   1.7597162  -0.8777193   0.09295582  3.3749492   1.4055638 ]
ranks after GD = [0 4 1 2 5 3]
diff. ranks after GD = [0.02       3.86       0.98999995 1.99       5.         3.1399999 ]
target diff. ranks = [0 1 2 3 4 5]

General input / outputs (Pytrees)#

This method can be applied to any function taking pytrees as input and output in the forward mode, and can also be used to compute derivatives, as illustrated below

tree_a = (jnp.array((0.1, 0.4, 0.5)),
          {'k1': jnp.array((0.1, 0.2)),
           'k2': jnp.array((0.1, 0.1))},
          jnp.array((0.4, 0.3, 0.2, 0.1)))

Tree argmax#

This piecewise constant function applies the argmax to every leaf array of the pytree

argmax_tree = lambda x: jax.tree.map(argmax_one_hot, x)
argmax_tree(tree_a)
(Array([0., 0., 1.], dtype=float32),
 {'k1': Array([0., 1.], dtype=float32), 'k2': Array([1., 0.], dtype=float32)},
 Array([1., 0., 0., 0.], dtype=float32))

The perturbed approximation applies a perturbed softmax

N_SAMPLES = 100
sigma = 1.0

pert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,
                                                   num_samples=N_SAMPLES,
                                                   sigma=SIGMA)
pert_argmax_fun(rng, tree_a)
(Array([0.07, 0.35, 0.58], dtype=float32),
 {'k1': Array([0.39999998, 0.59999996], dtype=float32),
  'k2': Array([0.5, 0.5], dtype=float32)},
 Array([0.59, 0.24, 0.09, 0.08], dtype=float32))

Scalar loss#

def pert_loss(inputs, rng):
  pert_softmax = pert_argmax_fun(rng, inputs)
  argmax = argmax_tree(inputs)
  diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)
  return optax.tree.sum(diffs)
init_loss = pert_loss(tree_a, rng)

print(f'initial loss value = {init_loss:.3f}')
initial loss value = 0.341

Gradient computation#

The gradient of the scalar loss can be evaluated

grad = jax.grad(pert_loss)(tree_a, rng)

print('Gradient of the scalar loss')
print()
grad
Gradient of the scalar loss
(Array([ 0.27816886,  0.34292346, -0.5273831 ], dtype=float32),
 {'k1': Array([ 0.3098759, -0.3945589], dtype=float32),
  'k2': Array([-0.35475507,  1.0202795 ], dtype=float32)},
 Array([0.09084772, 0.21252295, 0.23311305, 0.3052454 ], dtype=float32))

A small step in the gradient direction reduces the value

eta = 1e-1

loss_step = pert_loss(optax.tree.add_scale(tree_a, -eta, grad), rng)

print(f'initial loss value = {init_loss:.3f}')
print(f'loss after gradient step = {loss_step:.3f}')
initial loss value = 0.341
loss after gradient step = 0.210