jax_privacy.matrix_factorization.banded.optimize

jax_privacy.matrix_factorization.banded.optimize(n, *, bands, C=None, A=None, max_optimizer_steps=100, reduction_fn=<function mean>, scan_fn=<function scan>, callback=<function <lambda>>)[source]

Optimize the strategy using a gradient-based method.

Note that this function benefits substantially from GPUs. This function is primarily supported to aid in reproducing results from https://arxiv.org/abs/2405.15913. In practice, we recommend using a banded Toeplitz strategy instead (see toeplitz.optimize_banded_toeplitz), which are <0.5% suboptimal in the regimes of most interest (n>=1000, b<=32).

The strategies produces by this procedure can be used in both single- and multi-participation settings – both (k, b)-min-sep and (k, b)-fixed epoch order, as described in https://arxiv.org/abs/2306.08153, as long as the number of bands in C is less than or equal to the (minimum) separation between contributions from the same user.

Parameters:
  • n (int) – The number of training iterations the strategy is configured for.

  • bands (int) – The number of bands in the strategy.

  • C (ColumnNormalizedBanded | None) – The initial strategy to be optimized.

  • A (StreamingMatrix | None) – The target workload.

  • max_optimizer_steps (int) – The maximum number of iterations to optimize for.

  • reduction_fn (Callable[[Array], Array]) – A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or lambda v: v[-1] to optimize last iterate squared error.

  • scan_fn (Any) – Either ‘equinox’, ‘dinosaur’, or a function with the same signature as jax.lax.scan. Using ‘equinox’ or ‘dinosaur’ is helpful for doing strategy optimization on GPUs for large n, since it allows the scan function used internally by per_query_error to be checkpointed, avoiding OOM errors during backpropagation.

  • callback (Callable[[CallbackArgs], None | bool]) – A function to call after each optimization iteration. See optimization.optimize for details.

Return type:

ColumnNormalizedBanded

Returns:

An optimized strategy having the same structure as C.