jax_privacy.matrix_factorization.banded.per_query_error

jax_privacy.matrix_factorization.banded.per_query_error(C, A=None, scan_fn=<function scan>)[source]

Computes expected per-query squared error of a strategy.

Specifically, this function computes the row-wise L2^2 norm of B = A C^{-1}. this vector to a scalar via the reduction_fn.

Since C is column normalized, this error function can be used as a loss function, since sensitivity is constant for ColumnNormalizedBanded strategies for both single-participation and multi-participation settings, as long as the number of bands in C is less than or equal to the (minimum) separation between contributions from the same user.

If you need to backpropagate through this function, you can use the equinox or dinosaur scan functions to make the scan checkpointed, which allows the scan to be performed for large n without OOMing the accelerator.

Parameters:
  • C (ColumnNormalizedBanded) – the strategy matrix, represented implicitly.

  • A (StreamingMatrix | None) – The workload matrix, represented implicitly.

  • scan_fn (Any) – A function with the same signature as jax.lax.scan.

Return type:

Array

Returns:

The per query expected squared error of the strategy on the workload, represented as an array of length n.