jax_privacy.matrix_factorization.dense.get_orthogonal_mask

jax_privacy.matrix_factorization.dense.get_orthogonal_mask(n, epochs=1)[source]

Computes a mask that imposes orthognality constraints on the optimization.

This is specific to the fixed-epoch-order (k, b)-participation schema of https://arxiv.org/pdf/2211.06530.pdf, where participations are separated by exactly b-1 steps, and b = n / epochs.

This mask sets entry M_{ij} = 0 if i == j (mod b) and M_{ij} = 1 otherwise. Sensitivity for any matrix with 0s in these entries is easy to calculate as only a function of the diagonal. Moreover, the sensitivity is equal for all possible {-1,1} participation vectors.

Parameters:
  • n (int) – the size of the mask

  • epochs (int) – The number of epochs

Return type:

Array

Returns:

A 0/1 mask