jax_privacy.matrix_factorization.dense.optimize
- jax_privacy.matrix_factorization.dense.optimize(n, *, epochs=1, bands=None, equal_norm=False, A=None, max_optimizer_steps=10000, reduction_fn=<function mean>, callback=<function pg_tol_termination_fn>)[source]
Optimizes a strategy matrix C for a given reduction_fn and participation.
Note: While the function accepts a reduction_fn keyword argument, it has been tuned and tested rigorously only for mean-squared error (i.e., reduction_fn=jnp.mean).
This function can be used to optimize matrices under
Single-participation: [Denisov et al., 2022](https://arxiv.org/abs/2202.08312). This can be accomplished by running with default arguments.
Multi-participation with fixed-epoch order: [Choquette-Choo et al., 2022](https://arxiv.org/abs/2211.06530). This can be accomplished by setting epochs=k.
Multi-participation with min-separation (useful for federated training scenarios). This can be accomplished by setting bands = min_sep and equal_norm = True.
Multi-participation with amplification via subsampled fixed-epoch order: [Choquette-Choo et al., 2022](https://arxiv.org/abs/2211.06530). This can be accomplished by setting epochs=1, bands<separation, and equal_norm=True.
- Parameters:
n (
int) – the number of iterations the strategy should encode.epochs (
int) – The number of epochs the strategy should be calibrated for. Assumes (k, b)-fixed-epoch order participation.bands (
int|None) – The number of bands in the strategy.equal_norm (
bool) – Flag to indicate that each column of C should have equal_norm. Useful for BandMF. If epochs=1, this flag is a no-op, as the returned strategy will be column normalized either way.A (
Array|None) – The workload matrix (defaults to Prefix).max_optimizer_steps (
int) – The maximum number of LBFGS steps to take.reduction_fn (
Callable[[Array],Array]) – A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or any other differentiable function writtten in Jax.callback (
Callable[[CallbackArgs],None|bool]) – An optional callback function to monitor optimization progress. The default callback terminates the optimization early if the projected gradient is near-zero.
- Return type:
Array- Returns:
The strategy matrix C that minimizes expected total squared error.