jax_privacy.noise_addition
Implementations of optax.GradientTransformations that add noise to gradients.
This module implements optax.GradientTransformations, which we informally call privatizers, that are responsible for taking clipped + aggregated gradients and adding noise to them. These noise-addition schemes are stateful, meaning the noise added to one gradient may depend on the noise that was added to previous gradients in various ways. In the simplest case, where i.i.d. gaussian noise is added to each gradient, this state is nothing more than a pseudo-random key, each call to update uses this key to generate fresh noise, and splits it into a new key for future steps.
- Example Usage:
>>> import jax >>> privatizer = gaussian_privatizer(stddev=1.0, prng_key=jax.random.key(0)) >>> model = grad = jax.numpy.zeros(10) >>> noise_state = privatizer.init(model) >>> for _ in range(4): ... noisy_grad, noise_state = privatizer.update( ... sum_of_clipped_grads=grad, noise_state=noise_state ... )
More powerful privatizers, like those based on matrix factorization have richer state representations, but this is abstracted away from the user via the optax.GradientTransformation interface. Different privatizers are fully swappable with each other using the above pattern with only one line of code changed.
As optax.GradientTransformations, these privatizers can be composed with other transformations, via optax.chain(privatizer, optimizer). These transformed privatizers enjoy the same privacy properties by the post-processing property.
Functions
|
Creates a gradient privatizer that adds correlated noise to gradients. |
Classes
|
Supported strategies for generating intermediate noise. |