jax_privacy.matrix_factorization.toeplitz.optimize_banded_toeplitz
- jax_privacy.matrix_factorization.toeplitz.optimize_banded_toeplitz(n, bands, strategy_coef=None, max_optimizer_steps=250, reduction_fn=<function mean>)[source]
Optimize over the space of banded Toeplitz strategies on a Prefix workload.
The banded toeplitz strategies produced by this function can be used for both the single-participation setting and the multi-participation setting, (including both the fixed_epoch_order and min_sep participation schemas; see README.md) as long as the (minimum) separation between contributions from the same user is at least the number of bands provided. See https://arxiv.org/abs/2306.08153 for more details.
If used with a different participation pattern (e.g., (k, b)-minsep where b is less than the number of bands, sensitivity can be computed post-hoc using e.g. toeplitz.minsep_sensitivity_squared. This should not be necessary in centralized training regimes where the exact participation pattern should be known in advance, however.
- Parameters:
n (
int) – the number of iterations that defines the workload.bands (
int) – The number of bands in the Toeplitz matrix.strategy_coef (
Array|None) – Optional toeplitz coefficients to initialize optimization.max_optimizer_steps (
int) – The maximum number of LBFGS iterations.reduction_fn (
Callable[[Array],Array]) – A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or lambda v: v[-1] to optimize last iterate squared error. Defaults to jnp.mean.
- Return type:
Array- Returns:
The coefficeints of the optimal banded Toeplitz strategy, guaranteed to have L2 norm 1.