Control Variates#

control_delta_method(function)

The control delta covariate method.

control_variates_jacobians(function,Β ...[,Β ...])

Obtain jacobians using control variates.

moving_avg_baseline(function[,Β decay,Β ...])

A moving average baseline.

Control delta method#

optax.monte_carlo.control_delta_method(function)[source]#

The control delta covariate method.

Control variate obtained by performing a second order Taylor expansion

on the cost function f at the mean of the input distribution.

Only implemented for Gaussian random variables.

For details, see: https://icml.cc/2012/papers/687.pdf

Parameters:

function (Callable[[chex.Array], float]) – The function for which to compute the control variate. The function takes in one argument (a sample from the distribution) and returns a floating point value.

Return type:

ControlVariate

Returns:

A tuple of three functions, to compute the control variate, the expected value of the control variate, and to update the control variate state.

Control variates Jacobians#

optax.monte_carlo.control_variates_jacobians(function, control_variate_from_function, grad_estimator, params, dist_builder, rng, num_samples, control_variate_state=None, estimate_cv_coeffs=False, estimate_cv_coeffs_num_samples=20)[source]#

Obtain jacobians using control variates.

We will compute each term individually. The first term will use stochastic

gradient estimation. The second term will be computes using Monte Carlo estimation and automatic differentiation to compute nabla_{theta} h(x; theta). The the third term will be computed using automatic differentiation, as we restrict ourselves to control variates which compute this expectation in closed form.

This function updates the state of the control variate (once), before

computing the control variate coefficients.

Parameters:
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • control_variate_from_function (Callable[[Callable[[chex.Array], float]], ControlVariate]) – The control variate to use to reduce variance. See control_delta_method and moving_avg_baseline examples.

  • grad_estimator (Callable[…, jnp.ndarray]) – The gradient estimator to be used to compute the gradients. Note that not all control variates will reduce variance for all estimators. For example, the moving_avg_baseline will make no difference to the measure valued or pathwise estimators.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution and for which we want to compute the jacobians.

  • dist_builder (Callable[…, Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

  • control_variate_state (CvState) – The control variate state. This is used for control variates which keep states (such as the moving average baselines).

  • estimate_cv_coeffs (bool) – Boolean. Whether or not to estimate the optimal control variate coefficient via estimate_control_variate_coefficients.

  • estimate_cv_coeffs_num_samples (int) – The number of samples to use to estimate the optimal coefficient. These need to be new samples to ensure that the objective is unbiased.

Returns:

  • A tuple of size params, each element is num_samples x param.shape

    jacobian vector containing the estimates of the gradients obtained for each sample.

    The mean of this vector is the gradient wrt to parameters that can be used for learning. The entire jacobian vector can be used to assess estimator variance.

  • The updated CV state.

Return type:

A tuple of size two

Moving average baseline#

optax.monte_carlo.moving_avg_baseline(function, decay=0.99, zero_debias=True, use_decay_early_training_heuristic=True)[source]#

A moving average baseline.

It has no effect on the pathwise or measure valued estimator.

Parameters:
  • function (Callable[[chex.Array], float]) – The function for which to compute the control variate. The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • decay (float) – The decay rate for the moving average.

  • zero_debias (bool) – Whether or not to use zero debiasing for the moving average.

  • use_decay_early_training_heuristic –

    Whether or not to use a heuristic which overrides the decay value early in training based on

    min(decay, (1.0 + i) / (10.0 + i)). This stabilises training and was adapted from the Tensorflow codebase.

Return type:

ControlVariate

Returns:

A tuple of three functions, to compute the control variate, the expected value of the control variate, and to update the control variate state.