Perturbations#
|
Returns a differentiable approximation of a function, using stochastic perturbations. |
|
Gumbel distribution. |
|
Normal distribution. |
Gumbel noise#
Make perturbed function#
- optax.perturbations.make_perturbed_fun(fun: Callable[[base.ArrayTree], base.ArrayTree], num_samples: int = 1000, sigma: jax.typing.ArrayLike = 0.1, noise=<optax.perturbations._make_pert.Gumbel object>, use_baseline=True) Callable[[base.PRNGKey, base.ArrayTree], base.ArrayTree][source]#
Returns a differentiable approximation of a function, using stochastic perturbations.
Let \(f\) be a function, \(\sigma\) be a scalar, \(\mu\) be a noise distribution, and
\[f_\sigma(x) = \mathbb{E}_{z \sim \mu} f(x + \sigma z) \]Given certain conditions on \(\mu\), \(f_\sigma\) is a smoothed, differentiable approximation of \(f\), even if \(f\) itself is not differentiable.
optax.perturbations.make_perturbed_fun()yields a stochastic function whose values and arbitrary-order derivatives (when computed through JAX’s automatic differentiation system) are unbiased Monte-Carlo estimates of the corresponding values and derivatives of \(f_\sigma\). These estimates are computed using only values (not derivatives) of \(f\), at stochastic perturbations of the input. Thus \(f\) itself does not have to be differentiable.- Parameters:
fun – The function to transform into a differentiable function. The signature currently supported is from pytree to pytree, whose leaves are JAX arrays.
num_samples – an int, the number of perturbed outputs to average over.
sigma – a float, the scale of the random perturbation.
noise – a distribution object that implements
sampleandlog_probmethods, likeoptax.perturbations.Gumbel(which is the default).use_baseline – Use the value of the function at the unperturbed input as a baseline for variance reduction.
- Returns:
A new function with the same signature as the original function, but with a leading random PRNG key argument.
Example
>>> import jax >>> import jax.numpy as jnp >>> from optax.perturbations import make_perturbed_fun >>> key = jax.random.key(0) >>> x = jnp.array([0.0, 0.0, 0.0]) >>> f = lambda x: jnp.sum(jnp.maximum(x, 0.0)) >>> fn = make_perturbed_fun(f, 1_000, 0.1) >>> with jnp.printoptions(precision=2): ... print(jax.grad(fn, argnums=1)(key, x)) [0.69 0.72 0.58]
Note
For the curious reader, \(f_\sigma\) can also be expressed as
\[f_\sigma(x) = \mathbb{E}_{y \sim \nu(x, \sigma)} f(y) \]where \(\nu(x, \sigma)\) is the probability distribution of the random variable \(x + \sigma z\).
The gradient can then be obtained by the score function estimator, a.k.a. REINFORCE. We implement the score function estimator through the “magic box” operator introduced by Foerster et al, 2018, so that the returned function provides stochastic estimates of derivatives of any order by simply using JAX’s automatic differentiation system.
References
Berthet et al., Learning with Differentiable Perturbed Optimizers, 2020
Foerster et al., DiCE: The Infinitely Differentiable Monte Carlo Estimator, 2018
Salimans et al., Evolution Strategies as a Scalable Alternative to Reinforcement Learning, 2017
See also
Perturbed optimizers example.