jax_privacy.optimizers

Augmented gradient transformations for differentially private training.

Gradient-based DP training algorithms may need to specify pre-processing of the per-example gradients before clipping and noising happens. Because this is tightly linked to the optimizer, we provide an AugmentedGradientTransformation that provides a pre-processing function (the pre_clipping_transform that can be passed into jax_privacy_clipped_grad) bound together with the core optimizer update.

The primary use case is the “scale-then-privatize” technique from:

Ganesh, McMahan, Thakurta. “On Design Principles for Private Adaptive Optimizers.” arXiv:2507.01129.

Example Usage (clipping but no noise): >>> from jax_privacy import clipped_grad, noise_addition >>> import optax >>> import jax.numpy as jnp >>> loss_fn = lambda params, batch: 0.5 * jnp.mean((params - batch) ** 2) >>> optimizer = scale_then_privatize(optax.adamw(1e-3)) >>> params = jnp.ones(3) >>> data = jnp.ones((10, 3)) >>> state = optimizer.init(params) >>> noise_multiplier = 0.0 >>> noise_state = noise_addition.gaussian_privatizer(stddev=0.0).init(params) >>> for _ in range(5): … grad_fn = clipped_grad( … loss_fn, … l2_clip_norm=1, … pre_clipping_transform=optimizer.pre_clipping_transform(state) … ) … stddev = grad_fn.sensitivity() * noise_multiplier … noise_fn = noise_addition.gaussian_privatizer(stddev=stddev) … clipped_grads = grad_fn(params, data) … noisy_grads, noise_state = noise_fn.update(clipped_grads, noise_state) … updates, state = optimizer.update(noisy_grads, state, params) … params = optax.apply_updates(params, updates)

Functions

as_augmented_optimizer(optimizer)

Wraps a plain Optax optimizer with an identity pre-clipping transform.

scale_then_privatize(base_optimizer[, eps, ...])

Constructs an AugmentedGradientTransformation for scale-then-privatize.

Classes

AugmentedGradientTransformation(init, ...)

A gradient transformation augmented with a pre-clipping transform.

PreClippingTransform(*args, **kwargs)

A function that applies a transformation to a pytree of updates.