jax_privacy.matrix_factorization.buffered_toeplitz.LossFn
- class jax_privacy.matrix_factorization.buffered_toeplitz.LossFn(error_for_inv, sensitivity_squared, n, min_sep, max_participations, penalty_strength=1e-08, penalty_multipliers=<factory>, max_second_coef=1.0, min_theta_gap=1e-12)[source]
Bases:
objectEncapsulates the loss to be optimized for a specific setting.
This can represent the loss for both single participation and min-sep participation (which has single participation as a special case).
- Variables:
error_for_inv – Function for computing the error for the BLT representing C^{-1}, the noise correlating matrix.
sensitivity_squared – Function for computing the sensitivity for the BLT representing C, the strategy matrix.
n – The number of iterations the mechanism is optimized for.
min_sep – The minimum separation of participations.
max_participations – The effective maximum number of participations, taking into account n, min_sep, and max_participations.
penalty_strength – The multiplier applied to the sum of penalties for the loss.
penalty_multipliers – A dict of multipliers (default 1.0) applied to the individual penalties returned by compute_penalties.
max_second_coef – The maximum value of the second Toeplitz coefficient, which is equal to sum(output_scale).
min_theta_gap – The minimum gap between buf_decay parameters allowed by the theta_gap penalty.
Methods
__init__Construct a LossFn for single participation max-error.
Construct a LossFn for min-sep participation.
Computes penalties that help keep the optimization well-behaved.
Returns the loss (not including penalties).
Computes the total composite loss to be optimized.
Attributes
-
error_for_inv:
Callable[[BufferedToeplitz],Union[Array,ndarray,bool,number,float,int]]
-
sensitivity_squared:
Callable[[BufferedToeplitz],Union[Array,ndarray,bool,number,float,int]]
-
n:
int
-
min_sep:
int
-
max_participations:
int
-
penalty_strength:
float= 1e-08
-
penalty_multipliers:
dict[str,float]
-
max_second_coef:
float= 1.0
-
min_theta_gap:
float= 1e-12
- classmethod build_closed_form_single_participation(n, **kwargs)[source]
Construct a LossFn for single participation max-error.
This function utilizes the closed-form calculations for sensitivity and error from https://arxiv.org/abs/2404.16706, and hence optimization time is essentially independent of n. However, particularly for large n or large numbers of buffers, the optimal BLT may have a buf_decay theta very near 1, which leads to numerical issues in the closed forms. For max error, this function has been reasonably well tested up to n=10**7. Closed-form optimization of the mean loss closed form is possible, but this has not been well tested.
- Parameters:
n (
int) – The number of iterations the mechanism is optimized for.**kwargs – Optional additional arguments to pass to the constructor.
- Return type:
- Returns:
A LossFn for single participation.
- Raises:
ValueError – If error is not ‘max’ or ‘mean’.
- classmethod build_min_sep(n, error='max', min_sep=1, max_participations=None, **kwargs)[source]
Construct a LossFn for min-sep participation.
This LossFn computes loss and sensitivity by materializing the Toeplitz coefficients of C and C^{-1}, and then using the loss functions of toeplitz.py, as described in https://arxiv.org/abs/2408.08868. This is still significantly faster than computing the error directly from the Toeplitz coefficients of C, because
c_inv_coef = blt.inverse().toeplitz_coefs(n)
is orders of magnitude faster (on GPUs) than
c_inv_coef = toeplitz.inverse_coef(blt.toeplitz_coefs(n))
- Parameters:
n (
int) – The number of iterations the mechanism is optimized for.error (
str) – Either ‘max’ or ‘mean’, indicating whether to optimize for the maximum or mean squared error, respectively.min_sep (
int) – The minimum separation of participations.max_participations (
int|None) – The maximum number of participations.**kwargs – Optional additional arguments to pass to the constructor.
- Return type:
- Returns:
A LossFn for min-sep participation.
- Raises:
ValueError – If error is not ‘max’ or ‘mean’.
- compute_penalties(blt, inv_blt)[source]
Computes penalties that help keep the optimization well-behaved.
These correspond to the conditions of Theorem 1 (part a) of “An Inversion Theorem for Buffered Linear Toeplitz (BLT) Matrices and Applications to Streaming Differential Privacy” (https://arxiv.org/abs/2504.21413), which restricts the optimization to a class of well-behvaved BLTs. Note the constraint pillutla_score < 1 of part (a) is not strictly necessary, but empirically including it produces better results.
- Parameters:
blt (
BufferedToeplitz) – The BLT representing C.inv_blt (
BufferedToeplitz) – The BLT representing C^{-1}.
- Return type:
dict[str,Union[Array,ndarray,bool,number,float,int]]- Returns:
A dictionary of named penalties.
- penalized_loss(blt, inv_blt, normalize_by_approx_optimal_loss=True)[source]
Computes the total composite loss to be optimized.
- Parameters:
blt (
BufferedToeplitz) – The BLT representing C.inv_blt (
BufferedToeplitz) – The BLT representing C^{-1}.normalize_by_approx_optimal_loss (
bool) – If True, the loss is normalized by the expected optimal loss, so the relative penalty strength remains somewhat consistent across n and k. This is the default, and recommended for optimization.
- Return type:
Union[Array,ndarray,bool,number,float,int]- Returns:
The total composite loss to be optimized.
- loss(blt, skip_checks=False)[source]
Returns the loss (not including penalties).
This function is not intended to be jitted or used in optimization, but only in evaluation of the final BLT.
- Parameters:
blt (
BufferedToeplitz) – The BLT to compute the loss of.skip_checks (
bool) – If True, do not check that the BLT is valid for min-sep sensitivity.
- Return type:
Union[Array,ndarray,bool,number,float,int]