jax_privacy.optimizers
Augmented gradient transformations for differentially private training.
Gradient-based DP training algorithms may need to specify pre-processing of the per-example gradients before clipping and noising happens. Because this is tightly linked to the optimizer, we provide an AugmentedGradientTransformation that provides a pre-processing function (the pre_clipping_transform that can be passed into jax_privacy_clipped_grad) bound together with the core optimizer update.
- The primary use case is the “scale-then-privatize” technique from:
Ganesh, McMahan, Thakurta. “On Design Principles for Private Adaptive Optimizers.” arXiv:2507.01129.
Example Usage (clipping but no noise): >>> from jax_privacy import clipped_grad, noise_addition >>> import optax >>> import jax.numpy as jnp >>> loss_fn = lambda params, batch: 0.5 * jnp.mean((params - batch) ** 2) >>> optimizer = scale_then_privatize(optax.adamw(1e-3)) >>> params = jnp.ones(3) >>> data = jnp.ones((10, 3)) >>> state = optimizer.init(params) >>> noise_multiplier = 0.0 >>> noise_state = noise_addition.gaussian_privatizer(stddev=0.0).init(params) >>> for _ in range(5): … grad_fn = clipped_grad( … loss_fn, … l2_clip_norm=1, … pre_clipping_transform=optimizer.pre_clipping_transform(state) … ) … stddev = grad_fn.sensitivity() * noise_multiplier … noise_fn = noise_addition.gaussian_privatizer(stddev=stddev) … clipped_grads = grad_fn(params, data) … noisy_grads, noise_state = noise_fn.update(clipped_grads, noise_state) … updates, state = optimizer.update(noisy_grads, state, params) … params = optax.apply_updates(params, updates)
Functions
|
Wraps a plain Optax optimizer with an identity pre-clipping transform. |
|
Constructs an AugmentedGradientTransformation for scale-then-privatize. |
Classes
|
A gradient transformation augmented with a pre-clipping transform. |
|
A function that applies a transformation to a pytree of updates. |