jax_privacy.matrix_factorization.buffered_toeplitz.Parameterization

class jax_privacy.matrix_factorization.buffered_toeplitz.Parameterization(params_from_blt, blt_and_inverse_from_params)[source]

Bases: object

A parameterization of a BufferedToeplitz for optimization.

Used by optimize_loss to specify how parameters relate to the pair of BLTs representing C and C^{-1}.

Variables:
  • params_from_blt – Constructs parameters to be optimized initialized from a BLT.

  • blt_and_inverse_from_params – Constructs a tuple of BLTs representing the (strategy_matrix, noising_matrix) from parameters.

  • loss_fn – The loss function to optimize.

Methods

__init__

buf_decay_pair

A parameterization where a pair of buf_decay parameters is optimized.

get_loss_fn

Returns a loss function for the parameterization.

strategy_blt

A parameterization where the strategy BLT is the parameterization.

Attributes

params_from_blt

blt_and_inverse_from_params

params_from_blt: Callable[[BufferedToeplitz], Union[Array, ndarray, bool, number, Iterable[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]]
blt_and_inverse_from_params: Callable[[Any], tuple[BufferedToeplitz, BufferedToeplitz]]
classmethod strategy_blt()[source]

A parameterization where the strategy BLT is the parameterization.

Return type:

Parameterization

classmethod buf_decay_pair()[source]

A parameterization where a pair of buf_decay parameters is optimized.

This parameterization is generally more numerically stable than the strategy_blt parameterization, as well as being negligibly faster to compute (as it does not require a singular-value decomposition). However, the current L-BFGS parameters are tuned for the strategy_blt parameterization, so this parameterization may not converge as well with the default settings.

Return type:

Parameterization

Returns:

A Parameterization.

get_loss_fn(loss_fn)[source]

Returns a loss function for the parameterization.

Return type:

Callable[[Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]], Union[Array, ndarray, bool, number, float, int]]