jax_privacy.noise_addition.matrix_factorization_privatizer

jax_privacy.noise_addition.matrix_factorization_privatizer(noising_matrix, *, stddev, prng_key=None, dtype=None, intermediate_strategy=SupportedStrategies.DEFAULT)[source]

Creates a gradient privatizer that adds correlated noise to gradients.

This implementation is described in Section 4.4 of [Correlated Noise Mechanisms for Differentially Private Learning] (https://arxiv.org/pdf/2506.08201). A different implementation will be used depending on whether the noising_matrix is a jax.Array or a StreamingMatrix. The dtype of the noise generated by this privatizer will be determined by the noising_matrix and the input gradients according to standard jax type promotion rules. The output of the privatize transformation will always match the input dtype.

For naming of these parameters, see ../matrix_factorization/README.md.

Parameters:
  • noising_matrix (Union[Array, ndarray, bool, number, bool, int, float, complex, StreamingMatrix]) – A matrix used to generate correlated noise. Noise samples will be distributed according to a multivariate Gaussian with covariance matrix noising_matrix.T @ noising_matrix.

  • stddev (float) – Standard deviation to use for the noise of this privatizer.

  • prng_key (Array | int | None) – An optional PRNGKey array representing the source of randomness.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – The dtype to use for intermediate noise. If specified, noise will be generated with this dtype, added to the input gradient according to normal jax type promotion rules, and then cast back to the gradient dtype.

  • intermediate_strategy (SupportedStrategies) – Strategy to use for generating intermediate noise.

Return type:

GradientTransformation

Returns:

An optax.GradientTransformation which adds samples from Gaussian correlated by noising_matrix (i.e., samples from a Gaussian with covariance noising_matrix.T @ noising_matrix), keyed by noise_key, to its stream of gradients.