optax.contrib.dpsgd#
- optax.contrib.dpsgd(learning_rate: base.ScalarOrSchedule, l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, key: jax.Array | int | None = None, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False, *, seed: int | None = None) base.GradientTransformation[source]#
The DPSGD optimizer.
Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the modelβs parameters.
- Parameters:
learning_rate β A fixed global scaling factor.
l2_norm_clip β Maximum L2 norm of the per-example gradients.
noise_multiplier β Ratio of standard deviation to the clipping norm.
key β random generator key for noise generation.
momentum β Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.
nesterov β Whether Nesterov momentum is used.
seed β deprecated, use key instead.
- Returns:
References
Abadi et al, 2016 Deep Learning with Differential Privacy, 2016
Warning
This
optax.GradientTransformationexpects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).Warning
Generic gradient aggregation tools like
optax.MultiStepsoroptax.apply_every()wonβt work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.