jax_privacy.matrix_factorization.dense.get_orthogonal_mask
- jax_privacy.matrix_factorization.dense.get_orthogonal_mask(n, epochs=1)[source]
Computes a mask that imposes orthognality constraints on the optimization.
This is specific to the fixed-epoch-order (k, b)-participation schema of https://arxiv.org/pdf/2211.06530.pdf, where participations are separated by exactly b-1 steps, and b = n / epochs.
This mask sets entry M_{ij} = 0 if i == j (mod b) and M_{ij} = 1 otherwise. Sensitivity for any matrix with 0s in these entries is easy to calculate as only a function of the diagonal. Moreover, the sensitivity is equal for all possible {-1,1} participation vectors.
- Parameters:
n (
int) – the size of the maskepochs (
int) – The number of epochs
- Return type:
Array- Returns:
A 0/1 mask