jax_privacy.matrix_factorization.toeplitz.optimize_banded_inverse_toeplitz

jax_privacy.matrix_factorization.toeplitz.optimize_banded_inverse_toeplitz(n, min_sep, num_bands, *, noising_coef=None, strategy_coef=None, workload_coef=None, max_participations=None, max_optimizer_steps=1000, reduction_fn=<function mean>)[source]

Optimize over banded inverse Toeplitz noising matrices for BandInvMF.

This function optimizes directly over the Toeplitz coefficients of the lower-triangular noising matrix $C^{-1}$ for a Toeplitz workload, following the BandInvMF construction introduced in https://arxiv.org/pdf/2505.12128. The objective is the reduced per-query squared error on the induced workload times the squared min_sep sensitivity of the implied strategy matrix $C$.

Parameters:
  • n (int) – The number of iterations that defines the workload.

  • min_sep (int) – The minimum separation between contributions from the same user.

  • num_bands (int) – The number of Toeplitz coefficients of the noising matrix to optimize, including the diagonal.

  • noising_coef (Array | None) – Optional initialization for the noising coefficients. If not provided, initializes from strategy_coef if given, otherwise from banded_inverse_square_root_noising_coefs(workload_coef=…). If longer than num_bands, the extra coefficients are ignored.

  • strategy_coef (Array | None) – Optional initialization for the strategy coefficients. If provided, the corresponding noising coefficients are computed via inverse_coef.

  • workload_coef (Array | None) – Optional Toeplitz coefficients of the workload. If not provided, the default prefix-sum workload of all ones is used.

  • max_participations (int | None) – Optional cap on the number of participations.

  • max_optimizer_steps (int) – The maximum number of L-BFGS iterations.

  • reduction_fn (Callable[[Array], Array]) – A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or lambda v: v[-1] to optimize last iterate squared error. Defaults to jnp.mean.

Return type:

Array

Returns:

The optimized Toeplitz coefficients of the lower-triangular noising matrix $C^{-1}$.