jax_privacy.matrix_factorization.toeplitz.optimize_banded_toeplitz

jax_privacy.matrix_factorization.toeplitz.optimize_banded_toeplitz(n, bands, strategy_coef=None, max_optimizer_steps=250, reduction_fn=<function mean>)[source]

Optimize over the space of banded Toeplitz strategies on a Prefix workload.

The banded toeplitz strategies produced by this function can be used for both the single-participation setting and the multi-participation setting, (including both the fixed_epoch_order and min_sep participation schemas; see README.md) as long as the (minimum) separation between contributions from the same user is at least the number of bands provided. See https://arxiv.org/abs/2306.08153 for more details.

If used with a different participation pattern (e.g., (k, b)-minsep where b is less than the number of bands, sensitivity can be computed post-hoc using e.g. toeplitz.minsep_sensitivity_squared. This should not be necessary in centralized training regimes where the exact participation pattern should be known in advance, however.

Parameters:
  • n (int) – the number of iterations that defines the workload.

  • bands (int) – The number of bands in the Toeplitz matrix.

  • strategy_coef (Array | None) – Optional toeplitz coefficients to initialize optimization.

  • max_optimizer_steps (int) – The maximum number of LBFGS iterations.

  • reduction_fn (Callable[[Array], Array]) – A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or lambda v: v[-1] to optimize last iterate squared error. Defaults to jnp.mean.

Return type:

Array

Returns:

The coefficeints of the optimal banded Toeplitz strategy, guaranteed to have L2 norm 1.