# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class and instances for expressing and optimizing banded strategies."""
from collections.abc import Callable
import functools
from typing import Any
import chex
import jax
import jax.numpy as jnp
import numpy as np
from . import optimization
from . import sensitivity
from . import streaming_matrix
# Disabling pylint invalid-name to allow mathematical notation including
# single-capital-letter variables for matrices.
# See README.md for notation conventions.
# pylint:disable=invalid-name
[docs]
@chex.dataclass
class ColumnNormalizedBanded:
"""A column-normalized banded lower triangular n x n matrix.
This matrix class is parameterized by an arbitrary n x b matrix.
C(params) is obtained by setting the first b bands of C based on params.
The matrix is normalized to have sensitivity 1 under a single epoch,
by dividing each column by its respective norm.
Below we show how params relates to the matrix (before column normalization):
::
params = [a b c]
[d e f]
[g h i]
[j k -]
[m - -]
C = [a ]
[b d ]
[c e g ]
[ f h j ]
[ i k m]
"""
params: jnp.ndarray
@property
def n(self) -> int:
return self.params.shape[0]
@property
def bands(self) -> int:
return self.params.shape[1]
[docs]
@classmethod
def from_banded_toeplitz(
cls, n: int, coefs: jnp.ndarray
) -> 'ColumnNormalizedBanded':
"""Construct an instance of this object from banded toeplitz coefficients.
Args:
n: the number of training iterations.
coefs: an array of b toeplitz coefficients defining the strategy.
Returns:
A ColumnNormalizedBanded representation of the banded toeplitz matrix.
"""
bands = coefs.size
if bands > n or bands < 1:
raise ValueError(f'len(coefs) must be in the range [1, n], got {bands}')
coefs = coefs / jnp.linalg.norm(coefs)
params = jnp.broadcast_to(coefs, (n, bands))
params = jnp.tril(params[::-1])[::-1] # set the lower right triangle to 0
return cls(params=params)
[docs]
@classmethod
def default(cls, n: int, bands: int) -> 'ColumnNormalizedBanded':
"""Construct a default instance of this object given n and bands.
This object is initialized by using the fixed toeplitz strategy proposed
in [1; Algorithm 1], truncating to $b$ entries, and column normalizing.
It can act as a useful initialization for further optimization.
[1] https://proceedings.mlr.press/v202/fichtenberger23a/fichtenberger23a.pdf
Args:
n: the number of training iterations.
bands: the number of bands in the strategy.
Returns:
A ColumnNormalizedBanded object.
"""
k = jnp.arange(bands)
coefs = jnp.cumprod(((2 * k - 1) / (2 * k)).at[0].set(1))
return ColumnNormalizedBanded.from_banded_toeplitz(n, coefs)
[docs]
def materialize(self) -> jnp.ndarray:
I = jnp.arange(self.n)[:, None]
J = jnp.arange(self.n)[None]
D = I - J
indexer = (D + self.bands * J + 1) * (D >= 0) * (D < self.bands)
C = jnp.append(0, self.params.flatten())[indexer]
return C / jnp.linalg.norm(C, axis=0)
[docs]
def inverse_as_streaming_matrix(
self,
) -> streaming_matrix.StreamingMatrix:
"""Create $C^{-1}$ as a StreamingMatrix object."""
def init_fn(abstract_value):
dtype = jnp.promote_types(abstract_value.dtype, self.params.dtype)
# *_like preserves shape, dtype, and (if possible) sharding.
zero = jnp.zeros_like(abstract_value, dtype=dtype)
buffers = jnp.broadcast_to(zero, (self.bands,) + zero.shape)
return jnp.array(0), buffers
def next_fn(value, state):
index, bufs = state
if self.bands == 1:
return value, (index + 1, bufs)
k = index % self.bands
r = jnp.arange(self.bands)
row = self.params[index - r, r]
# Algorithm 9 from https://arxiv.org/abs/2306.08153
# Compute xi = (value - row[1:] @ bufs[k-r][1:]) / row[0]
inner = jnp.tensordot(row[1:], bufs[k - r][1:], axes=((0,), (0,)))
xi = (value - inner) / row[0]
col_norm = jnp.linalg.norm(self.params[index])
updated_state = (index + 1, bufs.at[k].set(xi))
return xi * col_norm, updated_state
return streaming_matrix.StreamingMatrix.from_array_implementation(
init_fn, next_fn
)
[docs]
def minsep_sensitivity_squared(
strategy: ColumnNormalizedBanded,
min_sep: int,
max_participations: int | None = None,
n: int | None = None,
skip_checks: bool = False,
) -> int:
"""Returns the sensitivity of the ColumnNormalizedBanded strategy.
With max_participations = 1 (and any min_sep, say min_sep = 1), this is the
same as single participation.
Args:
strategy: The strategy matrix defining the mechanism.
min_sep: The minimum separation between two participation of a worst-case
client/sample. Note that we use the definition in [(Amplified) Banded
Matrix Factorization: A unified approach to private
training](https://arxiv.org/abs/2306.08153). For a user participating on
iteration $i$ and then again on iteration $j$, the separation is $j -i$;
that is, a min_sep of 1 allows participation on every iteration.
max_participations: The maximum participation of a worst-case user. The
default value None allows the max number of possible participations.
n: Optional, the size of the matrix C (see `coef` above). If None, the size
of the matrix is equal to the number of coefficients.
skip_checks: If True, don't perform input verification which may not be
supported in jitted contexts.
Returns:
The sensitivity squared.
"""
bands = strategy.bands
n = n or strategy.n
max_participations = sensitivity.minsep_true_max_participations(
n, min_sep, max_participations
)
if not skip_checks:
if min_sep < bands:
raise ValueError(
f'{min_sep=} must be greater than or equal to {bands=}. This error is'
' usually indicative of a mis-configuration of the strategy for the'
' participation pattern. If it is intentional, please use'
' sensitivity.get_sensitivity_banded.'
)
if n > strategy.n:
raise ValueError(f'{n=} must be less than or equal to {strategy.n=}.')
return max_participations
def _equinox_scan_fn(n: int, bands: int, memory_limit_gb: float = 4):
"""Checkpointed scan function for memory-efficient backpropagation."""
# We do not want to take a hard dependence on equinox, so we import it here.
import equinox # pylint: disable=g-import-not-at-top, import-outside-toplevel
used = bands * n
limit = 2**30 * memory_limit_gb // 8
checkpoints = 2 ** int(np.log2(limit // used) - 1)
return functools.partial(
equinox.internal.scan,
kind='checkpointed',
checkpoints=checkpoints,
)
def _dinosaur_scan_fn(n: int, bands: int, memory_limit_gb: float = 4):
"""Checkpointed scan function for memory-efficient backpropagation."""
# We do not want to take a hard dependence on dinosaur, so we import it here.
from dinosaur import time_integration # pylint: disable=g-import-not-at-top, import-outside-toplevel
used = bands * n
limit = 2**30 * memory_limit_gb // 8
max_checkpoints = 2 ** int(np.log2(limit // used) - 1)
candidates = range(min(max_checkpoints, n), 0, -1)
num_checkpoints = next(filter(lambda d: n % d == 0, candidates))
nested_lengths = [num_checkpoints, n // num_checkpoints]
return functools.partial(
time_integration.nested_checkpoint_scan,
nested_lengths=nested_lengths,
)
# TODO: b/329444015 - document the definition of per query squared error.
[docs]
@functools.partial(jax.jit, static_argnums=[1, 2])
def per_query_error(
C: ColumnNormalizedBanded,
A: streaming_matrix.StreamingMatrix | None = None,
scan_fn: Any = jax.lax.scan,
) -> jnp.ndarray:
"""Computes expected per-query squared error of a strategy.
Specifically, this function computes the row-wise L2^2 norm of B = A C^{-1}.
this vector to a scalar via the reduction_fn.
Since C is column normalized, this error function can be used as
a loss function, since sensitivity is constant for ColumnNormalizedBanded
strategies for both single-participation and multi-participation settings,
as long as the number of bands in C is less than or equal to the (minimum)
separation between contributions from the same user.
If you need to backpropagate through this function, you can use the
`equinox` or `dinosaur` scan functions to make the scan checkpointed,
which allows the scan to be performed for large n without OOMing the
accelerator.
Args:
C: the strategy matrix, represented implicitly.
A: The workload matrix, represented implicitly.
scan_fn: A function with the same signature as jax.lax.scan.
Returns:
The per query expected squared error of the strategy on the workload,
represented as an array of length `n`.
"""
if scan_fn == 'equinox':
scan_fn = _equinox_scan_fn(C.n, C.bands)
elif scan_fn == 'dinosaur':
scan_fn = _dinosaur_scan_fn(C.n, C.bands)
A = A or streaming_matrix.prefix_sum()
B = A @ C.inverse_as_streaming_matrix()
return B.row_norms_squared(C.n, scan_fn=scan_fn)
# TODO: b/329444015 - rethink how the objective should be specified
[docs]
def optimize(
n: int,
*,
bands: int,
C: ColumnNormalizedBanded | None = None,
A: streaming_matrix.StreamingMatrix | None = None,
max_optimizer_steps: int = 100,
reduction_fn: Callable[[jnp.ndarray], jnp.ndarray] = jnp.mean,
scan_fn: Any = jax.lax.scan,
callback: optimization.CallbackFnType = lambda _: None,
) -> ColumnNormalizedBanded:
"""Optimize the strategy using a gradient-based method.
Note that this function benefits substantially from GPUs. This function
is primarily supported to aid in reproducing results from
https://arxiv.org/abs/2405.15913. In practice, we recommend using a
banded Toeplitz strategy instead (see toeplitz.optimize_banded_toeplitz),
which are <0.5% suboptimal in the regimes of most interest (n>=1000, b<=32).
The strategies produces by this procedure can be used in both single- and
multi-participation settings -- both (k, b)-min-sep and (k, b)-fixed epoch
order, as described in https://arxiv.org/abs/2306.08153, as long as the
number of bands in C is less than or equal to the (minimum) separation
between contributions from the same user.
Args:
n: The number of training iterations the strategy is configured for.
bands: The number of bands in the strategy.
C: The initial strategy to be optimized.
A: The target workload.
max_optimizer_steps: The maximum number of iterations to optimize for.
reduction_fn: A function that converts per query squared errors to a scalar.
Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max
squared error, or lambda v: v[-1] to optimize last iterate squared error.
scan_fn: Either 'equinox', 'dinosaur', or a function with the same signature
as jax.lax.scan. Using 'equinox' or 'dinosaur' is helpful for doing
strategy optimization on GPUs for large n, since it allows the scan
function used internally by per_query_error to be checkpointed, avoiding
OOM errors during backpropagation.
callback: A function to call after each optimization iteration. See
optimization.optimize for details.
Returns:
An optimized strategy having the same structure as C.
"""
A = A or streaming_matrix.prefix_sum()
C = C or ColumnNormalizedBanded.default(n, bands)
loss_fn = lambda C: reduction_fn(per_query_error(C, A=A, scan_fn=scan_fn))
return optimization.optimize(
loss_fn, C, max_optimizer_steps=max_optimizer_steps, callback=callback
)