Source code for jax_privacy.accounting.analysis

# coding=utf-8
# Copyright 2025 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keeping track of the differential privacy guarantee."""

import abc
from collections.abc import Sequence
import dataclasses
import enum
import math
import numbers
import time
from typing import Protocol, TypeAlias

from absl import logging
import chex
import dp_accounting
from jax_privacy.accounting import accountants
import numpy as np
import optax
from scipy import stats


class SamplingMethod(enum.Enum):
  """The sampling method assumed by the privacy analysis.

  `POISSON`:
    We assume each element is independently included in the batch with
    probability `batch_size / num_samples`, such that the batch has expected
    size `batch_size`. Compatible with add-or-remove and zero-out adjacencies.
  `FIXED_BATCH_SIZE`:
    We assume that the batch is a random subset of size `batch_size`. Compatible
    with replace adjacency (and thus zero-out), assuming that `num_samples` is
    public knowledge. The reported DP guarantee is also valid for add-or-remove
    adjacency, but replace adjacency is considered stronger so this is not
    recommended.
  """
  POISSON = enum.auto()
  FIXED_BATCH_SIZE = enum.auto()


def _to_python_int(value_or_array: chex.Numeric) -> int:
  """Converts the input to a python int (fails if value is not unique)."""
  # Expecting a unique value even if the array can have multiple entries.
  [single_value] = np.unique(value_or_array)
  return int(single_value)


# The batch-size scale schedule allows to change the schedule during training:
# e.g. if `bs_init = 16` and `bs_schedule = {100: 4, 200: 2}`, then the
# batch-size is set to 16 on steps 0-99, to 16*4 on steps 100-199, and
# to 16*2 from step 200 onwards.
BatchingScaleSchedule: TypeAlias = dict[int, int] | None


def _interleave_nm_and_bs(
    noise_multipliers: float | Sequence[tuple[int, float]],
    batch_sizes: int | Sequence[tuple[int, int]],
    num_steps: int,
) -> Sequence[tuple[int, float, int]]:
  """Returns `noise_multipliers` and `batch_sizes` across `num_steps`."""

  # If noise_multipliers is a number, turn it into list format of (0, nm).
  if isinstance(noise_multipliers, numbers.Number):
    noise_multipliers = [(0, noise_multipliers)]
  # If batch_sizes is a number, turn it into list format of (0, bs).
  if isinstance(batch_sizes, int):
    batch_sizes = [(0, batch_sizes)]

  # Make sure the time steps of changes are increasing.
  noise_multipliers = sorted(noise_multipliers, key=lambda t: t[0])
  batch_sizes = sorted(batch_sizes, key=lambda x: x[0])
  # Make sure the first time step is 0 in both sequences of hyper-parameters.
  assert noise_multipliers[0][0] == 0
  assert batch_sizes[0][0] == 0
  # Remove any settings which occur later than the maximum number of steps.
  noise_multipliers = [(t, x) for t, x in noise_multipliers if t <= num_steps]
  batch_sizes = [x for x in batch_sizes if x[0] <= num_steps]

  # Interleave both sequences of hyper-parameters into a single one.
  nm_and_bs = _interleave(noise_multipliers, batch_sizes)
  t_nm_and_bs = []
  # Adjust time indices to count number of steps in each configuration.
  for i in range(len(nm_and_bs) - 1):
    t_nm_and_bs.append((
        nm_and_bs[i + 1][0] - nm_and_bs[i][0],
        nm_and_bs[i][1],
        nm_and_bs[i][2],
    ))
  t_nm_and_bs.append(
      (num_steps - nm_and_bs[-1][0], nm_and_bs[-1][1], nm_and_bs[-1][2])
  )
  return t_nm_and_bs


def _interleave(t_a, t_b):
  """Helper function to pair two timed sequences."""
  ts = [t for (t, _) in t_a] + [t for (t, _) in t_b]
  ts = list(set(ts))
  ts.sort()

  def _find_pair(t):
    a = [a for (s, a) in t_a if s <= t][-1]
    b = [b for (s, b) in t_b if s <= t][-1]
    return a, b

  return [(t, *_find_pair(t)) for t in ts]


def _make_batch_size_boundaries(
    batch_size: int,
    scale_schedule: BatchingScaleSchedule,
) -> Sequence[tuple[int, chex.Numeric]]:
  """Returns the boundaries for batch-size values."""
  schedule_fn = optax.piecewise_constant_schedule(
      init_value=batch_size,
      boundaries_and_scales=scale_schedule,
  )
  return [(0, batch_size)] + [
      (threshold, schedule_fn(threshold + 1)) for threshold in scale_schedule
  ]


@dataclasses.dataclass(frozen=True, kw_only=True)
class DpParams:
  """Defines static parameters required for computing DP guarantees.

  Attributes:
    noise_multipliers: The noise multiplier, excluding the clipping norm and the
      batch-size. Or, list of pairs (t: int, nm: float) if the noise multiplier
      changes across steps. 't' indicates step where noise_multiplier is set to
      'nm'.
    delta: delta-value of DP guarantee.
    num_samples: number of examples in the training set.
    batch_size: batch-size used during training.
    dp_analysis_algorithm_config: Configuration for the accounting analysis
      algorithm to use. See subclasses of `DpAnalysisAlgorithmConfig`.
    batch_size_scale_schedule: schedule for scaling the batch-size.
    is_finite_guarantee: Whether the DP guarantee can be expected to be finite.
      This may be False if the clipping norm is not finite for example.
    batch_sizes: (Read-only) attribute consisting of list of pairs (t: int,
      bs: int) where `t` indicates the step where `batch_size` is set to `bs`.
        This is computed from `batch_size` and `batch_size_scale_schedule`.
    examples_per_user: If multiple examples per user are used, this is the
      maximum number any user contributes to the training set.
    cycle_length: If using cyclic Poisson sampling with BandMF, the length of
      the cycle, i.e. the number of partitions formed for sampling. It is
      assumed the number of bands in BandMF is at most cycle_length.
    sampling_method: If our privacy analysis assumes sampling, which sampling
      method it should assume. See SamplingMethod enum for details on each
      sampling method and the adjacency definitions it assumes.
    truncated_batch_size: If using Poisson sampling, a limit on the batch size
      enforced by truncation. If None, we assume no truncation is used.
      Otherwise, we assume that a random subset of the sampled batch is used,
      and the remaining examples are discarded.
  """

  noise_multipliers: float | Sequence[tuple[int, float]] | None
  num_samples: int
  delta: float
  batch_size: int
  batch_size_scale_schedule: BatchingScaleSchedule = None
  is_finite_guarantee: bool = True
  batch_sizes: Sequence[tuple[int, chex.Numeric]] | int = dataclasses.field(
      init=False
  )
  examples_per_user: int | None = None
  cycle_length: int | None = None
  sampling_method: SamplingMethod = SamplingMethod.POISSON
  truncated_batch_size: int | None = None

  def __post_init__(self):
    if self.batch_size_scale_schedule:
      batch_sizes = _make_batch_size_boundaries(
          self.batch_size, self.batch_size_scale_schedule
      )
    else:
      batch_sizes = self.batch_size
    object.__setattr__(self, 'batch_sizes', batch_sizes)


class TrainingAccountant(Protocol):

  def compute_epsilon(
      self,
      num_updates: chex.Numeric,
      dp_params: DpParams,
      allow_approximate_cache: bool = False,
  ) -> float:
    """Computes epsilon given the DP parameters and current `num_updates`."""


class DpTrainingAccountant(metaclass=abc.ABCMeta):
  """Defines privacy accounting interface for machine learning training."""

  def __init__(
      self,
      dp_accountant_config: accountants.DpAccountantConfig,
  ):
    """Initializes the accountant for Differential Privacy.

    Args:
      dp_accountant_config: Configuration for the DP accountant to use.
    """
    self._dp_accountant_config = dp_accountant_config

  @abc.abstractmethod
  def _compute_epsilon(
      self, num_updates: chex.Numeric, dp_params: DpParams
  ) -> float:
    """Computes epsilon using `num_updates` and `dp_params`."""

  @abc.abstractmethod
  def can_calibrate_steps(self) -> bool:
    """Returns whether the `num_steps` can be calibrated."""

  @abc.abstractmethod
  def can_calibrate_batch_size(self) -> bool:
    """Returns whether the `batch_size` can be calibrated."""

  @abc.abstractmethod
  def can_calibrate_noise_multipliers(self) -> bool:
    """Returns whether the `noise_multipliers` can be calibrated."""

  def _validate_dp_params(self, dp_params: DpParams):
    """Asserts that the accountant supports the given `dp_params`."""
    if dp_params.noise_multipliers is None:
      raise ValueError(f'{self.__class__.__name__} requires noise_multipliers.')

  def compute_epsilon(
      self,
      num_updates: chex.Numeric,
      dp_params: DpParams,
      allow_approximate_cache: bool = False,
  ) -> float:
    """Compute DP epsilon given the `dp_params`."""
    del allow_approximate_cache  # This class never uses allow_approximate_cache

    if num_updates == 0:
      return 0.0
    elif dp_params.is_finite_guarantee:
      self._validate_dp_params(dp_params)
      return self._compute_epsilon(num_updates, dp_params)
    else:
      return float('inf')


[docs] class DpsgdTrainingAccountant(DpTrainingAccountant): """Defines privacy computations for Band-MF with Cyclic Poisson sampling. This includes DP-SGD style analysis as a special case. For accounting we follow the reduction in https://arxiv.org/abs/2306.08153. We assume that if num_samples % cycle_length != 0, then num_samples % cycle_length examples are discarded. """
[docs] def can_calibrate_steps(self) -> bool: return True
[docs] def can_calibrate_batch_size(self) -> bool: return True
[docs] def can_calibrate_noise_multipliers(self) -> bool: return True
def _validate_dp_params(self, dp_params: DpParams): super()._validate_dp_params(dp_params) if ( dp_params.examples_per_user is not None and dp_params.examples_per_user != 1 ): raise ValueError( 'DpsgdTrainingAccountant requires examples_per_user = 1 or None. ' 'Choose a different examples_per_user or use ' 'DpsgdTrainingUserLevelAccountant instead.' ) if dp_params.cycle_length is not None and dp_params.cycle_length != 1: if not isinstance(dp_params.batch_size, numbers.Number): raise ValueError( 'DpsgdTrainingAccountant with cycle_length != 1 requires a single' ' batch size.' ) if not isinstance(dp_params.noise_multipliers, numbers.Number): raise ValueError( 'DpsgdTrainingAccountant with cycle_length != 1 requires a single' ' noise multiplier.' ) if dp_params.batch_size * dp_params.cycle_length > dp_params.num_samples: raise ValueError( 'DpsgdTrainingAccountant with cycle_length != 1 requires batch_size' ' * cycle_length <= num_samples.' ) if ( dp_params.sampling_method is not SamplingMethod.POISSON and dp_params.sampling_method is not SamplingMethod.FIXED_BATCH_SIZE ): raise ValueError( 'DpsgdTrainingAccountant requires sampling_method = POISSON or ' 'FIXED_BATCH_SIZE.' ) if dp_params.truncated_batch_size is not None: if dp_params.sampling_method is not SamplingMethod.POISSON: raise ValueError( 'DpsgdTrainingAccountant does not support truncated_batch_size' ' unless using sampling_method = POISSON.' ) if not isinstance( self._dp_accountant_config, accountants.PldAccountantConfig ): raise ValueError( 'DpsgdTrainingAccountant with truncated_batch_size != None requires' ' a PLDAccountant.' ) def _compute_epsilon( self, num_updates: chex.Numeric, dp_params: DpParams ) -> float: nms = dp_params.noise_multipliers batch_sizes = dp_params.batch_sizes num_samples = dp_params.num_samples sampling_method = dp_params.sampling_method cycle_length = dp_params.cycle_length if dp_params.cycle_length else 1 truncated_batch_size = dp_params.truncated_batch_size dp_accountant = self._dp_accountant_config.create_accountant() t_nm_and_bs = _interleave_nm_and_bs(nms, batch_sizes, num_updates) match sampling_method: case SamplingMethod.POISSON: sensitivity_multiplier = 1.0 case SamplingMethod.FIXED_BATCH_SIZE: # Fixed batch size sampling's privacy analysis reduces to Poisson # sampling with the same noise but the sensitivity doubled. sensitivity_multiplier = 2.0 for t, nm, bs in t_nm_and_bs: min_group_size = num_samples // cycle_length q = bs / float(min_group_size) if truncated_batch_size is None: event = dp_accounting.PoissonSampledDpEvent( q, dp_accounting.GaussianDpEvent(nm / sensitivity_multiplier) ) else: # This calculation involves a sum over num_samples terms corresponding # to the possible batch sizes before truncation. To save time and memory # we truncate this sum at a threshold chosen such that the terms in the # sum after the threshold are smaller than the precision of computation. threshold = truncated_batch_size while stats.binom.sf(threshold, min_group_size - 1, q) > 0.0: threshold = max(2 * threshold, min_group_size) sample_sizes = np.arange(truncated_batch_size, threshold) prob_2 = q * np.sum( stats.binom.pmf(sample_sizes, min_group_size - 1, q) * truncated_batch_size / (sample_sizes + 1) ) prob_1 = q * ( 1 - stats.binom.sf(truncated_batch_size, min_group_size - 1, q) ) prob_0 = 1 - prob_1 - prob_2 event = dp_accounting.dp_event.MixtureOfGaussiansDpEvent( nm, [0, 1, 2], [prob_0, prob_1, prob_2] ) dp_accountant.compose(event, math.ceil(t / cycle_length)) return dp_accountant.get_epsilon(target_delta=dp_params.delta)
class DpsgdTrainingUserLevelAccountant(DpTrainingAccountant): """Defines privacy computations for DP-SGD analysis with user-level DP. This class uses the calculations in https://arxiv.org/abs/2401.10294. """ def can_calibrate_steps(self) -> bool: return True def can_calibrate_batch_size(self) -> bool: return True def can_calibrate_noise_multipliers(self) -> bool: return True def _validate_dp_params(self, dp_params: DpParams): super()._validate_dp_params(dp_params) if dp_params.examples_per_user is None: raise ValueError( 'DpsgdTrainingUserLevelAccountant requires examples_per_user.' ) if dp_params.cycle_length is not None and dp_params.cycle_length != 1: raise ValueError( 'DpsgdTrainingUserLevelAccountant requires cycle_length = 1 or None.' ) if ( dp_params.sampling_method is not SamplingMethod.POISSON and dp_params.sampling_method is not SamplingMethod.FIXED_BATCH_SIZE ): raise ValueError( 'DpsgdTrainingUserLevelAccountant requires sampling_method = POISSON ' 'or FIXED_BATCH_SIZE.' ) if dp_params.truncated_batch_size is not None: raise ValueError( 'DpsgdTrainingUserLevelAccountant requires truncated_batch_size =' ' None.' ) if not isinstance( self._dp_accountant_config, accountants.PldAccountantConfig ): raise ValueError( 'DpsgdTrainingUserLevelAccountant requires a PLDAccountant.' ) def _compute_epsilon( self, num_updates: chex.Numeric, dp_params: DpParams ) -> float: nms = dp_params.noise_multipliers batch_sizes = dp_params.batch_sizes num_samples = dp_params.num_samples examples_per_user = dp_params.examples_per_user sampling_method = dp_params.sampling_method dp_accountant = self._dp_accountant_config.create_accountant() if not isinstance(dp_accountant, dp_accounting.pld.PLDAccountant): raise ValueError( 'DpsgdTrainingUserLevelAccountant requires a PLDAccountant.' ) t_nm_and_bs = _interleave_nm_and_bs(nms, batch_sizes, num_updates) for t, nm, bs in t_nm_and_bs: match sampling_method: case SamplingMethod.POISSON: q = bs / float(num_samples) sensitivities = range(examples_per_user + 1) probs = [ stats.binom.pmf(x, examples_per_user, q) for x in sensitivities ] case SamplingMethod.FIXED_BATCH_SIZE: sensitivities = [2 * x for x in range(examples_per_user + 1)] sensitivity_rv = stats.hypergeom(num_samples, examples_per_user, bs) probs = [sensitivity_rv.pmf(x) for x in range(examples_per_user + 1)] event = dp_accounting.dp_event.MixtureOfGaussiansDpEvent( nm, sensitivities, probs ) dp_accountant.compose(event, t) return dp_accountant.get_epsilon(target_delta=dp_params.delta) class SingleReleaseTrainingAccountant(DpTrainingAccountant): """Defines privacy computations for single release analysis. This style of analysis is used for un-amplified DP-FTRL mechanisms, as detailed in https://arxiv.org/pdf/2211.06530. Unlike DP-SGD analysis, which relies on Poisson amplification, this analysis treats accounting as a single Gaussian DP event. """ def can_calibrate_steps(self) -> bool: return False def can_calibrate_batch_size(self) -> bool: return False def can_calibrate_noise_multipliers(self) -> bool: return False def _validate_dp_params(self, dp_params: DpParams): super()._validate_dp_params(dp_params) if ( dp_params.examples_per_user is not None and dp_params.examples_per_user != 1 ): raise ValueError( 'SingleReleaseTrainingAccountant requires examples_per_user = 1 or' ' None' ) def _compute_epsilon( self, num_updates: chex.Numeric, dp_params: DpParams ) -> float: nms = dp_params.noise_multipliers batch_sizes = dp_params.batch_sizes dp_accountant = self._dp_accountant_config.create_accountant() t_nm_and_bs = _interleave_nm_and_bs(nms, batch_sizes, num_updates) for _, nm, _ in t_nm_and_bs: event = dp_accounting.GaussianDpEvent(nm) dp_accountant.compose(event, 1) return dp_accountant.get_epsilon(target_delta=dp_params.delta) def _ceil_div(a: int, b: int) -> int: return (a + b - 1) // b class CachedExperimentAccountant: """Pre-computes and caches epsilon for different `num_updates` values.""" def __init__( self, training_accountant: DpTrainingAccountant, max_num_updates: int, num_cached_points: int = 100, ): """Creates the cached accoutant and computes the cached points and values. Args: training_accountant: Which training accountant to use for computing the results to be cached. max_num_updates: Maximum value for `num_updates` to be requested. num_cached_points: Number of points to pre-compute and cache. """ self._accountant = training_accountant self._max_num_updates = max_num_updates self._num_cached_points = num_cached_points self._cache_is_initialized = False def _maybe_initialize_cache(self, dp_params: DpParams): """Precomputes and caches the values of `num_cached_points` points.""" if self._cache_is_initialized: return logging.info('Pre-computing accounting cache...') start_clock = time.time() self._cached_points = [ _ceil_div(self._max_num_updates * j, self._num_cached_points) for j in range(self._num_cached_points + 1) ] self._cached_values = {} for i, x in enumerate(self._cached_points): self._cached_values[x] = self._accountant.compute_epsilon(x, dp_params) # Compute current duration in seconds. current_duration = time.time() - start_clock ten_minutes_threshold = 10 * 60 # Estimate the total duration by extrapolating (linearly). current_progress = (i + 1) / len(self._cached_points) expected_duration = current_duration / current_progress if expected_duration > ten_minutes_threshold: logging.warning( 'Accounting cache is being slow: total duration estimated to' ' {%.0fmin} (current progress: %.0f%%)', expected_duration / 60, 100.0 * current_progress, ) end_clock = time.time() logging.info( 'Accounting cache (took %.3fmin).', (end_clock - start_clock) / 60 ) self._cache_is_initialized = True def compute_epsilon( self, num_updates: chex.Numeric, dp_params: DpParams, allow_approximate_cache: bool = False, ) -> float: """Uses cached results to give an approximate (over-estimated) epsilon. The value returned should always be an over-approximation of the true epsilon: this method uses the closest `num_updates` in the cache that is equal to or greater than the requested `num_updates`. If such a value cannot be found, an indexing error will be raised. Args: num_updates: The number of updates to compute epsilon for. dp_params: Parameters required for computing the DP guarnatee. allow_approximate_cache: Whether to use the approximate cache. Returns: Value of epsilon. """ if allow_approximate_cache: self._maybe_initialize_cache(dp_params) closest_cached_point = self._cached_points[ _ceil_div( self._num_cached_points * num_updates, self._max_num_updates ) ] return self._cached_values[closest_cached_point] else: return self._accountant.compute_epsilon(num_updates, dp_params)