optax.contrib.dpsgd

Contents

optax.contrib.dpsgd#

optax.contrib.dpsgd(learning_rate: base.ScalarOrSchedule, l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, key: jax.Array | int | None = None, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False, *, seed: int | None = None) base.GradientTransformation[source]#

The DPSGD optimizer.

Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.

Parameters:
  • learning_rate – A fixed global scaling factor.

  • l2_norm_clip – Maximum L2 norm of the per-example gradients.

  • noise_multiplier – Ratio of standard deviation to the clipping norm.

  • key – random generator key for noise generation.

  • momentum – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov – Whether Nesterov momentum is used.

  • seed – deprecated, use key instead.

Returns:

A optax.GradientTransformation.

References

Abadi et al, 2016 Deep Learning with Differential Privacy, 2016

Warning

This optax.GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).

Warning

Generic gradient aggregation tools like optax.MultiSteps or optax.apply_every() won’t work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.