optax.contrib.differentially_private_aggregate#
- optax.contrib.differentially_private_aggregate(l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, key: Array | int | None = None, *, seed: int | None = None) optax.GradientTransformation[source]#
Aggregates gradients based on the DPSGD algorithm.
- Parameters:
l2_norm_clip β maximum L2 norm of the per-example gradients.
noise_multiplier β ratio of standard deviation to the clipping norm.
key β random generator key for noise generation.
seed β deprecated, use key instead.
- Returns:
References
Abadi et al, 2016 Deep Learning with Differential Privacy, 2016
Warning
Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.
Warning
Generic gradient aggregation tools like
optax.MultiStepsoroptax.apply_every()wonβt work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.