optax.schedules.inject_hyperparams

optax.schedules.inject_hyperparams#

optax.schedules.inject_hyperparams(inner_factory: Callable[..., base.GradientTransformation], static_args: str | Iterable[str] = (), hyperparam_dtype: jnp.dtype | None = None) Callable[..., base.GradientTransformationExtraArgs][source]#

Wrapper to injects stateful hyperparameters into GradientTransformations.

This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).

This function supports both passing simple schedules that are function exclusively of the step count and also passing stateful schedules that rely on a complex internal state. The state updating can rely on additional information fed to gradient transformations via extra_args.

For example, to use optax.scale_by_adam() with a piecewise linear schedule for beta_1 and constant for beta_2:

>>> import optax
>>> import jax.numpy as jnp
>>> # create a learning rate that increases linearly from 0.1 to 1.0
... # over 100 iterations
>>> linear_schedule = optax.piecewise_interpolate_schedule(
...    'linear', init_value=0.1, boundaries_and_scales={100: 1.})
>>> scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)(
...     b1=linear_schedule, b2=0.99)

You may manually change numeric hyperparameters that were not scheduled through the hyperparams dict in the InjectHyperparamsState:

>>> params, grads = jnp.array(0.), jnp.array(0.)
>>> state = scheduled_adam.init(params)
>>> updates, state = scheduled_adam.update(grads, state)
>>> state.hyperparams['b2'] = 0.95
>>> updates, state = scheduled_adam.update(updates, state)  # uses b2 = 0.95

Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust b1).

For mixed precision training, you may want hyperparameters to match a specific dtype. To avoid automatic casting to the parameter’s highest precision dtype, pass the hyperparameter as a strongly typed JAX array:

>>> scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)(
...     b1=linear_schedule, b2=jnp.array(0.99, jnp.float32))
Parameters:
  • inner_factory – a function that returns the inner optax.GradientTransformation with dynamic hyperparameters.

  • static_args – a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a non-schedule callable, you must specify that using this argument.

  • hyperparam_dtype – Optional datatype override. If specified, all float hyperparameters will be cast to this type.

Returns:

A callable that returns a optax.GradientTransformationExtraArgs. This callable accepts the same arguments as inner_factory, except you may provide schedules in place of the constant arguments.

Changed in version 0.1.9: New parameter hyperparam_dtype, the returned callable outputs a GradientTransformationExtraArgs instead of a GradientTransformation.