Source code for jax_privacy.matrix_factorization.dense

# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Optimization and error fns for dense (explicitly represented) strategies.

See `sensitivity.py` for sensitivity calculations for dense strategies.
"""

from collections.abc import Callable

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np

from . import _checks as checks
from . import optimization
from . import sensitivity


# Disabling pylint invalid-name to allow mathematical notation including
# single-capital-letter variables for matrices.
# See README.md for notation conventions.
# pylint:disable=invalid-name


[docs] def per_query_error( *, strategy_matrix: jax.Array | None = None, noising_matrix: jax.Array | None = None, workload_matrix: jax.Array | None = None, skip_checks: bool = False, ) -> jax.Array: """Expected per-query squared error for a general matrix mechanism. Exactly one of `strategy_matrix` and `noising_matrix` should be provided. Args: strategy_matrix: The (square) strategy matrix C defining the mechanism. noising_matrix: The (possibly non-square) noising matrix C^{-1}. workload_matrix: The workload matrix. Defaults to `jnp.tri`, the prefix sum workload matrix. skip_checks: If True, skip input verification that depends on array values (e.g. finiteness, triangularity). Checks on static attributes (shapes, None-ness) are always performed. Returns: The expected per-query squared error, an array of length n. """ if (strategy_matrix is None) == (noising_matrix is None): raise ValueError( 'Specify exactly one of strategy_matrix or noising_matrix.' ) if strategy_matrix is not None: C = strategy_matrix # We require square C, because for a non-square C, the choice of the # specific pseudoinverse C^{-1} determines the error. Hence, we require # the user to specify the C^{-1} explicitly and pass that in instead. checks.check_square(C, 'strategy_matrix') n = C.shape[1] A = workload_matrix if workload_matrix is not None else jnp.tri(n) # Solve B @ C = A for B B = jnp.linalg.solve(C.T, A.T).T if not skip_checks: checks.check(A=A, B=B, C=C) else: assert noising_matrix is not None C_inv = noising_matrix n = C_inv.shape[0] A = workload_matrix if workload_matrix is not None else jnp.tri(n) B = A @ C_inv if not skip_checks: checks.check(A=A, B=B) return jnp.sum(B * B, axis=1)
[docs] def get_orthogonal_mask(n: int, epochs: int = 1) -> jax.Array: """Computes a mask that imposes orthognality constraints on the optimization. This is specific to the fixed-epoch-order (k, b)-participation schema of https://arxiv.org/pdf/2211.06530.pdf, where participations are separated by exactly b-1 steps, and b = n / epochs. This mask sets entry M_{ij} = 0 if i == j (mod b) and M_{ij} = 1 otherwise. Sensitivity for any matrix with 0s in these entries is easy to calculate as only a function of the diagonal. Moreover, the sensitivity is equal for all possible {-1,1} participation vectors. Args: n: the size of the mask epochs: The number of epochs Returns: A 0/1 mask """ # We use numpy instead of Jax internally here because we are performing # in-place updates to mask. mask = np.ones((n, n)) b = n // epochs for i in range(b): mask[i::b, i::b] = np.eye(epochs) return jnp.array(mask)
def _mean_loss_and_gradient( X: jax.Array, A: jax.Array ) -> tuple[jax.Array, jax.Array]: r"""Computes the matrix mechanism total squared error loss and gradient. This function computes $\tr[A^T A X^{-1}]$ and the associated gradient $dX = -X^{-1} A^T A X^{-1}$. It assumes that $X$ is a symmetric positive definite matrix. For efficiency, no error is thrown if this assumption is not satisfied, but the returned loss or gradient may contain NaN's if this is the case. Args: X: The current iterate, an n x n symmetric positive definite matrix. A: The workload, an n x n matrix Returns: loss: a real-valued number gradient: the gradient of the loss w.r.t. X, an n x n matrix """ # It is significantly faster to compute the gradient ourselves rather than # rely on Jax autodiff here. For n=8192, difference is 550ms vs. 900ms on GPU. n = X.shape[0] H = jsp.linalg.solve(X, A.T, assume_a='pos') return jnp.trace(H @ A) / n, -H @ H.T / n
[docs] def strategy_from_X(X: jax.Array) -> jax.Array: """Return a lower triangular strategy matrix C from its Gram matrix. Args: X: A positive symmetric semidefinite matrix. Returns: A lower triangular matrix C satisfying X = C^T C. """ return jnp.linalg.cholesky(X[::-1, ::-1]).T[::-1, ::-1]
[docs] def pg_tol_termination_fn(step_info: optimization.CallbackArgs) -> bool: """Callback function that returns True if projected gradient is near-zero.""" return bool(jnp.abs(step_info.grad).max() <= 1e-3)
[docs] def optimize( n: int, *, epochs: int = 1, bands: int | None = None, equal_norm: bool = False, A: jax.Array | None = None, max_optimizer_steps: int = 10000, reduction_fn: Callable[[jax.Array], jax.Array] = jnp.mean, callback: optimization.CallbackFnType = pg_tol_termination_fn, ) -> jax.Array: """Optimizes a strategy matrix C for a given reduction_fn and participation. Note: While the function accepts a reduction_fn keyword argument, it has been tuned and tested rigorously only for mean-squared error (i.e., reduction_fn=jnp.mean). This function can be used to optimize matrices under * Single-participation: [Denisov et al., 2022](https://arxiv.org/abs/2202.08312). This can be accomplished by running with default arguments. * Multi-participation with fixed-epoch order: [Choquette-Choo et al., 2022](https://arxiv.org/abs/2211.06530). This can be accomplished by setting epochs=k. * Multi-participation with min-separation (useful for federated training scenarios). This can be accomplished by setting bands = min_sep and equal_norm = True. * Multi-participation with amplification via subsampled fixed-epoch order: [Choquette-Choo et al., 2022](https://arxiv.org/abs/2211.06530). This can be accomplished by setting epochs=1, bands<separation, and equal_norm=True. Args: n: the number of iterations the strategy should encode. epochs: The number of epochs the strategy should be calibrated for. Assumes (k, b)-fixed-epoch order participation. bands: The number of bands in the strategy. equal_norm: Flag to indicate that each column of C should have equal_norm. Useful for BandMF. If epochs=1, this flag is a no-op, as the returned strategy will be column normalized either way. A: The workload matrix (defaults to Prefix). max_optimizer_steps: The maximum number of LBFGS steps to take. reduction_fn: A function that converts per query squared errors to a scalar. Use jnp.mean to optimize mean-squared-error, jnp.max to optimize max squared error, or any other differentiable function writtten in Jax. callback: An optional callback function to monitor optimization progress. The default callback terminates the optimization early if the projected gradient is near-zero. Returns: The strategy matrix C that minimizes expected total squared error. """ A = jnp.tri(n) if A is None else A mask = get_orthogonal_mask(n, epochs) if bands is not None: mask = mask * sensitivity.banded_symmetric_mask(n, bands) @jax.value_and_grad def loss_and_grad(X): # It's better to calculate this w.r.t. X than go through per_query_error. H = jsp.linalg.solve(X, A.T, assume_a='pos') per_query = jnp.sum(A * H.T, axis=1) return reduction_fn(per_query) def loss_and_projected_grad(X): if reduction_fn is jnp.mean: # Use an exact analytical formula for jnp.mean for max performance. loss, dX = _mean_loss_and_gradient(X, A) else: loss, dX = loss_and_grad(X) if equal_norm: diag = 0 else: # normalizes sum_i dX[i,i] = 0, where sum is taken over iterations a # single user can participate in. This ensures that sum_i X[i,i] # remains equal to 1. dsum = jnp.diag(dX).reshape(epochs, -1).sum(axis=0) / epochs diag = jnp.diag(dX) - jnp.kron(jnp.ones(epochs), dsum) dX = dX.at[jnp.diag_indices(n)].set(diag) # sets dX[i,j] = 0 if i \neq j and a user can appear in both round i and j. # If banded constraints are given, sets dX[i,j] = 0 if |i-j| > # bands. return loss, dX * mask X = optimization.optimize( loss_and_projected_grad, jnp.eye(n, dtype=jnp.float64) / epochs, max_optimizer_steps=max_optimizer_steps, grad=True, callback=callback, ) return strategy_from_X(X)