# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for working with Buffered Linear Toeplitz (BLT) strategy matrices.
`BufferedToeplitz` is the main BLT class, with several helper functions
for error and sensitity calculation, as well as optimization.
"""
from collections.abc import Callable
import dataclasses
import functools
from typing import Any, TypeAlias
from absl import logging
import chex
import jax
import jax.numpy as jnp
import numpy as np
import optax
from . import optimization
from . import sensitivity
from . import streaming_matrix
from . import toeplitz
# Disabling pylint invalid-name to allow mathematical notation including
# single-capital-letter variables for matrices.
# See README.md for notation conventions.
# pylint:disable=invalid-name
# We want to allow primitive types, lists of primitive types, np.array,
# jnp.array, etc.
# See https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html
ThetaPairType: TypeAlias = tuple[jax.Array, jax.Array]
ScalarFloat: TypeAlias = chex.Numeric | float
def _gpu_or_cpu_device() -> jax.Device:
# Use a GPU if possible, and avoid TPUs as we require float64:
try:
return jax.local_devices(backend='gpu')[0]
except RuntimeError:
return jax.local_devices(backend='cpu')[0]
[docs]
def check_float64_dtype(blt: 'BufferedToeplitz'):
if (blt.dtype == jnp.float64) and not jax.config.read('jax_enable_x64'):
raise ValueError(
'A BLT with dtype=jnp.float64 requires `jax_enable_x64`.\nMany BLT'
' operations require float64 for numerical stability, particularly'
' optimization, and buffered_toeplitz.optimize() will return a BLT with'
' float64 parameters. If you see this error, you likely need to do one'
' of the following:\n'
'1. Locally, add `with '
'jax.enable_x64():` around the code that triggered this '
'error.\n'
'2. Globally set `jax.config.update("jax_enable_x64", True)`.\n'
'3. Explicitly cast the parameters of your BufferedToeplitz to '
' have dtype=jnp.float32.'
)
[docs]
@dataclasses.dataclass(kw_only=True, frozen=True)
class StreamingMatrixBuilder:
"""Builder to convert a BLT to a StreamingMatrix.
Attributes:
buf_decay: A numpy array of shape `(num_buffers,)` representing the decay
factors for each buffer.
output_scale: A numpy array of shape `(num_buffers,)` representing the scale
factors applied to each buffer when reading the output.
"""
buf_decay: np.ndarray
output_scale: np.ndarray
@property
def dtype(self) -> np.dtype:
"""Returns the dtype of the BLT parameters."""
assert self.output_scale.dtype == self.buf_decay.dtype
return self.output_scale.dtype
def _init(self, abstract_value):
# Common logic in the StreamingMatrix representation of C and C^{-1}
# The state is simply the contents of the buffers, each of `shape`.
assert len(self.buf_decay) == len(self.output_scale)
num_buffers = self.buf_decay.shape[0]
# *_like preserves shape, dtype, and (if possible) sharding.
zero = jnp.zeros_like(abstract_value, dtype=self.dtype)
return jnp.broadcast_to(zero, (num_buffers,) + zero.shape)
def _read(self, state):
# Common logic in the StreamingMatrix representation of C and C^{-1}
bufs = state
return jnp.tensordot(
jnp.asarray(self.output_scale), bufs, axes=((0,), (0,))
)
def _update(self, state, next_rhs_value):
# Common logic in the StreamingMatrix representation of C and C^{-1}
bufs = state
buf_decay = jnp.asarray(self.buf_decay)
# Decay buffers:
bufs = jnp.tensordot(jnp.diag(buf_decay), bufs, axes=((0,), (0,)))
# Add xi to each buffer (relies on broadcasting):
bufs = bufs + next_rhs_value
return bufs
[docs]
def build(
self,
) -> streaming_matrix.StreamingMatrix:
"""Returns a `StreamingMatrix` representing C.
This implements Alg 2 of https://arxiv.org/pdf/2408.08868
"""
def multiply_next(xi, state):
yi = xi + self._read(state)
state = self._update(state, xi)
return yi, state
return streaming_matrix.StreamingMatrix.from_array_implementation(
self._init, multiply_next
)
[docs]
def build_inverse(
self,
) -> streaming_matrix.StreamingMatrix:
"""Returns a `StreamingMatrix` representing C^{-1}.
This implements Alg 3 of https://arxiv.org/pdf/2408.08868
"""
def inv_multiply_next(yi, state):
xi = yi - self._read(state)
state = self._update(state, xi)
return xi, state
return streaming_matrix.StreamingMatrix.from_array_implementation(
self._init, inv_multiply_next
)
[docs]
@chex.dataclass
class BufferedToeplitz:
"""A lower-triangular Toeplitz C parameterized as a BLT.
`BufferedToeplitz.build` is the recommended way to construct a BLT.
For background on Buffered Linear Toeplitz matrices and DP mechanisms, see:
- https://arxiv.org/abs/2404.16706
- https://arxiv.org/abs/2408.08868
If buf_decay = [d1, d2] and output_scale = [s1, s2], for n = 5
this class represents
::
1 0 0 0 0 t0 = 1
t1 1 0 0 0 t1 = s1 + s2
t2 t1 1 0 0 where t2 = s1*d1**1 + s2*d2**1
t3 t2 t1 1 0 t3 = s1*d1**2 + s2*d2**2
t4 t3 t2 t1 1 t4 = s1*d1**3 + s2*d2**3
These Toeplitz parameters are returned by `toeplitz_coefs(n=5)`.
"""
buf_decay: jax.Array # Shape(nbuf, )
output_scale: jax.Array # Shape (nbuf, )
[docs]
def validate(self):
"""Validates basic properties of the BLT parameters."""
# Note: We do not put these checks in __post_init__ because
# these properties might not hold if we tree_map over a BLT
# to hold other values (like gradients or dtypes).
if not ((self.buf_decay.ndim <= 1) and (self.output_scale.ndim <= 1)):
raise ValueError(
'buf_decay and output_scale must be OD or 1D, but :'
f'buf_decay.shape = {self.buf_decay.shape} and '
f'output_scale.shape = {self.output_scale.shape}'
)
if self.buf_decay.shape != self.output_scale.shape:
raise ValueError(
'buf_decay and output_scale must have the same length, but '
f'buf_decay.shape = {self.buf_decay.shape} !='
f' output_scale.shape = {self.output_scale.shape}'
)
if self.buf_decay.dtype != self.output_scale.dtype:
raise ValueError(
'buf_decay and output_scale must have the same dtype, but '
f'buf_decay.dtype = {self.buf_decay.dtype} !='
f' output_scale.dtype = {self.output_scale.dtype}'
)
[docs]
@classmethod
def build(
cls,
buf_decay: Any,
output_scale: Any,
dtype: jax.typing.DTypeLike = jnp.float64,
) -> 'BufferedToeplitz':
"""Helper to convert arguments to jnp.arrays and canonicalize.
Args:
buf_decay: The buf_decay parameters of a BLT.
output_scale: The output_scale parameters of a BLT.
dtype: The dtype to use for the BLT parameters. The default,
`jnp.float64`, is storngly recommended for numerical stability. However,
this requires either the global option
`jax.config.update('jax_enable_x64', True)` or that build() and
subsequent computations occur within a `with jax.enable_x64():` context.
See also `check_float64_dtype`.
Returns:
A `BufferedToeplitz` with buf_decay in decreasing order.
"""
if (dtype == jnp.float64) and not jax.config.read('jax_enable_x64'):
raise ValueError(
'dtype=jnp.float64 requires `jax_enable_x64`; see docstring.'
)
blt = cls(
buf_decay=jnp.array(buf_decay, dtype=dtype),
output_scale=jnp.array(output_scale, dtype=dtype),
)
return blt.canonicalize() # Also calls validate()
[docs]
@classmethod
def from_rational_approx_to_sqrt_x(
cls,
num_buffers: int,
*,
max_buf_decay: float = 1.0,
max_pillutla_score: float | None = None,
buf_decay_scale: float = 1.6,
buf_decay_shift: int = -1,
) -> 'BufferedToeplitz':
"""Returns a BLT based on a rational approximation of 1/sqrt(1 - x).
The optimal-for-max-loss Toeplitz coefficients (see
`toeplitz.optimal_max_error_strategy_coefs`) correspond to the ordinary
generating function 1/sqrt(1 - x). Thus, this method is used to produce
a BLT that approximates these coefficients, but allows for a much more
memory-efficient implementation of multiplication by the noising matrix
$C^{-1}$.
The rational approximation is from https://arxiv.org/abs/2404.16706v2,
see Proposition 4.5 in particular.
NOTE: The BLTs produced by this method are generally significantly inferior
to those from `buffered_toeplitz.optimize`, which finds a numerically
optimal BLT for a specific value of `n`. Hence, the primary use of this
method is initializating numerical optimization, as well as providing an
implementation of the "RA-BLT" method of https://arxiv.org/abs/2404.16706v2
for use in research conparisons.
Args:
num_buffers: The number of buffers to use in the BLT (equivalently, the
degree of the rational function), must be >= 1.
max_buf_decay: The maximum value of buf_decay to use. For large numbers of
buffers, this routine can produce buf_decay values of 1.0 (up to float64
precision),or higher due to numerical issues. This parameter can be used
to enforce that the largest buf_decay paramter is strictly less than
one. This is useful for initializing optimization.
max_pillutla_score: If not None, the maximum pillutla score to use. This
is accomplished by first scaling the buf_decay parameters if needed
based on max_buf_decay, and then scaling the output_scale parameters to
ensure the pillutla score is less than or equal to this value.
buf_decay_scale: A factor that scales the dynamic range of the buf_decay
parameters. Larger values indicate a coarser resolution, and hence the
largest buf_decay will be closer to 1.0, and the smallest closer to 0.0.
buf_decay_shift: A shift added to the range of the counter `k` in the
construction of the rational approximation. The buf decay parameters
come from a discrete set indexed by `k`, with resolution that depends on
the `buf_decay_scale`. A negative `buf_decay_shift` shifts the selected
buf_decay parameters toward 1.0; a positive shift moves the selected set
closer to 0.0. The default of -1 is recommended.
Returns:
A `BufferedToeplitz` matrix generated by a rational function
approximation of 1/sqrt(1 - x).
"""
# Note: Following Prop 4.5, we actually construct a BLT for
# C^{-1} based on the rational function approximation of sqrt(1 - x),
# and then take the inverse.
if num_buffers < 1:
raise ValueError('num_buffers must be >= 1.')
degree = num_buffers
d1 = (degree + 1) // 2
h = buf_decay_scale * np.pi / np.sqrt(2 * (d1 + 1)) # step size.
# Construct the approximation.
# The range of the values of the counter `k`
ks = np.arange(-d1 + 1, degree - d1 + 1, 1) + buf_decay_shift
buf_decay = 1 / (1 + np.exp(2 * h * ks)) # Aka `c`s
output_scale = -(buf_decay**2) * np.exp(3 * h * ks)
constant_term = np.sum(np.exp(h * ks) * buf_decay)
# Normalize the leading constant to 1.
output_scale /= constant_term
# The (cs, scales) define a BLT object for C^{-1}
inv_blt = cls.build(buf_decay=buf_decay, output_scale=output_scale)
blt = inv_blt.inverse()
buf_decay = blt.buf_decay
output_scale = blt.output_scale
# We now scale / project to ensure we are within the
# max_buf_decay and max_pillutla_socre. First,
# we scale all the buf decay values, which ensures the ordering
# isn't change and we don't introduce duplicate values.
largest_buf_decay = buf_decay[0]
scale = jnp.minimum(1.0, max_buf_decay / largest_buf_decay)
buf_decay *= scale
assert buf_decay[0] <= max_buf_decay + 1e-12
score = jnp.sum(output_scale / buf_decay)
# If needed, rescale the output parameters to ensure the pillutla score
# is less than the maximum.
if max_pillutla_score is not None:
score_scale = jnp.minimum(1.0, max_pillutla_score / score)
output_scale *= score_scale
blt = cls.build(
buf_decay=buf_decay,
output_scale=output_scale,
)
if max_pillutla_score is not None:
assert blt.pillutla_score() <= max_pillutla_score + 1e-12
return blt
[docs]
def canonicalize(self) -> 'BufferedToeplitz':
"""Returns a `BufferedToeplitz` with buf_decay in decreasing order."""
self.validate() # Make sure lengths match.
idx = jnp.argsort(self.buf_decay)[::-1]
return BufferedToeplitz(
buf_decay=self.buf_decay[idx], output_scale=self.output_scale[idx]
)
@property
def dtype(self) -> jax.typing.DTypeLike:
return self.buf_decay.dtype
[docs]
def toeplitz_coefs(self, n: int) -> jax.Array:
"""Returns the Toeplitz coefficients for `C`."""
powers = jnp.arange(n - 1)
tmp = self.buf_decay ** powers[:, None] * self.output_scale
return jnp.append(1, jnp.sum(tmp, axis=1))
[docs]
def materialize(self, n: int) -> jax.Array:
return toeplitz.materialize_lower_triangular(self.toeplitz_coefs(n))
[docs]
def inverse(self, skip_checks: bool = False) -> 'BufferedToeplitz':
"""Compute the BufferedToeplitz parameterization of C^{-1} from C.
This is an alternative approach to https://arxiv.org/pdf/2404.16706 Lemma
5.2, along the lines of Proposition 5.6 (Representation of the reciprocal of
a rational generating function), with slightly different parameterization.
Args:
skip_checks: If True, skip error checks on the inputs and results
(necessary in jitted contexts).
Returns:
A Buffered Linear Toeplitz parameterization of the inverse.
Raises:
RuntimeError: If skip_checks=False and the inverse calculation encounters
numerical problems.
"""
check_float64_dtype(self)
blt = self # We only use the public API of the BLT.
if not skip_checks and len(blt.buf_decay) > 1:
gap = min_buf_decay_gap(blt.buf_decay)
if gap < 1e-9:
# Some quick empirical tests show the closed-form evecs computation
# returns nans for gaps < 1e-9.
raise ValueError(
'The input BLT has buf_decay values that are too close to each'
f' other.Expected gap >= 1e-9, found {gap} for'
f' buf_decay={blt.buf_decay}.'
)
nbuf = len(blt.buf_decay)
Theta = jnp.diag(blt.buf_decay)
omega = blt.output_scale
alpha = jnp.ones(nbuf)
# These parameters produce the inverse, but Theta2 is not diagonal,
# so this is not a BLT:
Theta2 = Theta - jnp.outer(omega, alpha)
omega2 = -omega
# Derive an equivalent BLT by diagonalizing Theta2.
evals = jnp.linalg.eigvals(Theta2)
# Ideally we would have a proof that the evals are in fact all real:
evals = evals.real
# We have the following closed-form for the eigenvectors:
evecs = omega[:, jnp.newaxis] / (evals - blt.buf_decay[:, jnp.newaxis])
einv = jnp.linalg.inv(evecs)
if not skip_checks:
Theta2_diagonalized = evecs @ jnp.diag(evals) @ einv
if not np.allclose(Theta2_diagonalized, Theta2, atol=1e-7):
raise RuntimeError(
'Error computing inverse parameters:'
f' {blt=}\n{evecs=}\n{einv=}\n{evals=}\n'
f'{Theta2=}\n{Theta2_diagonalized=}'
)
omega3 = (einv @ omega2) * (evecs.T @ alpha)
return BufferedToeplitz.build(
buf_decay=evals, output_scale=omega3, dtype=blt.dtype
)
[docs]
def pillutla_score(self) -> ScalarFloat:
"""Returns the 'Pillutla Score' of the BLT.
See Theorem 1 of "An Inversion Theorem for Buffered Linear Toeplitz (BLT)
Matrices and Applications to Streaming Differential Privacy",
https://arxiv.org/abs/2504.21413. To avoid a negative buf_decay value in the
noising matric BLT (which produces an oscillating term), we enforce a
pillutla_score < 1 during optimization.
Note that a BLT may have `buf_decay == 0` values, which leads to an
nan or inf pillutla score. (In particular, the inverse of a BLT with
pillutla_score=0 will have this property).
Returns:
The Pillutla Score of the BLT, `sum_i(output_scale[i] / buf_decay[i]).`
"""
return jnp.sum(self.output_scale / self.buf_decay)
@property
def _num_buffers(self):
return self.buf_decay.shape[0]
def _streaming_matrix_builder(self):
dtype = np.dtype(self.dtype)
return StreamingMatrixBuilder(
output_scale=np.array(self.output_scale, dtype=dtype),
buf_decay=np.array(self.buf_decay, dtype=dtype),
)
[docs]
def as_streaming_matrix(
self,
) -> streaming_matrix.StreamingMatrix:
"""Returns a `StreamingMatrix` representing C."""
check_float64_dtype(self)
return self._streaming_matrix_builder().build()
[docs]
def inverse_as_streaming_matrix(
self,
) -> streaming_matrix.StreamingMatrix:
"""Returns a `StreamingMatrix` representing C^{-1}."""
check_float64_dtype(self)
return self._streaming_matrix_builder().build_inverse()
def __str__(self):
coefs = self.toeplitz_coefs(10)
coefs_str = ', '.join([f'{c:.5f}' for c in coefs] + ['...'])
return (
# Print in a copy-pastable format, with initial coefficients:
'BufferedToeplitz.build(\n'
f' buf_decay={self.buf_decay.tolist()},\n'
f' output_scale={self.output_scale.tolist()},\n'
f' dtype=jnp.{self.dtype})\n'
f'Initial Toeplitz coefs=[{coefs_str}])'
)
def _gt_zero_penalty(x: jax.Array) -> jax.Array:
"""Penalize values to enforce x > 0."""
return -jnp.log(x).sum()
def _lt_zero_penalty(x: jax.Array) -> jax.Array:
"""Penalize values to enforce x < 0."""
return -jnp.log(-x).sum()
def _lt_penalty(x: jax.Array, upper_bound: float) -> jax.Array:
"""Penalize values to enforce x < upper_bound."""
return -jnp.log(upper_bound - x).sum()
def _lt_one_penalty(x: jax.Array) -> jax.Array:
"""Penalize values to enforce x < 1."""
return -jnp.log(1 - x).sum()
[docs]
def min_buf_decay_gap(buf_decay: jax.Array) -> jax.Array:
"""Returns max_{i,j i!=j} abs(theta[i] - theta[j]).
Much of the theory for BLTs, as well as the numerical functions in this file,
require uniqueness of the buf_decay parameters theta. This function computes
the smallest gap between two such thetas, without assuming they are sorted.
Args:
buf_decay: The buf_decay parameters of a BLT.
Returns:
max_{i,j i!=j} abs(theta[i] - theta[j])
"""
theta = jnp.asarray(buf_decay)
A = theta[:, jnp.newaxis] - theta
diag_elements = jnp.diag_indices_from(A)
A = A.at[diag_elements].set(jnp.inf)
return jnp.min(jnp.abs(A))
[docs]
@dataclasses.dataclass(frozen=True)
class LossFn:
"""Encapsulates the loss to be optimized for a specific setting.
This can represent the loss for both single participation and min-sep
participation (which has single participation as a special case).
Attributes:
error_for_inv: Function for computing the error for the BLT representing
C^{-1}, the noise correlating matrix.
sensitivity_squared: Function for computing the sensitivity for the BLT
representing C, the strategy matrix.
n: The number of iterations the mechanism is optimized for.
min_sep: The minimum separation of participations.
max_participations: The effective maximum number of participations, taking
into account n, min_sep, and max_participations.
penalty_strength: The multiplier applied to the sum of penalties for the
loss.
penalty_multipliers: A dict of multipliers (default 1.0) applied to the
individual penalties returned by `compute_penalties`.
max_second_coef: The maximum value of the second Toeplitz coefficient, which
is equal to sum(output_scale).
min_theta_gap: The minimum gap between buf_decay parameters allowed by the
theta_gap penalty.
"""
error_for_inv: Callable[[BufferedToeplitz], ScalarFloat]
sensitivity_squared: Callable[[BufferedToeplitz], ScalarFloat]
n: int
min_sep: int
max_participations: int
# Usually doesn't need to be changed:
penalty_strength: float = 1e-8
penalty_multipliers: dict[str, float] = dataclasses.field(
default_factory=dict
)
max_second_coef: float = 1.0
min_theta_gap: float = 1e-12
[docs]
@classmethod
def build_min_sep(
cls,
n: int,
error: str = 'max',
min_sep: int = 1,
max_participations: int | None = None,
**kwargs,
) -> 'LossFn':
"""Construct a `LossFn` for min-sep participation.
This LossFn computes loss and sensitivity by materializing the Toeplitz
coefficients of C and C^{-1}, and then using the loss functions of
`toeplitz.py`, as described in https://arxiv.org/abs/2408.08868. This is
still significantly faster than computing the error directly from the
Toeplitz coefficients of C, because
::
c_inv_coef = blt.inverse().toeplitz_coefs(n)
is orders of magnitude faster (on GPUs) than
::
c_inv_coef = toeplitz.inverse_coef(blt.toeplitz_coefs(n))
Args:
n: The number of iterations the mechanism is optimized for.
error: Either 'max' or 'mean', indicating whether to optimize for the
maximum or mean squared error, respectively.
min_sep: The minimum separation of participations.
max_participations: The maximum number of participations.
**kwargs: Optional additional arguments to pass to the constructor.
Returns:
A `LossFn` for min-sep participation.
Raises:
ValueError: If `error` is not 'max' or 'mean'.
"""
if error == 'mean':
error_fn = lambda inv_blt: jnp.mean(
toeplitz.per_query_error(noising_coef=inv_blt.toeplitz_coefs(n))
)
elif error == 'max':
error_fn = lambda inv_blt: jnp.max(
toeplitz.per_query_error(noising_coef=inv_blt.toeplitz_coefs(n))
)
else:
raise ValueError(f'Unknown error={error}')
def minsep_sensitivity_squared(blt):
return toeplitz.minsep_sensitivity_squared(
strategy_coef=blt.toeplitz_coefs(n),
min_sep=min_sep,
n=n,
max_participations=max_participations,
)
return cls(
n=n,
error_for_inv=error_fn,
sensitivity_squared=minsep_sensitivity_squared,
min_sep=min_sep,
max_participations=sensitivity.minsep_true_max_participations(
n=n,
min_sep=min_sep,
max_participations=max_participations,
),
**kwargs,
)
[docs]
def compute_penalties(
self, blt: BufferedToeplitz, inv_blt: BufferedToeplitz
) -> dict[str, ScalarFloat]:
"""Computes penalties that help keep the optimization well-behaved.
These correspond to the conditions of Theorem 1 (part a) of "An Inversion
Theorem for Buffered Linear Toeplitz (BLT) Matrices and Applications to
Streaming Differential Privacy" (https://arxiv.org/abs/2504.21413), which
restricts the optimization to a class of well-behvaved BLTs. Note the
constraint `pillutla_score < 1` of part (a) is not strictly necessary, but
empirically including it produces better results.
Args:
blt: The BLT representing C.
inv_blt: The BLT representing C^{-1}.
Returns:
A dictionary of named penalties.
"""
check_float64_dtype(blt)
num_buffers = len(blt.buf_decay)
# We could also impose the same penalties on inv_blt, but this does not seem
# to be necessary.
penalties = {
# Conditions on `blt` parameters.
'buf_decay>0': _gt_zero_penalty(blt.buf_decay),
'buf_decay<1': _lt_one_penalty(blt.buf_decay),
'output_scale>0': _gt_zero_penalty(blt.output_scale),
# Conditions on `inv_blt` parameters.
'inv_buf_decay>0': _gt_zero_penalty(inv_blt.buf_decay),
'inv_buf_decay<1': _lt_one_penalty(inv_blt.buf_decay),
'inv_output_scale<0': _lt_zero_penalty(inv_blt.output_scale),
}
if num_buffers > 1:
# Optimizing for multiple participations can cause multiple buf_decay
# parameters to converge to the same value, which will break
# `blt.inverse()` numerically. We know "good" solutions should
# have separated thetas, so we enforce some separation:
penalties['theta_gap'] = _gt_zero_penalty(
min_buf_decay_gap(blt.buf_decay) - self.min_theta_gap
) + _gt_zero_penalty(
min_buf_decay_gap(inv_blt.buf_decay) - self.min_theta_gap
)
# The 2nd Toeplitz coefficient (after 1.0) is equal to sum(output_scale),
# so to ensure decreasing coefficients it is also necessary that
# sum(output_scale) < 1. When optimizing for single participation,
# we know the optimal 2nd coefficient is 0.5, so we could also use
# a value like max_second_coef = 0.5 + (small value).
second_coef = blt.output_scale.sum()
penalties['second_coef'] = _lt_penalty(second_coef, self.max_second_coef)
# The final condition is the 'Pillutla Score' is less than 1.
penalties['pillutla_score'] = _lt_one_penalty(blt.pillutla_score())
return penalties
[docs]
def penalized_loss(
self,
blt: BufferedToeplitz,
inv_blt: BufferedToeplitz,
normalize_by_approx_optimal_loss: bool = True,
) -> ScalarFloat:
"""Computes the total composite loss to be optimized.
Args:
blt: The BLT representing C.
inv_blt: The BLT representing C^{-1}.
normalize_by_approx_optimal_loss: If True, the loss is normalized by the
expected optimal loss, so the relative penalty strength remains somewhat
consistent across n and k. This is the default, and recommended for
optimization.
Returns:
The total composite loss to be optimized.
"""
check_float64_dtype(blt)
error = self.error_for_inv(inv_blt)
sens_squared = self.sensitivity_squared(blt)
penalty_dict = self.compute_penalties(blt, inv_blt)
multipliers = dict(self.penalty_multipliers) # copy
total_penalty = 0.0
for k, v in penalty_dict.items():
multiplier = multipliers.pop(k, 1.0)
if multiplier != 0.0:
# We want a multiplier of 0 to "turn off" the penalty even if
# the penalty is NaN or inf, so we avoid mulitplication in that case.
total_penalty += multiplier * v
total_penalty *= self.penalty_strength
if multipliers:
raise ValueError(
f'Unrecognized penalty multipliers: {multipliers.keys()}'
)
loss = error * sens_squared
if normalize_by_approx_optimal_loss:
# Optimal sqrt(max_loss) for Toeplitz matrices scales like
# 1 + ln(n)/pi; sqrt(mean_loss) can be a bit lower.
# Similarly, sensitivity_squared scales by max_participations.
# This gives a rough estimate of the expected optimal loss,
# which we use to normalize the computed loss so the
# relative penalty strength remains somewhat consistent across n and k.
approx_optimal_loss = (
self.max_participations + (1 + jnp.log(self.n) / np.pi) ** 2
)
loss /= approx_optimal_loss
return loss + total_penalty
[docs]
def loss(
self, blt: BufferedToeplitz, skip_checks: bool = False
) -> ScalarFloat:
"""Returns the loss (not including penalties).
This function is not intended to be jitted or used in optimization, but
only in evaluation of the final BLT.
Args:
blt: The BLT to compute the loss of.
skip_checks: If True, do not check that the BLT is valid for min-sep
sensitivity.
"""
check_float64_dtype(blt)
try:
inv_blt = blt.inverse(skip_checks=skip_checks)
except RuntimeError as e:
logging.warning(
'During loss computation, error computing inverse for'
' BLT\n%s\nReturning jnp.inf. If you really need a finite loss for'
' this BLT and n is small, consider computing the loss directly from'
' blt.toeplitz_coefs(n). Exception:\n%s',
str(blt),
str(e),
)
return jnp.inf
error = self.error_for_inv(inv_blt)
sens_squared = self.sensitivity_squared(blt)
return error * sens_squared
[docs]
@dataclasses.dataclass(frozen=True)
class Parameterization:
"""A parameterization of a BufferedToeplitz for optimization.
Used by `optimize_loss` to specify how parameters relate to the pair of
BLTs representing C and C^{-1}.
Attributes:
params_from_blt: Constructs parameters to be optimized initialized from a
BLT.
blt_and_inverse_from_params: Constructs a tuple of BLTs representing the
(strategy_matrix, noising_matrix) from parameters.
loss_fn: The loss function to optimize.
"""
params_from_blt: Callable[[BufferedToeplitz], chex.ArrayTree]
blt_and_inverse_from_params: Callable[
[Any],
tuple[BufferedToeplitz, BufferedToeplitz],
]
[docs]
@classmethod
def strategy_blt(cls) -> 'Parameterization':
"""A parameterization where the strategy BLT is the parameterization."""
return cls(
params_from_blt=lambda blt: blt,
blt_and_inverse_from_params=lambda blt: (
blt,
blt.inverse(skip_checks=True),
),
)
[docs]
@classmethod
def buf_decay_pair(cls) -> 'Parameterization':
"""A parameterization where a pair of buf_decay parameters is optimized.
This parameterization is generally more numerically stable than the
`strategy_blt` parameterization, as well as being negligibly faster to
compute (as it does not require a singular-value decomposition). However,
the current L-BFGS parameters are tuned for the strategy_blt
parameterization, so this parameterization may not converge as well with the
default settings.
Returns:
A `Parameterization`.
"""
def params_from_blt(blt: BufferedToeplitz) -> tuple[jax.Array, jax.Array]:
inv_blt = blt.inverse()
return (blt.buf_decay, inv_blt.buf_decay)
def blt_and_inverse_from_params(
params: tuple[jax.Array, jax.Array],
) -> tuple[BufferedToeplitz, BufferedToeplitz]:
return blt_pair_from_theta_pair(params[0], params[1])
return cls(
params_from_blt=params_from_blt,
blt_and_inverse_from_params=blt_and_inverse_from_params,
)
[docs]
def get_loss_fn(
self, loss_fn: LossFn
) -> Callable[[chex.ArrayTree], ScalarFloat]:
"""Returns a loss function for the parameterization."""
return lambda params: loss_fn.penalized_loss(
*self.blt_and_inverse_from_params(params)
)
[docs]
def get_init_blt(
num_buffers: int = 3,
init_blt: BufferedToeplitz | None = None,
) -> BufferedToeplitz:
"""Returns an initial BufferedToeplitz for initializing optimization.
Currently, this defaults to `BufferedToeplitz.from_rational_approx_to_sqrt_x`,
setting max_buf_decay and max_pillutla_score so that the solution is within
the feasible set imposed by the optimization penalities in
`LosssFn.compute_penalties`. However, this initialization choice may be
changed in the future.
Args:
num_buffers: The number of buffers to use in the initial BLT, greater than
or equal to zero.
init_blt: An initial BLT to use. If None, a default initialization is used.
This is a convienence for callers who want to handle an optional explicit
initialization and check that it has the correct number of buffers.
Returns:
A BufferedToeplitz with the requested number of buffers.
"""
if not init_blt:
if num_buffers == 0:
init_blt = BufferedToeplitz.build(buf_decay=[], output_scale=[])
else:
init_blt = BufferedToeplitz.from_rational_approx_to_sqrt_x(
num_buffers=num_buffers,
max_buf_decay=1 - 1e-6,
max_pillutla_score=1 - 1e-6,
)
# Check we have the correct number of buffers:
if len(init_blt.buf_decay) != num_buffers:
raise ValueError(
f'{num_buffers=} does not match {len(init_blt.buf_decay)=}'
)
return init_blt
@optimization.jax_enable_x64
def optimize_loss(
loss_fn: LossFn,
num_buffers: int = 1,
init_blt: BufferedToeplitz | None = None,
parameterization: Parameterization | None = None,
**kwargs,
) -> tuple[BufferedToeplitz, ScalarFloat]:
"""Like, optimize(), but more configurable and a fixed number of buffers.
Args:
loss_fn: The loss to optimize.
num_buffers: The number of buffers to optimize for; the default of 3 is a
good choice in general; large number of buffers can cause numerical
instability in the optimization.
init_blt: An initial BLT to start the optimization from. If None, a default
initialization is used.
parameterization: The parameterization to use for optimization. If None, the
default parameterization is used.
**kwargs: Optional additional arguments to pass to optimization.optimize,
such as `max_optimizer_steps`, `callback`, and `optimizer`. Note this
function may supply different defaults for these arguments compared to
`optimization.optimize`, so use with care.
Returns:
A tuple (blt, loss).
Raises:
RuntimeError: If the optimization produces an invalid BLT.
"""
if num_buffers == 0:
# Construct the identity strategy matrix and directly compute the loss.
blt = BufferedToeplitz.build(buf_decay=[], output_scale=[])
return blt, loss_fn.loss(blt)
if parameterization is None:
parameterization = Parameterization.buf_decay_pair()
default_optimizer = optax.lbfgs(
# Empirical testing shows these parameters improve over the defaults.
# A larger memory size can speed convergence, but also seems to result
# in the optimizer getting stuck at a suboptimal solution.
memory_size=15,
linesearch=optax.scale_by_zoom_linesearch(
max_linesearch_steps=100,
curv_rtol=0.5,
# This avoids some rare nan cases:
max_learning_rate=1.0,
# A small stepsize_precision is necessary when we need values of
# theta very near 1.
stepsize_precision=1e-20,
),
)
optimize_kwargs = {'max_optimizer_steps': 600, 'optimizer': default_optimizer}
optimize_kwargs.update(kwargs)
blt = get_init_blt(num_buffers=num_buffers, init_blt=init_blt)
params = parameterization.params_from_blt(blt)
# Combine the parameterization with the loss fn:
loss_fn_to_optimize = parameterization.get_loss_fn(loss_fn)
params = optimization.optimize(
# Create
loss_fn_to_optimize,
params,
**optimize_kwargs,
)
blt, _ = parameterization.blt_and_inverse_from_params(params)
blt = blt.canonicalize()
loss = loss_fn.loss(blt)
if not jnp.isfinite(loss):
raise RuntimeError(
f'Optimization produced BLT with non-finite loss {loss}:\n{blt}'
)
if jnp.any(jnp.abs(blt.output_scale) < 1e-8):
logging.warning(
'BLT has near-zero output_scale parameters, which '
'means some buffers are ignored. Consider re-optimizing '
'with a smaller number of buffers.\n%s',
blt,
)
return blt, loss
def _optimize_increasing_nbuf(
opt_blt_and_loss_fn: Callable[[int], tuple[Any, float]],
min_buffers: int = 0,
max_buffers: int = 10,
rtol: float = 1.02,
) -> Any:
"""Optimizes w/ increasing num_buffers until the improvement is < rtol."""
prev_blt, prev_loss = opt_blt_and_loss_fn(min_buffers)
for nbuf in range(min_buffers + 1, max_buffers + 1):
try:
blt, loss = opt_blt_and_loss_fn(nbuf)
except NotImplementedError as err:
err.add_note(
'This error may be caused by trying to run `blt.inverse()` on a GPU.'
)
# This can happen if we try to run `jnp.linalg.eigvals` on a GPU,
# for example. This is a real configuration issue, not
# an optimization failure.
raise err
except RuntimeError as err:
logging.warning(
'Optimization failed for %d buffers with:\n%s',
nbuf,
str(err),
)
blt, loss = None, jnp.inf
if rtol * loss < prev_loss:
# Sufficient improvement, accept this BLT and maybe try more buffers:
prev_blt, prev_loss = blt, loss
else:
# Improvement was < rtol, return prev_blt
return prev_blt
return prev_blt
@optimization.jax_enable_x64
def optimize(
*,
n: int,
min_sep: int = 1,
max_participations: int | None = 1,
error: str = 'max',
min_buffers: int = 0,
max_buffers: int = 10,
rtol: float = 1.01,
**kwargs,
) -> BufferedToeplitz:
"""Computes a good BLT with a dynamically-chosen num_buffers for min-sep.
Internally this function uses `jax.jit` on the key optimization steps, but it
is not intended to be called from a jitted context itself.
For single-participation optimization of max error, this function utilizes the
closed-form calculations for sensitivity and
error from https://arxiv.org/abs/2404.16706, and hence optimization time
is essentially independent of `n`. However, particularly for large `n` or
large numbers of buffers, the optimal BLT may have a buf_decay theta very
near 1, which leads to numerical issues in the closed forms. This function
has been reasonably well tested for max loss up to n=10**7.
Otherwise (for multiple-participations or mean loss optimization), this
function computes loss and sensitivity by materializing the Toeplitz
coefficients of C and C^{-1}, and then using the loss functions of
`toeplitz.py`, as described in https://arxiv.org/abs/2408.08868. This is still
significantly faster than directly optimizing
Toeplitz mechanisms, because `blt.inverse()` is much faster than
computing the Toeplitz coefficients of C^{-1} directly. This optimization
benefits from GPUs for large n (say > 1000);
We have observed some issues using TPUs, so avoid them for now.
Internally this function uses `jax.jit` on the key optimization steps, but it
is not intended to be called from a jitted context itself.
Args:
n: The number of iterations the mechanism is optimized for.
min_sep: The minimum separation of participations.
max_participations: The maximum number of participations.
error: Either 'max' or 'mean', indicating whether to optimize for the
maximum or mean squared error, respectively.
min_buffers: The minimum number of buffers to optimize for (inclusive).
max_buffers: The maximum number of buffers to optimize for (inclusive).
rtol: The relative tolerance for the loss improvement in order to increase
the number of buffers.
**kwargs: Optional additional arguments to pass to optimization.optimize,
such as `max_optimizer_steps` and `optimizer`. Note this function may
supply different defaults for these arguments compared to
`optimization.optimize`, so use with care. Further, the same arguments
will be passed to each optimization (for different `num_buffers`), so for
example a stateful `callback` function may not work as expected.
Returns:
A BLT that approximately minimizes the loss.
Raises:
RuntimeError: If the optimization produces an invalid BLT.
"""
if max_buffers > 15:
raise ValueError(
'In typical regimes 3 to 7 buffers should give the best utility (and'
' memory efficiency); 15+ buffers will almost certainly lead to'
' numerical issues in the optimization. Try setting a smaller value for'
' `max_buffers`.'
)
k = sensitivity.minsep_true_max_participations(
n=n, min_sep=min_sep, max_participations=max_participations
)
if k == 1 and error == 'max':
# Single participation max loss, used closed forms:
loss_fn = LossFn.build_closed_form_single_participation(n=n)
else:
loss_fn = LossFn.build_min_sep(
n=n, error=error, min_sep=min_sep, max_participations=max_participations
)
with jax.default_device(_gpu_or_cpu_device()):
return _optimize_increasing_nbuf(
opt_blt_and_loss_fn=lambda nbuf: (
optimize_loss(
loss_fn=loss_fn,
num_buffers=nbuf,
parameterization=Parameterization.buf_decay_pair(),
**kwargs,
)
),
min_buffers=min_buffers,
max_buffers=max_buffers,
rtol=rtol,
)
[docs]
@jax.jit
def geometric_sum(
a: jax.Array, r: jax.Array, num: chex.Numeric = jnp.inf
) -> jax.Array:
"""Sum a + a*r + a*r**2 + ... + a*r**(num-1) (or limit if num=jnp.inf).
Args:
a: Scale factor (or vector of scale factors).
r: ratio between successive terms, requires :math:`|r| < 1`.
num: How many terms to add, or jnp.inf for the limit.
Returns:
The sum.
"""
n = jnp.array(num, dtype=jnp.float64)
def finite_n_geo_sum(a, n, r):
"""Robustly handle the finite-n case, including r=1."""
# We choose between a Taylor series approx near r=1 and the
# direct calculation.
#
# The direct calculation is numerically fairly stable for the
# value, but gets unstable earlier (r farther from 1) for the calculation
# of the gradient (which we need). So this threshold rule was chosen to
# minimize the difference between the direct computation and the series
# approximation in terms of the gradient w.r.t. `r`.
SLOPE = 0.53018965
INTERCEPT = 3.33503185
pow_threshold = INTERCEPT + SLOPE * jnp.log(n)
use_direct_calc = r < 1 - 10 ** (-pow_threshold)
# Quadratic Taylor polynomial approx at r = 1 from sympy:
x0 = n - 1
x1 = r - 1
series_approx = (1 / 6) * a * n * (x0 * x1**2 * (n - 2) + 3 * x0 * x1 + 6)
# Avoid nans in the untaken branch when r == 1, see
# https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
safe_r = jnp.where(use_direct_calc, r, jnp.zeros_like(r))
return jnp.where(
use_direct_calc, a * (1 - safe_r**num) / (1 - safe_r), series_approx
)
return jax.lax.cond(
jnp.isinf(n),
lambda a, n, r: a / (1 - r), # Infinite n case
finite_n_geo_sum,
a,
num,
r,
)
def _inf_if(cond_fn, blt_fn: Callable[..., Any]) -> Any:
def new_fn(blt, *args, **kwargs):
return jax.lax.cond(
cond_fn(blt),
lambda: jnp.inf,
lambda: blt_fn(blt, *args, **kwargs),
)
return new_fn
[docs]
def require_buf_decay_less_eq_one(blt_fn: Callable[..., Any]) -> Any:
"""Return inf if blt, the first arg to blt_fn, has buf_decay > 1."""
return _inf_if(lambda blt: jnp.any(blt.buf_decay > 1), blt_fn)
[docs]
def require_output_scale_gt_zero(blt_fn: Callable[..., Any]) -> Any:
"""Return inf if blt, the first arg to blt_fn, has output_scale <= 0."""
return _inf_if(lambda blt: jnp.any(blt.output_scale <= 0), blt_fn)
def _poly_from_theta(theta: jax.Array) -> jax.Array:
"""Returns a polynomial (1 - theta[0]*x)*...*(1-theta[-1]*x).
Args:
theta: The multiplicative inverses of the roots of the polynomial.
Returns:
A polynomial of degree len(theta) -1 = d represented as an array under the
`np.poly1d` convention. If the returned array is `c`,then
p(x) = c[0]*x**d + c[1]*x**(d-1) + ... + c[-2]*x + c[-1]
"""
theta = jnp.asarray(theta)
z = jnp.prod(-theta)
p = jnp.poly(seq_of_zeros=1 / theta) * z
p = p.at[-1].set(1) # This should be approximately 1, set to make exact.
return p
[docs]
def blt_pair_from_theta_pair(
theta: jax.Array, theta_hat: jax.Array
) -> tuple[BufferedToeplitz, BufferedToeplitz]:
"""Computes BLTs (C, C_inv) from theta and theta_hat.
This implements Lemma 5.2 of https://arxiv.org/abs/2404.16706 (See also
Algorithm 5 of https://arxiv.org/abs/2408.08868). The simplified computation
used in the implementation here appears as Theorem 2 of
https://arxiv.org/abs/2504.21413.
This function uses quantities like 1 /(1/theta[i] - 1/theta[j]), and so can be
numerically unstable if abs(theta[i] - theta[j]) is too small; hence, by
default the Parameterization based on `blt.inverse()` should be used instead.
Args:
theta: Array of thetas for the denominator q of `r(x) = p(x; theta_hat) /
q(x; theta)`.
theta_hat: Array of thetas for the numerator of r(x).
Returns:
A tuple of BLTs (C, C_inv) where C.buf_decay = theta and C_inv.buf_decay =
theta_hat,
with the output_scale (omega) parameters for each computed to make these
inverses.
"""
theta = jnp.asarray(theta)
theta_hat = jnp.asarray(theta_hat)
def get_omega(theta, theta_hat):
numerators = jnp.prod(theta[:, jnp.newaxis] - theta_hat, axis=1)
# Compute denom_i = prod_{j: i != j} (lambda_i - lambda_j)
A = theta[:, jnp.newaxis] - theta
A = A.at[jnp.diag_indices_from(A)].set(1)
denominators = jnp.prod(A, axis=1)
return numerators / denominators
return (
BufferedToeplitz.build(
output_scale=get_omega(theta, theta_hat), buf_decay=theta
),
# Note we reverse the order of the arguments to `get_omega` to get C_inv.
BufferedToeplitz.build(
output_scale=get_omega(theta_hat, theta), buf_decay=theta_hat # pylint: disable=arguments-out-of-order
),
)
[docs]
@jax.jit
@require_buf_decay_less_eq_one
def sensitivity_squared(blt: BufferedToeplitz, n: chex.Numeric) -> float:
"""Computes sensitivity**2 for a BLT strategy matrix C.
See https://arxiv.org/pdf/2404.16706 Lemma 5.3
Args:
blt: The Buffered Linear Toeplitz operator representing C().
n: The number of iterations; the limit of sensitivity as n -> jnp.inf is
returned if n is jnp.inf.
Returns:
The max-column-norm-squared of C.
"""
omega = blt.output_scale
theta = blt.buf_decay
num = jnp.array(n - 1) # Still jnp.inf if n = jnp.inf
# Vectorized via JAX broadcasting, omega and theta must be the same length.
omega_pairs = omega * omega[:, jnp.newaxis]
theta_pairs = theta * theta[:, jnp.newaxis]
geo_pairs = geometric_sum(omega_pairs, theta_pairs, num=num)
return 1.0 + geo_pairs.sum()
def _max_error_Gamma_j(
omega: jax.Array, theta: jax.Array, n: jax.Array
) -> jax.Array:
# Closed-form computation of
# sum([geometric_sum(omega, theta, i) for i in range(1, n)])
return (omega / (1.0 - theta)) * (1 - geometric_sum(1, theta, n) / n)
def _max_error_Gamma_j_series(
omega: jax.Array, theta: jax.Array, n: jax.Array
) -> jax.Array:
"""Taylor series approximation to _max_error_Gamma_j."""
# Auto-generated via sympy, see colab notebook
# robust_max_error_for_blts.ipynb.ipynb
# fmt: off
x0 = theta - 1
x1 = omega*(n - 2)*(n - 1)
return -omega*(1/2 - 1/2*n) + (1/24)*x0**2*x1*(n - 3) + (1/6)*x0*x1
# fmt: on
[docs]
def robust_max_error_Gamma_j(
omega: jax.Array, theta: jax.Array, n: jax.Array
) -> jax.Array:
"""Robustly computes _max_error_Gamma_j."""
# See robust_max_error_for_blts.ipynb.ipynb
# for computation of these regression constants.
J_SLOPE = 0.43877484
J_INTERCEPT = 2.91215085
power = J_INTERCEPT + J_SLOPE * jnp.log(n)
threshold = 1 - 10 ** (-power)
predicate = theta < threshold
# We need to avoid inf/nan in v0, v1, and v2 even if not selected, see:
# https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
safe_theta = jnp.where(predicate, theta, jnp.zeros_like(theta))
v0 = _max_error_Gamma_j(omega, safe_theta, n)
v1 = _max_error_Gamma_j_series(omega, theta, n)
return jnp.where(predicate, v0, v1)
def _max_error_Gamma_jk(
omega1: jax.Array,
theta1: jax.Array,
omega2: jax.Array,
theta2: jax.Array,
n: jax.Array,
) -> jax.Array:
"""Direct computation of Gamma_jk for max-error."""
# Closed-form computation of
# sum([geometric_sum(omega1, theta1, i) * geometric_sum(omega2, theta2, i)
# for i in range(1, n)])
temp1 = omega1 * omega2 / ((1 - theta1) * (1 - theta2))
temp2 = (
n
- geometric_sum(1, theta1, n)
- geometric_sum(1, theta2, n)
+ geometric_sum(1, theta1 * theta2, n)
) / n
return temp1 * temp2
def _max_error_Gamma_jk_series_j(
omega1: jax.Array,
theta1: jax.Array,
omega2: jax.Array,
theta2: jax.Array,
n: jax.Array,
) -> jax.Array:
"""Compute _max_error_Gamma_jk with a series approximation for theta1."""
# Auto-generated via sympy, see colab notebook
# robust_max_error_for_blts.ipynb
# fmt: off
x0 = theta2 - 1
x1 = theta2**(n + 1)
x2 = -x1
x3 = 6*x0
x4 = theta2**n
x5 = n - 1
x6 = theta1 - 1
x7 = theta2**(n + 2)
return (-1/6*omega1*omega2*(
n*x0**3*(3*n + x5*x6*(n - 2) - 3) + n*x3*(x2 + x4) - x3*(theta2 + x2)
+ 3*x6*(n*x0*(-x1*x5 + x4*x5) - 2*n*(x1 - x7)
+ 2*theta2**2 - 2*x7))/(n*x0**4))
# fmt: on
def _max_error_Gamma_jk_series_jk(
omega1: jax.Array,
theta1: jax.Array,
omega2: jax.Array,
theta2: jax.Array,
n: jax.Array,
) -> jax.Array:
"""Compute _max_error_Gamma_jk with a series approximation for theta1."""
# Auto-generated via sympy, see colab notebook
# robust_max_error_for_blts.ipynb
# fmt: off
x0 = n**2
x1 = 3*n**3 + 9*n - 10*x0 - 2
return ((1/24)*omega1*omega2*(-12*n + 8*x0 + x1*(theta1 - 1)
+ x1*(theta2 - 1) + 4))
# fmt: on
[docs]
def robust_max_error_Gamma_jk(
omega1: jax.Array,
theta1: jax.Array,
omega2: jax.Array,
theta2: jax.Array,
n: jax.Array,
) -> jax.Array:
"""Robustly computes _max_error_Gamma_jk.
Robustly computes the Gamma_{j,k} term of the closed-form expression
for max-error, by choosing either the direct computation or a Taylor
series approximation.
Used below to compute all pairs of Gamma_{j,k} via numpy broadcasting.
Args:
omega1: The omega_j argument (a output_scale value), typically a 1D array.
theta1: The theta_j argument (a buf_decay value), typically a 1D array.
omega2: The omega_k argument (a output_scale value), typically a 2D column
vector.
theta2: The theta_k argument (a buf_decay value), typically a 2D column
vector.
n: The value of n.
Returns:
The value of Gamma_jk.
"""
# The _j series approximation needs theta1 > theta2:
theta1, theta2 = jnp.maximum(theta1, theta2), jnp.minimum(theta1, theta2)
JK_SLOPE = 0.35321577
JK_INTERCEPT = 2.81518052
power = JK_INTERCEPT + JK_SLOPE * jnp.log(n)
threshold = 1 - 10 ** (-power)
# We need to avoid inf/nan in v0, v1, and v2 even if not selected, see:
# https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
v0_predicate = theta1 < threshold
v1_predicate = theta2 < threshold
safe_theta1 = jnp.where(v0_predicate, theta1, jnp.zeros_like(theta1))
safe_theta2 = jnp.where(v0_predicate, theta2, jnp.zeros_like(theta2))
v0 = _max_error_Gamma_jk(omega1, safe_theta1, omega2, safe_theta2, n)
safe_theta2 = jnp.where(v1_predicate, theta2, jnp.zeros_like(theta2))
v1 = _max_error_Gamma_jk_series_j(omega1, theta1, omega2, safe_theta2, n)
v2 = _max_error_Gamma_jk_series_jk(omega1, theta1, omega2, theta2, n)
return jnp.where(
v0_predicate, # and theta2 <= theta1 by assumption
# Both thetas are not close to 1, use the direct fn:
v0,
jnp.where(
v1_predicate,
# theta1 is near 1, theta2 is not
v1,
# Both thetas are near 1, use 2-variable series
v2,
),
)
[docs]
@jax.jit
def iteration_error(inv_blt: BufferedToeplitz, i: chex.Array) -> jax.Array:
"""Computes the error on iteration `i` which is also the max error.
That is, for a Buffered Linear Toeplitz matrix, the max error from iteration 0
through `i` is achieved on iteration `i`, so this equivalently computes
the max error for `i+1` iterations.
Here "error" is the squared error introduced in the `n`-th iterate (partial
sum) assuming unit-variance noise. This generally scales as O(n), and
so optimization routines might normalize by an additional factor of n.
This implements https://arxiv.org/pdf/2404.16706 Lemma 5.4.
Args:
inv_blt: The Buffered Linear Toeplitz operator where inv_blt.C() represents
C^{-1} in the matrix factorization mechanism.
i: The iteration for which to compute error, 0-indexed. To compute the max
error for an `n` iteration mechanism, one should thus pass i = n-1 to this
function.
Returns:
The squared-error (variance) on iteration `iter`.
"""
check_float64_dtype(inv_blt)
# Note: The derivation in the paper is 1-indexed and uses `n`;
# we follow that formula here, so we define n accordingly:
n = i + 1
omega = inv_blt.output_scale # shape = (inv_blt._num_buffers,)
theta = inv_blt.buf_decay # shape = (inv_blt._num_buffers,)
# (b, b) -> scalar, where b = inv_blt._num_buffers
s1 = jnp.sum(
robust_max_error_Gamma_j(omega, theta, n)
) # Vectorized, via JAX broadcasting
s2 = jnp.sum(
robust_max_error_Gamma_jk(
omega, theta, omega[:, jnp.newaxis], theta[:, jnp.newaxis], n
)
) # (b, b) -> scalar
return n * (1 + 2 * s1 + s2)
[docs]
@jax.jit
def max_error(inv_blt: BufferedToeplitz, n: chex.Array) -> jax.Array:
"""Returns the max squared error for any iteration 0, ..., n-1."""
# Note: For a BLT, the iteration error is increasing in `i`, so:
return iteration_error(inv_blt, n - 1)
[docs]
@jax.jit
def limit_max_error(
inv_blt: BufferedToeplitz,
) -> jax.Array:
"""Computes limit_{n -> jnp.inf} (1/n)*max_error(blt, n).
Args:
inv_blt: The Buffered Linear Toeplitz operator where inv_blt.C() represents
C^{-1} in the matrix factorization mechanism.
Returns:
The iteration error in the limit.
"""
omega = inv_blt.output_scale # shape = (inv_blt._num_buffers,)
theta = inv_blt.buf_decay # shape = (inv_blt._num_buffers,)
# Re-shaping to utilize numpy broadcasting to behave
# like a double for-loop, omega and theta should
# have the same length.
omega_pairs = omega * omega[:, jnp.newaxis]
theta_pairs = (1 - theta) * (1 - theta[:, jnp.newaxis])
cross_term_sum = jnp.sum(omega_pairs / theta_pairs)
return 1 + 2 * jnp.sum(omega / (1 - theta)) + cross_term_sum
[docs]
@jax.jit
@require_output_scale_gt_zero
def max_loss(blt: BufferedToeplitz, n: jax.Array) -> jax.Array:
"""Max squared error scaled by sensitivity**2."""
check_float64_dtype(blt)
sens_squared = sensitivity_squared(blt, n)
inv_blt = blt.inverse(skip_checks=True)
error = max_error(inv_blt, n)
return error * sens_squared
[docs]
@jax.jit
@require_output_scale_gt_zero
def limit_max_loss(
blt: BufferedToeplitz,
) -> jax.Array:
"""Limit of (1/n) max squared error scaled by sensitivity**2."""
check_float64_dtype(blt)
sens_squared = sensitivity_squared(blt, n=jnp.inf)
inv_blt = blt.inverse(skip_checks=True)
error = limit_max_error(inv_blt)
return error * sens_squared