jax_privacy.matrix_factorization.buffered_toeplitz.BufferedToeplitz
- class jax_privacy.matrix_factorization.buffered_toeplitz.BufferedToeplitz(buf_decay, output_scale)[source]
Bases:
MappingA lower-triangular Toeplitz C parameterized as a BLT.
BufferedToeplitz.build is the recommended way to construct a BLT.
- For background on Buffered Linear Toeplitz matrices and DP mechanisms, see:
If buf_decay = [d1, d2] and output_scale = [s1, s2], for n = 5 this class represents
1 0 0 0 0 t0 = 1 t1 1 0 0 0 t1 = s1 + s2 t2 t1 1 0 0 where t2 = s1*d1**1 + s2*d2**1 t3 t2 t1 1 0 t3 = s1*d1**2 + s2*d2**2 t4 t3 t2 t1 1 t4 = s1*d1**3 + s2*d2**3
These Toeplitz parameters are returned by toeplitz_coefs(n=5).
Methods
__init__Returns a StreamingMatrix representing C.
Helper to convert arguments to jnp.arrays and canonicalize.
Returns a BufferedToeplitz with buf_decay in decreasing order.
Returns a BLT based on a rational approximation of 1/sqrt(1 - x).
getRetrieves the corresponding layout by the string key.
Compute the BufferedToeplitz parameterization of C^{-1} from C.
Returns a StreamingMatrix representing C^{-1}.
D.items() -> a set-like object providing a view on D's items
D.keys() -> a set-like object providing a view on D's keys
Returns the 'Pillutla Score' of the BLT.
Returns the Toeplitz coefficients for C.
Validates basic properties of the BLT parameters.
D.values() -> an object providing a view on D's values
Attributes
-
buf_decay:
Array
-
output_scale:
Array
- classmethod build(buf_decay, output_scale, dtype=<class 'jax.numpy.float64'>)[source]
Helper to convert arguments to jnp.arrays and canonicalize.
- Parameters:
buf_decay (
Any) – The buf_decay parameters of a BLT.output_scale (
Any) – The output_scale parameters of a BLT.dtype (
Union[str,type[Any],dtype,SupportsDType]) – The dtype to use for the BLT parameters. The default, jnp.float64, is storngly recommended for numerical stability. However, this requires either the global option jax.config.update(‘jax_enable_x64’, True) or that build() and subsequent computations occur within a with jax.enable_x64(): context. See also check_float64_dtype.
- Return type:
- Returns:
A BufferedToeplitz with buf_decay in decreasing order.
- classmethod from_rational_approx_to_sqrt_x(num_buffers, *, max_buf_decay=1.0, max_pillutla_score=None, buf_decay_scale=1.6, buf_decay_shift=-1)[source]
Returns a BLT based on a rational approximation of 1/sqrt(1 - x).
The optimal-for-max-loss Toeplitz coefficients (see toeplitz.optimal_max_error_strategy_coefs) correspond to the ordinary generating function 1/sqrt(1 - x). Thus, this method is used to produce a BLT that approximates these coefficients, but allows for a much more memory-efficient implementation of multiplication by the noising matrix $C^{-1}$.
The rational approximation is from https://arxiv.org/abs/2404.16706v2, see Proposition 4.5 in particular.
NOTE: The BLTs produced by this method are generally significantly inferior to those from buffered_toeplitz.optimize, which finds a numerically optimal BLT for a specific value of n. Hence, the primary use of this method is initializating numerical optimization, as well as providing an implementation of the “RA-BLT” method of https://arxiv.org/abs/2404.16706v2 for use in research conparisons.
- Parameters:
num_buffers (
int) – The number of buffers to use in the BLT (equivalently, the degree of the rational function), must be >= 1.max_buf_decay (
float) – The maximum value of buf_decay to use. For large numbers of buffers, this routine can produce buf_decay values of 1.0 (up to float64 precision),or higher due to numerical issues. This parameter can be used to enforce that the largest buf_decay paramter is strictly less than one. This is useful for initializing optimization.max_pillutla_score (
float|None) – If not None, the maximum pillutla score to use. This is accomplished by first scaling the buf_decay parameters if needed based on max_buf_decay, and then scaling the output_scale parameters to ensure the pillutla score is less than or equal to this value.buf_decay_scale (
float) – A factor that scales the dynamic range of the buf_decay parameters. Larger values indicate a coarser resolution, and hence the largest buf_decay will be closer to 1.0, and the smallest closer to 0.0.buf_decay_shift (
int) – A shift added to the range of the counter k in the construction of the rational approximation. The buf decay parameters come from a discrete set indexed by k, with resolution that depends on the buf_decay_scale. A negative buf_decay_shift shifts the selected buf_decay parameters toward 1.0; a positive shift moves the selected set closer to 0.0. The default of -1 is recommended.
- Return type:
- Returns:
A BufferedToeplitz matrix generated by a rational function approximation of 1/sqrt(1 - x).
- property dtype: str | type[Any] | dtype | SupportsDType
- inverse(skip_checks=False)[source]
Compute the BufferedToeplitz parameterization of C^{-1} from C.
This is an alternative approach to https://arxiv.org/pdf/2404.16706 Lemma 5.2, along the lines of Proposition 5.6 (Representation of the reciprocal of a rational generating function), with slightly different parameterization.
- Parameters:
skip_checks (
bool) – If True, skip error checks on the inputs and results (necessary in jitted contexts).- Return type:
- Returns:
A Buffered Linear Toeplitz parameterization of the inverse.
- Raises:
RuntimeError – If skip_checks=False and the inverse calculation encounters numerical problems.
- pillutla_score()[source]
Returns the ‘Pillutla Score’ of the BLT.
See Theorem 1 of “An Inversion Theorem for Buffered Linear Toeplitz (BLT) Matrices and Applications to Streaming Differential Privacy”, https://arxiv.org/abs/2504.21413. To avoid a negative buf_decay value in the noising matric BLT (which produces an oscillating term), we enforce a pillutla_score < 1 during optimization.
Note that a BLT may have buf_decay == 0 values, which leads to an nan or inf pillutla score. (In particular, the inverse of a BLT with pillutla_score=0 will have this property).
- Return type:
Union[Array,ndarray,bool,number,float,int]- Returns:
The Pillutla Score of the BLT, sum_i(output_scale[i] / buf_decay[i]).
- from_tuple()
- items() a set-like object providing a view on D's items
- keys() a set-like object providing a view on D's keys
- replace(**kwargs)
- to_tuple()
- values() an object providing a view on D's values