jax_privacy.matrix_factorization.toeplitz.per_query_error

jax_privacy.matrix_factorization.toeplitz.per_query_error(*, strategy_coef=None, noising_coef=None, n=None, workload_coef=None)[source]

Expected per-query squared error for a (banded) Toeplitz mechanism.

This function returns the squared error on a per-iteration basis. To compute the mean-squared error or max squared error, use jnp.mean or jnp.max on the output of this function. Note: for toeplitz workloads / strategies, the max squared error is equal to the last iterate squared error, and might be more efficient to compute under jax transformations.

Exactly one of strategy_coef and noising_coef should be provided.

Parameters:
  • strategy_coef (Array | None) – Toeplitz coefficients of the strategy matrix.

  • noising_coef (Array | None) – Toeplitz coefficients of the noising matrix.

  • n (int | None) – The size of the implied matrices (defaults to the length of the Toeplitz coefficient array).

  • workload_coef (Array | None) – Toeplitz coefficients of the workload matrix. Defaults to the vector of 1s, corresponding to the prefix matrix. If this is longer than n, the extra entries are ignored (even if n is inferred from the length of the strategy_coef or noising_coef).

Return type:

Array

Returns:

The expected per-query squared error, an array of length n.