jax_privacy.matrix_factorization.toeplitz.banded_inverse_square_root_noising_coefs

jax_privacy.matrix_factorization.toeplitz.banded_inverse_square_root_noising_coefs(num_bands, workload_coef=None)[source]

Returns Toeplitz noising coefficients for the BISR factorization.

This computes the first num_bands coefficients of the lower-triangular Toeplitz noising matrix $C^{-1}$ for the Banded Inverse Square Root (BISR) factorization introduced in https://arxiv.org/pdf/2505.12128. If workload_coef is not provided, this uses the default prefix-sum workload with all-ones Toeplitz coefficients. If workload_coef is provided, then it is treated as the Toeplitz coefficients of the workload; this can encode workload families such as those induced by SGD with momentum and weight decay. In that case, this function computes Toeplitz coefficients of the square root of the workload and then returns the first num_bands coefficients of its inverse.

Parameters:
  • num_bands (int) – The number of coefficients to return.

  • workload_coef (Array | None) – Optional Toeplitz coefficients of the workload.

Return type:

Array

Returns:

The coefficients of the lower-triangular Toeplitz noising matrix $C^{-1}$.