Stochastic Gradient Estimators#

measure_valued_jacobians(function, params, ...)

Measure valued gradient estimation.

pathwise_jacobians(function, params, ...)

Pathwise gradient estimation.

score_function_jacobians(function, params, ...)

Score function gradient estimation.

Measure valued Jacobians#

optax.monte_carlo.measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True)[source]#

Measure valued gradient estimation.

Approximates:

nabla_{theta} E_{p(x; theta)} f(x)

With:

1./ c (E_{p1(x; theta)} f(x) - E_{p2(x; theta)} f(x)) where p1 and p2 are measures which depend on p.

Currently only supports computing gradients of expectations of Gaussian RVs.

Parameters:
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution.

  • dist_builder (Callable[…, Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

  • coupling (bool) – A boolean. Whether or not to use coupling for the positive and negative samples. Recommended: True, as this reduces variance.

Return type:

Sequence[chex.Array]

Returns:

A tuple of size params, each element is num_samples x param.shape

jacobian vector containing the estimates of the gradients obtained for each sample.

The mean of this vector is the gradient wrt to parameters that can be used

for learning. The entire jacobian vector can be used to assess estimator variance.

Pathwise Jacobians#

optax.monte_carlo.pathwise_jacobians(function, params, dist_builder, rng, num_samples)[source]#

Pathwise gradient estimation.

Approximates:

nabla_{theta} E_{p(x; theta)} f(x)

With:
E_{p(epsilon)} nabla_{theta} f(g(epsilon, theta))

where x = g(epsilon, theta). g depends on the distribution p.

Requires: p to be reparametrizable and the reparametrization to be implemented

in tensorflow_probability. Applicable to continuous random variables. f needs to be differentiable.

Parameters:
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution.

  • dist_builder (Callable[…, Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

Return type:

Sequence[chex.Array]

Returns:

A tuple of size params, each element is num_samples x param.shape

jacobian vector containing the estimates of the gradients obtained for each sample.

The mean of this vector is the gradient wrt to parameters that can be used

for learning. The entire jacobian vector can be used to assess estimator variance.

Score function Jacobians#

optax.monte_carlo.score_function_jacobians(function, params, dist_builder, rng, num_samples)[source]#

Score function gradient estimation.

Approximates:

nabla_{theta} E_{p(x; theta)} f(x)

With:

E_{p(x; theta)} f(x) nabla_{theta} log p(x; theta)

Requires: p to be differentiable wrt to theta. Applicable to both continuous

and discrete random variables. No requirements on f.

Parameters:
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution.

  • dist_builder (Callable[…, Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

Return type:

Sequence[chex.Array]

Returns:

A tuple of size params, each element is num_samples x param.shape

jacobian vector containing the estimates of the gradients obtained for each sample.

The mean of this vector is the gradient wrt to parameters that can be used

for learning. The entire jacobian vector can be used to assess estimator variance.