jax_privacy.matrix_factorization.dense.optimize

jax_privacy.matrix_factorization.dense.optimize(n, *, epochs=1, bands=None, equal_norm=False, A=None, max_optimizer_steps=10000, reduction_fn=<function mean>, callback=<function pg_tol_termination_fn>)[source]

Optimizes a strategy matrix C for a given reduction_fn and participation.

Note: While the function accepts a reduction_fn keyword argument, it has been tuned and tested rigorously only for mean-squared error (i.e., reduction_fn=jnp.mean).

This function can be used to optimize matrices under

  • Single-participation: [Denisov et al., 2022](https://arxiv.org/abs/2202.08312). This can be accomplished by running with default arguments.

  • Multi-participation with fixed-epoch order: [Choquette-Choo et al., 2022](https://arxiv.org/abs/2211.06530). This can be accomplished by setting epochs=k.

  • Multi-participation with min-separation (useful for federated training scenarios). This can be accomplished by setting bands = min_sep and equal_norm = True.

  • Multi-participation with amplification via subsampled fixed-epoch order: [Choquette-Choo et al., 2022](https://arxiv.org/abs/2211.06530). This can be accomplished by setting epochs=1, bands<separation, and equal_norm=True.

Parameters:
  • n (int) – the number of iterations the strategy should encode.

  • epochs (int) – The number of epochs the strategy should be calibrated for. Assumes (k, b)-fixed-epoch order participation.

  • bands (int | None) – The number of bands in the strategy.

  • equal_norm (bool) – Flag to indicate that each column of C should have equal_norm. Useful for BandMF. If epochs=1, this flag is a no-op, as the returned strategy will be column normalized either way.

  • A (Array | None) – The workload matrix (defaults to Prefix).

  • max_optimizer_steps (int) – The maximum number of LBFGS steps to take.

  • reduction_fn (Callable[[Array], Array]) – A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or any other differentiable function writtten in Jax.

  • callback (Callable[[CallbackArgs], None | bool]) – An optional callback function to monitor optimization progress. The default callback terminates the optimization early if the projected gradient is near-zero.

Return type:

Array

Returns:

The strategy matrix C that minimizes expected total squared error.