Perturbed optimizers#
We review in this notebook a universal method to transform any function \(f\) mapping a pytree to another pytree to a differentiable approximation \(f_\varepsilon\), using pertutbations following the method of Berthet et al. (2020).
For a random \(Z\) drawn from a distribution with continuous positive distribution \(\mu\) and a function \(f: X \to Y\), its perturbed approximation defined for any \(x \in X\) by
We illustrate here on some examples, including the case of an optimizer function \(y^*\) over \(C\) defined for any cost \(\theta \in \mathbb{R}^d\) by
In this case, the perturbed optimizer is given by
import jax
import jax.numpy as jnp
from jax import tree_util as jtu
import optax.tree
from optax import perturbations
Argmax one-hot#
We consider an optimizer, such as the following argmax_one_hot function. It transforms a real-valued vector into a binary vector with a 1 in the coefficient with largest magnitude and 0 elsewhere. It corresponds to \(y^*\) for \(C\) being the unit simplex. We run it on an example input values.
One-hot function#
def argmax_one_hot(x, axis=-1):
return jax.nn.one_hot(jnp.argmax(x, axis=axis), x.shape[axis])
values = jnp.array([-0.6, 1.9, -0.2, 1.1, -1.0])
one_hot_vec = argmax_one_hot(values)
print(one_hot_vec)
[0. 1. 0. 0. 0.]
One-hot with pertubations#
Our implementation transforms the argmax_one_hot function into a perturbed one that we call pert_one_hot. In this case we use Gumbel noise for the perturbation.
N_SAMPLES = 100
SIGMA = 0.5
GUMBEL = perturbations.Gumbel()
rng = jax.random.PRNGKey(1)
pert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,
num_samples=N_SAMPLES,
sigma=SIGMA,
noise=GUMBEL)
In this particular case, it is equal to the usual softmax function. This is not always true, in general there is no closed form for \(y_\varepsilon^*\)
rngs = jax.random.split(rng, 2)
rng = rngs[0]
pert_argmax = pert_one_hot(rng, values)
print(f'computation with {N_SAMPLES} samples, sigma = {SIGMA}')
print(f'perturbed argmax = {pert_argmax}')
jax.nn.softmax(values/SIGMA)
soft_max = jax.nn.softmax(values/SIGMA)
print(f'softmax = {soft_max}')
print(f'square norm of softmax = {jnp.linalg.norm(soft_max):.2e}')
print(f'square norm of difference = {jnp.linalg.norm(pert_argmax - soft_max):.2e}')
computation with 100 samples, sigma = 0.5
perturbed argmax = [0.02 0.87 0.01 0.09999999 0. ]
softmax = [0.00549293 0.8152234 0.01222475 0.16459078 0.00246813]
square norm of softmax = 8.32e-01
square norm of difference = 8.60e-02
Gradients for one-hot with perturbations#
The perturbed optimizer \(y_\varepsilon^*\) is differentiable, and its gradient can be computed with stochastic estimation automatically, using jax.grad.
We create a scalar loss loss_simplex of the perturbed optimizer \(y^*_\varepsilon\)
For values equal to a vector \(\theta\), we can compute gradients of
with respect to values, automatically.
# Example loss function
def loss_simplex(values, rng):
n = values.shape[0]
v_true = jnp.arange(n) + 2
y_true = v_true / jnp.sum(v_true)
y_pred = pert_one_hot(rng, values)
return jnp.sum((y_true - y_pred) ** 2)
loss_simplex(values, rngs[1])
Array(0.7062, dtype=float32)
We can compute the gradient of \(\ell\) directly
The computation of the jacobian \(\partial_\theta y^*_\varepsilon(\theta)\) is implemented automatically, using an estimation method given by Berthet et al. (2020), [Prop. 3.1].
# Gradient of the loss w.r.t input values
gradient = jax.grad(loss_simplex)(values, rngs[1])
print(gradient)
[-0.09853157 0.10874727 -0.11743014 -0.17878106 0.16792142]
We illustrate the use of this method by running 200 steps of gradient descent on \(\theta_t\) so that it minimizes this loss.
# Doing 200 steps of gradient descent on the values to have the desired ranks
steps = 200
values_t = values
eta = 0.5
grad_func = jax.jit(jax.grad(loss_simplex))
for t in range(steps):
rngs = jax.random.split(rngs[1], 2)
values_t = values_t - eta * grad_func(values_t, rngs[1])
rngs = jax.random.split(rngs[1], 2)
n = values.shape[0]
v_true = jnp.arange(n) + 2
y_true = v_true / jnp.sum(v_true)
print(f'initial values = {values}')
print(f'initial one-hot = {argmax_one_hot(values)}')
print(f'initial diff. one-hot = {pert_one_hot(rngs[0], values)}')
print()
print(f'values after GD = {values_t}')
print(f'ranks after GD = {argmax_one_hot(values_t)}')
print(f'diff. one-hot after GD = {pert_one_hot(rngs[1], values_t)}')
print(f'target diff. one-hot = {y_true}')
initial values = [-0.6 1.9 -0.2 1.1 -1. ]
initial one-hot = [0. 1. 0. 0. 0.]
initial diff. one-hot = [0.01 0.83 0.01 0.14999999 0. ]
values after GD = [-0.11097738 0.10103489 0.28753668 0.3747991 0.47736812]
ranks after GD = [0. 0. 0. 0. 1.]
diff. one-hot after GD = [0.08 0.17999999 0.21 0.26999998 0.26 ]
target diff. one-hot = [0.1 0.15 0.2 0.25 0.3 ]
Differentiable ranking#
Ranking function#
We consider an optimizer, such as the following ranking function. It transforms a real-valued vector of size \(n\) into a vector with coefficients being a permutation of \(\{0,\ldots, n-1\}\) corresponding to the order of the coefficients of the original vector. It corresponds to \(y^*\) for \(C\) being the permutahedron. We run it on an example input values.
# Function outputting a vector of ranks
def ranking(values):
return jnp.argsort(jnp.argsort(values))
# Example on random values
n = 6
rng = jax.random.PRNGKey(0)
values = jax.random.normal(rng, (n,))
print(f'values = {values}')
print(f'ranking = {ranking(values)}')
values = [ 1.6226422 2.0252647 -0.43359444 -0.07861735 0.1760909 -0.97208923]
ranking = [4 5 1 2 3 0]
Ranking with perturbations#
As above, our implementation transforms this function into a perturbed one that we call pert_ranking. In this case we use Gumbel noise for the perturbation.
N_SAMPLES = 100
SIGMA = 0.2
GUMBEL = perturbations.Gumbel()
pert_ranking = perturbations.make_perturbed_fun(ranking,
num_samples=N_SAMPLES,
sigma=SIGMA,
noise=GUMBEL)
# Expectation of the perturbed ranks on these values
rngs = jax.random.split(rng, 2)
diff_ranks = pert_ranking(rngs[0], values)
print(f'values = {values}')
print(f'diff_ranks = {diff_ranks}')
values = [ 1.6226422 2.0252647 -0.43359444 -0.07861735 0.1760909 -0.97208923]
diff_ranks = [4.11 4.89 1.17 2.02 2.76 0.05]
Gradients for ranking with perturbations#
As above, the perturbed optimizer \(y_\varepsilon^*\) is differentiable, and its gradient can be computed with stochastic estimation automatically, using jax.grad.
We showcase this on a loss of \(y_\varepsilon(\theta)\) that can be directly differentiated w.r.t. the values equal to \(\theta\).
# Example loss function
def loss_example(values, rng):
n = values.shape[0]
y_true = ranking(jnp.arange(n))
y_pred = pert_ranking(rng, values)
return jnp.sum((y_true - y_pred) ** 2)
print(loss_example(values, rngs[1]))
59.774796
# Gradient of the objective w.r.t input values
gradient = jax.grad(loss_example)(values, rngs[1])
print(gradient)
[-1.4866201 -1.724831 2.797777 -0.1345453 -1.9688779 -1.8026495]
As above, we showcase this example on gradient descent to minimize this loss.
steps = 20
values_t = values
eta = 0.1
grad_func = jax.jit(jax.grad(loss_example))
for t in range(steps):
rngs = jax.random.split(rngs[1], 2)
values_t = values_t - eta * grad_func(values_t, rngs[1])
rngs = jax.random.split(rngs[1], 2)
y_true = ranking(jnp.arange(n))
print(f'initial values = {values}')
print(f'initial ranks = {ranking(values)}')
print(f'initial diff. ranks = {pert_ranking(rngs[0], values)}')
print()
print(f'values after GD = {values_t}')
print(f'ranks after GD = {ranking(values_t)}')
print(f'diff. ranks after GD = {pert_ranking(rngs[1], values_t)}')
print(f'target diff. ranks = {y_true}')
initial values = [ 1.6226422 2.0252647 -0.43359444 -0.07861735 0.1760909 -0.97208923]
initial ranks = [4 5 1 2 3 0]
initial diff. ranks = [4.0899997 4.91 1.1 1.99 2.84 0.07 ]
values after GD = [-1.9037365 1.7597162 -0.8777193 0.09295582 3.3749492 1.4055638 ]
ranks after GD = [0 4 1 2 5 3]
diff. ranks after GD = [0.02 3.86 0.98999995 1.99 5. 3.1399999 ]
target diff. ranks = [0 1 2 3 4 5]
General input / outputs (Pytrees)#
This method can be applied to any function taking pytrees as input and output in the forward mode, and can also be used to compute derivatives, as illustrated below
tree_a = (jnp.array((0.1, 0.4, 0.5)),
{'k1': jnp.array((0.1, 0.2)),
'k2': jnp.array((0.1, 0.1))},
jnp.array((0.4, 0.3, 0.2, 0.1)))
Tree argmax#
This piecewise constant function applies the argmax to every leaf array of the pytree
argmax_tree = lambda x: jax.tree.map(argmax_one_hot, x)
argmax_tree(tree_a)
(Array([0., 0., 1.], dtype=float32),
{'k1': Array([0., 1.], dtype=float32), 'k2': Array([1., 0.], dtype=float32)},
Array([1., 0., 0., 0.], dtype=float32))
The perturbed approximation applies a perturbed softmax
N_SAMPLES = 100
sigma = 1.0
pert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,
num_samples=N_SAMPLES,
sigma=SIGMA)
pert_argmax_fun(rng, tree_a)
(Array([0.07, 0.35, 0.58], dtype=float32),
{'k1': Array([0.39999998, 0.59999996], dtype=float32),
'k2': Array([0.5, 0.5], dtype=float32)},
Array([0.59, 0.24, 0.09, 0.08], dtype=float32))
Scalar loss#
def pert_loss(inputs, rng):
pert_softmax = pert_argmax_fun(rng, inputs)
argmax = argmax_tree(inputs)
diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)
return optax.tree.sum(diffs)
init_loss = pert_loss(tree_a, rng)
print(f'initial loss value = {init_loss:.3f}')
initial loss value = 0.341
Gradient computation#
The gradient of the scalar loss can be evaluated
grad = jax.grad(pert_loss)(tree_a, rng)
print('Gradient of the scalar loss')
print()
grad
Gradient of the scalar loss
(Array([ 0.27816886, 0.34292346, -0.5273831 ], dtype=float32),
{'k1': Array([ 0.3098759, -0.3945589], dtype=float32),
'k2': Array([-0.35475507, 1.0202795 ], dtype=float32)},
Array([0.09084772, 0.21252295, 0.23311305, 0.3052454 ], dtype=float32))
A small step in the gradient direction reduces the value
eta = 1e-1
loss_step = pert_loss(optax.tree.add_scale(tree_a, -eta, grad), rng)
print(f'initial loss value = {init_loss:.3f}')
print(f'loss after gradient step = {loss_step:.3f}')
initial loss value = 0.341
loss after gradient step = 0.210