Stochastic Gradient Estimators#
Warning
This module has been deprecated and will be removed in optax 0.3.0.

Measure valued gradient estimation. 

Pathwise gradient estimation. 

Score function gradient estimation. 
Measure valued Jacobians#
 optax.monte_carlo.measure_valued_jacobians(function: Callable[[chex.Array], float], params: optax.Params, dist_builder: Callable[[...], Any], rng: Array, num_samples: int, coupling: bool = True) Sequence[chex.Array] [source]#
Measure valued gradient estimation.
 Approximates:
nabla_{theta} E_{p(x; theta)} f(x)
 With:
1./ c (E_{p1(x; theta)} f(x)  E_{p2(x; theta)} f(x)) where p1 and p2 are measures which depend on p.
Currently only supports computing gradients of expectations of Gaussian RVs.
 Parameters:
function – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.
params – A tuple of jnp arrays. The parameters for which to construct the distribution.
dist_builder – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.
rng – a PRNGKey key.
num_samples – Int, the number of samples used to compute the grads.
coupling – A boolean. Whether or not to use coupling for the positive and negative samples. Recommended: True, as this reduces variance.
 Returns:
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
 The mean of this vector is the gradient wrt to parameters that can be used
for learning. The entire jacobian vector can be used to assess estimator variance.
Deprecated since version 0.2.4: This function will be removed in 0.3.0
Pathwise Jacobians#
 optax.monte_carlo.pathwise_jacobians(function: Callable[[chex.Array], float], params: optax.Params, dist_builder: Callable[[...], Any], rng: Array, num_samples: int) Sequence[chex.Array] [source]#
Pathwise gradient estimation.
 Approximates:
nabla_{theta} E_{p(x; theta)} f(x)
 With:
 E_{p(epsilon)} nabla_{theta} f(g(epsilon, theta))
where x = g(epsilon, theta). g depends on the distribution p.
 Requires: p to be reparametrizable and the reparametrization to be implemented
in tensorflow_probability. Applicable to continuous random variables. f needs to be differentiable.
 Parameters:
function – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.
params – A tuple of jnp arrays. The parameters for which to construct the distribution.
dist_builder – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.
rng – a PRNGKey key.
num_samples – Int, the number of samples used to compute the grads.
 Returns:
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
 The mean of this vector is the gradient wrt to parameters that can be used
for learning. The entire jacobian vector can be used to assess estimator variance.
Deprecated since version 0.2.4: This function will be removed in 0.3.0
Score function Jacobians#
 optax.monte_carlo.score_function_jacobians(function: Callable[[chex.Array], float], params: optax.Params, dist_builder: Callable[[...], Any], rng: Array, num_samples: int) Sequence[chex.Array] [source]#
Score function gradient estimation.
 Approximates:
nabla_{theta} E_{p(x; theta)} f(x)
 With:
E_{p(x; theta)} f(x) nabla_{theta} log p(x; theta)
 Requires: p to be differentiable wrt to theta. Applicable to both continuous
and discrete random variables. No requirements on f.
 Parameters:
function – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.
params – A tuple of jnp arrays. The parameters for which to construct the distribution.
dist_builder – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.
rng – a PRNGKey key.
num_samples – Int, the number of samples used to compute the grads.
 Returns:
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
 The mean of this vector is the gradient wrt to parameters that can be used
for learning. The entire jacobian vector can be used to assess estimator variance.
Deprecated since version 0.2.4: This function will be removed in 0.3.0