jax_privacy.matrix_factorization.buffered_toeplitz.BufferedToeplitz

class jax_privacy.matrix_factorization.buffered_toeplitz.BufferedToeplitz(buf_decay, output_scale)[source]

Bases: Mapping

A lower-triangular Toeplitz C parameterized as a BLT.

BufferedToeplitz.build is the recommended way to construct a BLT.

For background on Buffered Linear Toeplitz matrices and DP mechanisms, see:

If buf_decay = [d1, d2] and output_scale = [s1, s2], for n = 5 this class represents

1   0  0  0  0             t0 = 1
t1  1  0  0  0             t1 = s1       + s2
t2 t1  1  0  0    where    t2 = s1*d1**1 + s2*d2**1
t3 t2 t1  1  0             t3 = s1*d1**2 + s2*d2**2
t4 t3 t2 t1  1             t4 = s1*d1**3 + s2*d2**3

These Toeplitz parameters are returned by toeplitz_coefs(n=5).

Methods

__init__

as_streaming_matrix

Returns a StreamingMatrix representing C.

build

Helper to convert arguments to jnp.arrays and canonicalize.

canonicalize

Returns a BufferedToeplitz with buf_decay in decreasing order.

from_rational_approx_to_sqrt_x

Returns a BLT based on a rational approximation of 1/sqrt(1 - x).

from_tuple

get

Retrieves the corresponding layout by the string key.

inverse

Compute the BufferedToeplitz parameterization of C^{-1} from C.

inverse_as_streaming_matrix

Returns a StreamingMatrix representing C^{-1}.

items

D.items() -> a set-like object providing a view on D's items

keys

D.keys() -> a set-like object providing a view on D's keys

materialize

pillutla_score

Returns the 'Pillutla Score' of the BLT.

replace

to_tuple

toeplitz_coefs

Returns the Toeplitz coefficients for C.

validate

Validates basic properties of the BLT parameters.

values

D.values() -> an object providing a view on D's values

Attributes

dtype

buf_decay

output_scale

buf_decay: Array
output_scale: Array
validate()[source]

Validates basic properties of the BLT parameters.

classmethod build(buf_decay, output_scale, dtype=<class 'jax.numpy.float64'>)[source]

Helper to convert arguments to jnp.arrays and canonicalize.

Parameters:
  • buf_decay (Any) – The buf_decay parameters of a BLT.

  • output_scale (Any) – The output_scale parameters of a BLT.

  • dtype (Union[str, type[Any], dtype, SupportsDType]) – The dtype to use for the BLT parameters. The default, jnp.float64, is storngly recommended for numerical stability. However, this requires either the global option jax.config.update(‘jax_enable_x64’, True) or that build() and subsequent computations occur within a with jax.enable_x64(): context. See also check_float64_dtype.

Return type:

BufferedToeplitz

Returns:

A BufferedToeplitz with buf_decay in decreasing order.

classmethod from_rational_approx_to_sqrt_x(num_buffers, *, max_buf_decay=1.0, max_pillutla_score=None, buf_decay_scale=1.6, buf_decay_shift=-1)[source]

Returns a BLT based on a rational approximation of 1/sqrt(1 - x).

The optimal-for-max-loss Toeplitz coefficients (see toeplitz.optimal_max_error_strategy_coefs) correspond to the ordinary generating function 1/sqrt(1 - x). Thus, this method is used to produce a BLT that approximates these coefficients, but allows for a much more memory-efficient implementation of multiplication by the noising matrix $C^{-1}$.

The rational approximation is from https://arxiv.org/abs/2404.16706v2, see Proposition 4.5 in particular.

NOTE: The BLTs produced by this method are generally significantly inferior to those from buffered_toeplitz.optimize, which finds a numerically optimal BLT for a specific value of n. Hence, the primary use of this method is initializating numerical optimization, as well as providing an implementation of the “RA-BLT” method of https://arxiv.org/abs/2404.16706v2 for use in research conparisons.

Parameters:
  • num_buffers (int) – The number of buffers to use in the BLT (equivalently, the degree of the rational function), must be >= 1.

  • max_buf_decay (float) – The maximum value of buf_decay to use. For large numbers of buffers, this routine can produce buf_decay values of 1.0 (up to float64 precision),or higher due to numerical issues. This parameter can be used to enforce that the largest buf_decay paramter is strictly less than one. This is useful for initializing optimization.

  • max_pillutla_score (float | None) – If not None, the maximum pillutla score to use. This is accomplished by first scaling the buf_decay parameters if needed based on max_buf_decay, and then scaling the output_scale parameters to ensure the pillutla score is less than or equal to this value.

  • buf_decay_scale (float) – A factor that scales the dynamic range of the buf_decay parameters. Larger values indicate a coarser resolution, and hence the largest buf_decay will be closer to 1.0, and the smallest closer to 0.0.

  • buf_decay_shift (int) – A shift added to the range of the counter k in the construction of the rational approximation. The buf decay parameters come from a discrete set indexed by k, with resolution that depends on the buf_decay_scale. A negative buf_decay_shift shifts the selected buf_decay parameters toward 1.0; a positive shift moves the selected set closer to 0.0. The default of -1 is recommended.

Return type:

BufferedToeplitz

Returns:

A BufferedToeplitz matrix generated by a rational function approximation of 1/sqrt(1 - x).

canonicalize()[source]

Returns a BufferedToeplitz with buf_decay in decreasing order.

Return type:

BufferedToeplitz

property dtype: str | type[Any] | dtype | SupportsDType
toeplitz_coefs(n)[source]

Returns the Toeplitz coefficients for C.

Return type:

Array

materialize(n)[source]
Return type:

Array

inverse(skip_checks=False)[source]

Compute the BufferedToeplitz parameterization of C^{-1} from C.

This is an alternative approach to https://arxiv.org/pdf/2404.16706 Lemma 5.2, along the lines of Proposition 5.6 (Representation of the reciprocal of a rational generating function), with slightly different parameterization.

Parameters:

skip_checks (bool) – If True, skip error checks on the inputs and results (necessary in jitted contexts).

Return type:

BufferedToeplitz

Returns:

A Buffered Linear Toeplitz parameterization of the inverse.

Raises:

RuntimeError – If skip_checks=False and the inverse calculation encounters numerical problems.

pillutla_score()[source]

Returns the ‘Pillutla Score’ of the BLT.

See Theorem 1 of “An Inversion Theorem for Buffered Linear Toeplitz (BLT) Matrices and Applications to Streaming Differential Privacy”, https://arxiv.org/abs/2504.21413. To avoid a negative buf_decay value in the noising matric BLT (which produces an oscillating term), we enforce a pillutla_score < 1 during optimization.

Note that a BLT may have buf_decay == 0 values, which leads to an nan or inf pillutla score. (In particular, the inverse of a BLT with pillutla_score=0 will have this property).

Return type:

Union[Array, ndarray, bool, number, float, int]

Returns:

The Pillutla Score of the BLT, sum_i(output_scale[i] / buf_decay[i]).

as_streaming_matrix()[source]

Returns a StreamingMatrix representing C.

Return type:

StreamingMatrix

inverse_as_streaming_matrix()[source]

Returns a StreamingMatrix representing C^{-1}.

Return type:

StreamingMatrix

from_tuple()
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
replace(**kwargs)
to_tuple()
values() an object providing a view on D's values