# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for clipping function outputs and aggregating across a batch."""
import collections
from collections.abc import Sequence
import dataclasses
import functools
import numbers
from typing import Any, Callable, TypeAlias
import chex
import dp_accounting
import jax
import jax.numpy as jnp
import optax
from optax import microbatching
PyTree: TypeAlias = chex.ArrayTree
AuxiliaryOutput = collections.namedtuple('Aux', ['values', 'grad_norms', 'aux'])
_REPLACE_SPECIAL = dp_accounting.NeighboringRelation.REPLACE_SPECIAL
[docs]
@dataclasses.dataclass(frozen=True)
class BoundedSensitivityCallable:
"""Callable with a sensitivity property.
If has_aux is False, the sensitivity guarantee holds for the entire output
which may be an arbitrary PyTree of JAX Arrays. If has_aux is True, the
output of the function is a pair `(value, aux)` and the sensitivity guarantee
only holds for `value` PyTree. The aux PyTree is returned on a per-example
basis (i.e., as a PyTree of arrays having a batch axis). The caller should
handle the aux output with care w.r.t. DP guarantees, should they be needed.
"""
fun: Callable[..., Any]
l2_norm_bound: float
has_aux: bool
[docs]
def __call__(self, *args, **kwargs):
return self.fun(*args, **kwargs)
[docs]
def sensitivity(
self,
neighboring_relation: dp_accounting.NeighboringRelation = _REPLACE_SPECIAL, # pylint: disable=line-too-long
):
"""Returns the L2 sensitivity of the Callable.
The L2 sensitivity is defined with respect to the given neighboring relation
and the unit of privacy implied by the function that created this instance.
Args:
neighboring_relation: The neighboring relation to consider.
Returns:
The L2 sensitivity of the Callable.
"""
match neighboring_relation:
case dp_accounting.NeighboringRelation.ADD_OR_REMOVE_ONE:
return self.l2_norm_bound
case dp_accounting.NeighboringRelation.REPLACE_ONE:
return 2 * self.l2_norm_bound
case dp_accounting.NeighboringRelation.REPLACE_SPECIAL:
return self.l2_norm_bound
case _:
raise ValueError(f'Unsupported {neighboring_relation=}')
[docs]
def clip_pytree(
pytree: PyTree,
clip_norm: float,
rescale_to_unit_norm: bool = False,
nan_safe: bool = True,
return_zero: bool = False,
):
"""Clips a PyTree of jax arrays based on its global L2 norm.
Calculates the global L2 norm of the input PyTree. If the norm exceeds
`clip_norm`, the PyTree is scaled down to have norm equal to `clip_norm`.
If `rescale_to_unit_norm` is True, the PyTree is additionally scaled by
`1.0 / clip_norm` (resulting in a norm of at most 1.0 no matter
what clip_norm is). Handles cases where the original norm is zero,
or the clip norm is 0 or infinity.
Formal Guarantees:
- The output PyTree will have norm at most `clip_norm` if
`rescale_to_unit_norm` is False, and norm at most 1.0 if it is True.
- The output PyTree will have the same structure+dtypes as the input PyTree.
Edge Case Handling:
======================= ==================== =================================
Case rescale_to_unit_norm Output
======================= ==================== =================================
clip_norm = 0 False Zero
clip_norm = 0 True Input / norm, as clip_norm -> 0
clip_norm = inf False Unchanged
clip_norm = inf True Zero
clip_norm < 0 (static) - Raises ValueError
clip_norm < 0 (dynamic) - Zero
pytree_norm = 0 - Unchanged
======================= ==================== =================================
Args:
pytree: The PyTree of arrays to clip.
clip_norm: The maximum L2 norm allowed.
rescale_to_unit_norm: If True, the output PyTree's norm is rescaled by `1.0
/ clip_norm` after potential clipping. If False, the output PyTree has
norm at most `clip_norm`.
nan_safe: If True, NaNs and +/- infs are converted to 0 before clipping.
Must be True to preserve the formal guarantees in the presence of NaNs,
although it does require potentially additional computation. If False, the
NaNs in input PyTree will be preserved in the output PyTree. +/- infs will
be converted to NaNs as well.
return_zero: If True, the output PyTree is guaranteed to be zero no matter
what the inputs are. Does not influence the formal guarantees.
Returns:
A tuple `(clipped_pytree, original_l2_norm)`, where `clipped_pytree` is the
processed PyTree and `original_l2_norm` is the L2 norm of the input PyTree.
"""
if isinstance(clip_norm, numbers.Real) and clip_norm < 0:
raise ValueError(f'clip_norm must be non-negative, got {clip_norm=}.')
if nan_safe:
nan_to_num = lambda x: jnp.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
pytree = jax.tree.map(nan_to_num, pytree)
clip_norm = jnp.maximum(clip_norm, 0.0)
l2_norm = optax.tree.norm(pytree)
scale = jnp.minimum(1.0, clip_norm / l2_norm)
if rescale_to_unit_norm:
scale = jax.lax.select(clip_norm > 0, scale / clip_norm, 1 / l2_norm)
# If l2_norm is 0 or nan, set scale to 0.0.
scale = jnp.nan_to_num(scale, nan=0.0, posinf=0.0)
clipped = jax.tree.map(lambda x: jnp.astype(scale, x.dtype) * x, pytree)
maybe_zero = lambda x: jax.lax.select(return_zero, jnp.zeros_like(x), x)
return jax.tree.map(maybe_zero, clipped), l2_norm.astype(jnp.float32)
def _adjusted_clip_norm(
l2_clip_norm: float,
grid_scale: int,
num_params: int,
) -> float:
"""Computes a tighter clip norm to account for norm increase from rounding.
When a real-valued vector is scaled by ``grid_scale / l2_clip_norm`` and
rounded to the nearest integer, each coordinate changes by at most 0.5.
This can increase the L2 norm by at most ``sqrt(num_params) * 0.5`` (in
integer units), or equivalently ``sqrt(num_params) * l2_clip_norm /
grid_scale / 2`` in the original units. We return a tighter clip norm so
that after rounding, the integer L2 norm is at most ``grid_scale``.
Args:
l2_clip_norm: The desired post-rounding L2 norm bound (in original units).
grid_scale: Number of integer grid steps corresponding to l2_clip_norm.
num_params: Total number of scalar parameters in the gradient.
Returns:
The adjusted (tighter) clip norm in original units.
Raises:
ValueError: If the rounding error exceeds the clip norm.
"""
grid_step = l2_clip_norm / grid_scale
max_rounding_error = num_params**0.5 * grid_step * 0.5
adjusted = l2_clip_norm - max_rounding_error
# Only check when all values are concrete Python numbers (not JAX tracers).
if isinstance(adjusted, numbers.Real) and adjusted <= 0:
raise ValueError(
f'Grid scale {grid_scale} is too small for {num_params} parameters. '
f'The rounding error ({max_rounding_error:.4f}) exceeds the clip norm '
f'({l2_clip_norm:.4f}). Increase grid_scale.'
)
return adjusted
def _check_x64_enabled():
"""Raises ValueError if JAX 64-bit mode is not enabled."""
if not jax.config.jax_enable_x64:
raise ValueError(
'grid_scale requires 64-bit mode. Call '
"jax.config.update('jax_enable_x64', True) before using "
'grid_scale, otherwise jnp.int64 is silently truncated to '
'int32.'
)
[docs]
def clip_and_round_to_grid(
gradient: PyTree,
l2_clip_norm: float,
grid_scale: int,
*,
nan_safe: bool = True,
return_zero: bool = False,
) -> tuple[PyTree, jax.Array]:
"""Clips a gradient and rounds it to an integer grid.
This function performs four operations on a single (unbatched) gradient
pytree:
1. Computes a tighter (adjusted) clip norm to account for the L2 norm
increase caused by rounding.
2. Clips the gradient to have L2 norm at most the adjusted clip norm.
3. Scales the gradient by ``grid_scale / l2_clip_norm``.
4. Rounds each coordinate to the nearest integer.
After rounding, the integer vector has L2 norm at most ``grid_scale``.
This function is designed to be used with ``jax.vmap`` to process a batch of
per-example gradients.
Args:
gradient: A pytree of gradient arrays for a single example.
l2_clip_norm: The desired L2 clip norm (in original units).
grid_scale: Number of integer grid steps corresponding to l2_clip_norm.
nan_safe: If True, NaNs and +/- infs are converted to 0 before clipping. See
``clip_pytree`` for details.
return_zero: If True, the output is guaranteed to be zero regardless of
inputs. See ``clip_pytree`` for details.
Returns:
A tuple ``(rounded, l2_norm)`` where ``rounded`` is a pytree of int64
arrays representing the rounded gradient on the integer grid, and
``l2_norm`` is the L2 norm of the input gradient before clipping.
Raises:
ValueError: If grid_scale is too small relative to the number of
parameters (the rounding error would exceed the clip norm).
"""
_check_x64_enabled()
num_params = sum(x.size for x in jax.tree.leaves(gradient))
adj_clip_norm = _adjusted_clip_norm(l2_clip_norm, grid_scale, num_params)
clipped, l2_norm = clip_pytree(
gradient, adj_clip_norm, nan_safe=nan_safe, return_zero=return_zero
)
scale = grid_scale / l2_clip_norm
rounded = jax.tree.map(
lambda g: jnp.round(g * scale).astype(jnp.int64), clipped
)
return rounded, l2_norm
# pylint: disable=g-bare-generic
def _with_extra_batch_axis(
fun: Callable, batch_argnums: int | Sequence[int]
) -> Callable:
"""Wraps a function to add an extra batch axis to the batch_argnums."""
if isinstance(batch_argnums, int):
batch_argnums = (batch_argnums,)
def wrapped_fun(*args, **kwargs):
args_with_group_axis = list(args)
for i in batch_argnums:
args_with_group_axis[i] = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=1), args[i]
)
return fun(*args_with_group_axis, **kwargs)
return wrapped_fun
def _validate_batch_args(batch_argnums, args):
"""Validates the arguments to the per-example gradient clipping function."""
if isinstance(batch_argnums, int):
batch_argnums = (batch_argnums,)
max_argnum = max(batch_argnums)
if len(args) <= max_argnum:
raise ValueError(
f'Unable to find argnum={max_argnum}, was given {len(args)} args.'
)
batch_args = [args[i] for i in batch_argnums]
batch_axis_sizes = set(
jax.tree.flatten(jax.tree.map(lambda x: x.shape[0], batch_args))[0]
)
if len(batch_axis_sizes) > 1:
raise ValueError(
'Batch axis must have the same size for all inputs in batch_argnums, '
f'got {batch_axis_sizes}.'
)
def _normalize_fun_to_return_aux(fun, has_aux):
if has_aux:
return fun
else:
return lambda *args, **kwargs: (fun(*args, **kwargs), ())
def _num_real_microbatches(
is_padding_example: jax.Array,
microbatch_size: int | None,
) -> int | jax.Array:
"""Calculates the number of non-padding microbatches.
The returned result is 1 + the index of the last microbatch that contains at
least one non-padding example. This means that microbatches consisting of
all-padding examples that do not appear at the end will be treated as a real
microbatch.
Args:
is_padding_example: A 1D array of shape (num_examples,).
microbatch_size: Argument passed to `microbatch`.
Returns:
The `true` batch size, as a scalar jax array.
"""
if microbatch_size is None:
return is_padding_example.shape[0]
reshaped = microbatching.reshape_batch_axis(
is_padding_example, microbatch_size
)
# Ensure there is at least one True in the array.
is_real_batch = jnp.append(True, ~reshaped.all(axis=1))
# We want the last real microbatch, argmax returns the first True value,
# so we add increasing numbers from 0 to 1 to each index.
return jnp.argmax(is_real_batch + jnp.linspace(0, 1, is_real_batch.size))
def _maybe_squeeze_axis_1(x: jax.Array) -> jax.Array:
"""Squeezes the second axis if it is of size 1.
We need a guard check because some leaves in the pytree might be scalars or
have only 1 dimension (e.g., if the model reduces over the batch dimension of
size 1 to return a single loss scalar). In JAX, calling jnp.squeeze(axis=1) on
a scalar or 1D array throws an out-of-bounds error. Additionally, jnp.squeeze
errors out if the specified axis is not of size 1. This check ensures we only
squeeze valid 2D+ arrays that actually have size 1 at index 1.
Note that this will also squeeze axis 1 for auxiliary outputs that naturally
have a shape like `(Batch, 1, ...)`. Callers should be aware of this potential
side effect.
Args:
x: The input array.
Returns:
The input array with the second axis squeezed if it is of size 1.
"""
if hasattr(x, 'shape') and len(x.shape) >= 2 and x.shape[1] == 1:
return jnp.squeeze(x, axis=1)
return x
[docs]
def clipped_fun(
fun: Callable,
has_aux: bool = False,
*,
batch_argnums: int | Sequence[int] = 0,
keep_batch_dim: bool = True,
l2_clip_norm: float = 1.0,
rescale_to_unit_norm: bool = False,
normalize_by: float = 1.0,
return_norms: bool = False,
microbatch_size: int | None = None,
nan_safe: bool = True,
dtype: jax.typing.DTypeLike | None = None,
prng_argnum: int | None = None,
spmd_axis_name: str | None = None,
grid_scale: int | None = None,
) -> BoundedSensitivityCallable:
"""Transforms a function to clip its output and sum across a batch.
Example Usage:
>>> data = jnp.array([0, 1, 2, 3, 4, 5])
>>> clipped_mean = clipped_fun(jnp.mean, l2_clip_norm=1.0)
>>> clipped_mean(data)
Array(5., dtype=float32)
Formal Guarantees:
For the first function output:
The L2 sensitivity of the returned function with respect to the batch
arguments (specified by `batch_argnums`) under add/remove or zero-out
differential privacy definitions is guaranteed to be 1.0 if
`rescale_to_unit_norm` is True. Otherwise, the sensitivity is
`l2_clip_norm`. Under replace-one DP, the sensitivity is doubled
(2.0 or 2 * `l2_clip_norm`).
Extra auxiliary outputs (aux, norms) are per-example. This function
guarantees that per-example outputs only depend the data for the same
example. This allows maximum flexibility for the caller to aggregate
these as desired (possibly with a DP mean, median, quantile, or histogram
mechanism).
Args:
fun: The function to be clipped.
has_aux: If True, `fun` is expected to return a tuple `(value, aux)`. Only
the value will be clipped + aggregated, `aux` will be returned on a
per-example basis. Exercise caution when using this as the sensitivity
guarantees of the returned Callable are only provided w.r.t. `value`.
batch_argnums: Specifies which argument(s) of `fun` contain the batch
dimension. All arguments specified here must have the same size along the
0th axis.
keep_batch_dim: If True, batch inputs will be passed to `fun` with a leading
batch axis of size 1. If False, this size 1 axis will be dropped
(reducing the rank of the batch args by 1 before passing to `fun`).
l2_clip_norm: The maximum L2 norm allowed.
rescale_to_unit_norm: If True, the output PyTree's norm is rescaled by `1.0
/ clip_norm` after potential clipping. If False, the output PyTree has
norm at most `clip_norm`.
normalize_by: Divide the clipped output by this value before returning.
return_norms: If True, the returned Callable will return the l2_norms of the
per-example values before clipping. These values should be handled with
care, see the formal guarantees above.
microbatch_size: If set, the batch is split up into microbatches of this
size. These microbatches are then processed sequentially, with operations
on the groups within each microbatch being vectorized using `vmap`. This
can be used to reduce peak memory usage at the cost of increased
sequential computation.
nan_safe: If True, the formal guarantees of the returned Callable still
holds in the presence of NaNs and infs. See `clip_pytree` for more details
on this argument.
dtype: Optional dtype for the clipped+aggregated PyTree. If None, the dtype
will be the same as the dtypes of the function output. Can be useful to
avoid overflow issues when using low-precision dtypes as the transformed
function computes a sum over a potentially large batch.
prng_argnum: If set, specifies which argument of `fun` is a PRNG key. The
PRNG will be split to have a batch dimension and vmapped over.
spmd_axis_name: See jax.vmap.
grid_scale: If set, per-example outputs are additionally scaled and rounded
to an integer grid after clipping. Specifically, each clipped output is
multiplied by ``grid_scale / l2_clip_norm``, rounded to the nearest
integer, and cast to ``jnp.int64``. The clipping norm is tightened
automatically so that the integer L2 norm of each rounded output is at
most ``grid_scale``. This option is designed for use with the discrete
Gaussian mechanism. Incompatible with ``rescale_to_unit_norm=True`` and
``normalize_by != 1.0``. When set, ``dtype`` is ignored (output is always
``jnp.int64``).
Returns:
A new function `clip_fn` that clips the output of `fun` and sums across
the batch. `clip_fn` takes the same arguments as `fun`. The exact output
signature depends on `has_aux` and `return_norms`:
| `has_aux` | `return_norms` | `clipped_fn` returns |
| :-------- | :--------------| :-------------------- |
| `False` | `False` | `value` |
| `True` | `False` | `value, aux` |
| `False` | `True` | `value, norms` |
| `True` | `True` | `value, (aux, norms)` |
"""
if isinstance(batch_argnums, int):
batch_argnums = (batch_argnums,)
if grid_scale is not None:
if rescale_to_unit_norm:
raise ValueError(
'rescale_to_unit_norm is not compatible with grid_scale.'
)
if normalize_by != 1.0:
raise ValueError(
'normalize_by is not compatible with grid_scale. Normalization '
'should be applied after noise addition.'
)
_check_x64_enabled()
fun = _normalize_fun_to_return_aux(fun, has_aux)
def clipped_fn(*args, **kwargs):
_validate_batch_args(batch_argnums, args)
is_padding_example = kwargs.get('is_padding_example', None)
batch_size = jax.tree.leaves(args[batch_argnums[0]])[0].shape[0]
if is_padding_example is None:
is_padding_example = jnp.zeros(batch_size, dtype=jnp.bool_)
kwargs['is_padding_example'] = is_padding_example
def clipped_fun_one_group(*args, is_padding_example, **kwargs):
value, aux = fun(*args, **kwargs)
if grid_scale is not None:
clipped_value, l2_norm = clip_and_round_to_grid(
value,
l2_clip_norm,
grid_scale,
nan_safe=nan_safe,
return_zero=is_padding_example,
)
else:
value = optax.tree.cast(value, dtype)
clipped_value, l2_norm = clip_pytree(
value,
clip_norm=l2_clip_norm,
rescale_to_unit_norm=rescale_to_unit_norm,
nan_safe=nan_safe,
# See https://arxiv.org/pdf/2411.04205 for info on why this is
# useful.
return_zero=is_padding_example,
)
return clipped_value, aux, l2_norm
num_real_mb = _num_real_microbatches(is_padding_example, microbatch_size)
sum_ = microbatching.AccumulationType.SUM
concat = microbatching.AccumulationType.CONCAT
axes = [0 if i in batch_argnums else None for i in range(len(args))]
if prng_argnum is not None:
args = list(args)
rngs = args[prng_argnum]
split_rngs = jax.tree.map(lambda x: jax.random.split(x, batch_size), rngs)
args[prng_argnum] = split_rngs
axes[prng_argnum] = 0
microbatched_vmap_fun = microbatching.micro_vmap(
clipped_fun_one_group,
in_axes=axes,
microbatch_size=microbatch_size,
accumulator=(sum_, concat, concat),
num_real_microbatches=num_real_mb,
vmap_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name),
)
clipped_values, aux, norms = microbatched_vmap_fun(*args, **kwargs)
if keep_batch_dim:
# If keep_batch_dim is True, we artificially added a dimension of size 1
# to the batch arguments before passing them to the vmap'ed function.
# While vmap and micro_vmap take a single example's slice, it preserves
# the output shapes of the inner function. If the inner function (e.g.
# your model) returns something that still contains this size 1 batch
# dimension, vmap will stack these outputs (not concatenate), resulting
# in a shape like (B, 1, ...). If the caller expects standard shapes
# (B, ...), this extra axis at index 1 can cause broadcasting issues
# downstream (e.g. in Keras metrics). To fix this, we squeeze the axis 1
# of aux outputs.
aux = jax.tree.map(_maybe_squeeze_axis_1, aux)
if normalize_by != 1.0:
clipped_values = jax.tree.map(lambda x: x / normalize_by, clipped_values)
match has_aux, return_norms:
case False, False:
return clipped_values
case False, True:
return clipped_values, norms
case True, False:
return clipped_values, aux
case True, True:
return clipped_values, (aux, norms)
if grid_scale is not None:
norm_bound = float(grid_scale)
else:
norm_bound = (1.0 if rescale_to_unit_norm else l2_clip_norm) / normalize_by
if keep_batch_dim:
clipped_fn = _with_extra_batch_axis(clipped_fn, batch_argnums)
callable_has_aux = has_aux or return_norms
return BoundedSensitivityCallable(clipped_fn, norm_bound, callable_has_aux)
def _validate_static_args(argnums, batch_argnums, normalize_by):
"""Validates the argnums and batch_argnums inputs are compatible."""
if normalize_by <= 0.0:
raise ValueError(f'normalize_by must be > 0, got {normalize_by}.')
if isinstance(argnums, int):
argnums = (argnums,)
if isinstance(batch_argnums, int):
batch_argnums = (batch_argnums,)
if not batch_argnums:
raise ValueError('Batch Argnums must not be empty.')
if min(argnums + batch_argnums) < 0:
raise ValueError(
f'argnums={argnums} and batch_argnums={batch_argnums} must be >= 0.'
)
shared_argnums = set(argnums) & set(batch_argnums)
if shared_argnums:
raise ValueError(
'Cannot compute clipped gradients for argnums that have a batch axis. '
f'{argnums=} and {batch_argnums=} with overlap {list(shared_argnums)}.'
)
[docs]
def clipped_grad(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
*,
l2_clip_norm: float,
rescale_to_unit_norm: bool = False,
normalize_by: float = 1.0,
batch_argnums: int | Sequence[int] = 1,
keep_batch_dim: bool = True,
return_values: bool = False,
return_grad_norms: bool = False,
pre_clipping_transform: Callable[[PyTree], PyTree] = lambda x: x,
microbatch_size: int | None = None,
nan_safe: bool = True,
dtype: jax.typing.DTypeLike | None = None,
prng_argnum: int | None = None,
spmd_axis_name: str | None = None,
grid_scale: int | None = None,
) -> BoundedSensitivityCallable:
"""Create a function to compute the sum of clipped gradients of fun.
This function acts as a transformation similar to `jax.grad`, but with added
functionality for gradient clipping applied on a per-example (or per-group)
basis before summation. It computes the gradient of `fun` with respect to
`argnums`, calculates the L2 norm of the gradient for each example slice
along the first axis of the `batch_argnums` args, clips each per-example
gradient to have a norm of at most `l2_clip_norm`, and finally sums these
clipped gradients.
Non-grad outputs of the returned function (aux, values, grad_norms) may
optionally be returned by setting the arguments `has_aux`,
`return_values`, and/or `return_grad_norms` to True. These outputs are
per-example, and hence have a batch axis. It is up to the caller to handle
these as necessary. See the `DP Sensitivity Guarantee` below for more details
on this design choice.
Example Usage:
>>> import jax.numpy as jnp
>>> f = lambda param, data: 0.5 * jnp.mean((data - param)**2)
>>> g = clipped_grad(f, l2_clip_norm=jnp.inf)
>>> g(3.0, jnp.array([0, 7, -2]))
Array(4., dtype=float32)
Example Usage (with Auxiliary Output):
>>> g = clipped_grad(
... f, l2_clip_norm=jnp.inf, return_values=True, return_grad_norms=True
... )
>>> _, aux = g(3.0, jnp.array([0, 7, -2]))
>>> aux.values
Array([ 4.5, 8. , 12.5], dtype=float32)
>>> aux.grad_norms
Array([3., 4., 5.], dtype=float32)
Example Usage (with Per-User Clipping):
>>> f = lambda param, data: 0.5 * jnp.mean((data - param)**2)
>>> g = clipped_grad(f, l2_clip_norm=jnp.inf, keep_batch_dim=False)
>>> userA = jnp.array([1, -1])
>>> userB = jnp.array([2, 2])
>>> userC = jnp.array([0, 3])
>>> g(3.0, jnp.array([userA, userB, userC]))
Array(5.5, dtype=float32)
Formal Guarantees:
For the gradient output:
The L2 sensitivity of the returned function with respect to the batch
arguments (specified by `batch_argnums`) under add/remove or zero-out
differential privacy definitions is guaranteed to be 1.0 if
`rescale_to_unit_norm` is True. Otherwise, the sensitivity is
`l2_clip_norm`. Under replace-one DP, the sensitivity is doubled
(2.0 or 2 * `l2_clip_norm`).
All auxiliary outputs (aux, values, grad_norms) are per-example. This
function guarantees that per-example outputs only depend the data for the
same example. This allows maximum flexibility for the caller to aggregate
these as desired (possibly with a DP mean, median, quantile, or histogram
mechanism).
Args:
fun: The function to be differentiated, which should return a scalar loss
value. If `has_aux` is True, it should return a tuple `(value, aux)`.
argnums: Specifies which argument(s) of `fun` to differentiate with respect
to. Can be an integer or a sequence of integers. These arguments should
*not* have a batch dimension.
has_aux: If True, `fun` is expected to return a tuple `(value, aux)`. The
auxiliary data `aux` will be returned by the transformed function.
Exercise caution when using this as no DP sensitivity guarantees are
provided for the auxiliary data.
l2_clip_norm: The maximum L2 norm for each per-example gradient. Gradients
with a norm larger than this value will be scaled down.
rescale_to_unit_norm: If True, clipped gradients are rescaled by `1.0 /
l2_clip_norm`. This ensures the sensitivity is 1.0. If False, they are
only scaled down if their norm exceeds `l2_clip_norm`, resulting in a
sensitivity of `l2_clip_norm`. The motivation for setting this to True is
to decouple the clipping norm from the learning rate for non-adaptive
optimizers, as described in https://arxiv.org/abs/2204.13650.
normalize_by: Divide the clipped output by this value before returning.
batch_argnums: Specifies which argument(s) of `fun` contain the batch
dimension (usually the data and labels). Can be an integer or a sequence
of integers. All arguments specified here must have the same size along
their first dimension (the batch dimension). The default value of 1
assumes the signature of fun is `fun(params, batch)`.
keep_batch_dim: If True, batch inputs will be passed to `fun` with a leading
batch axis of size 1. If False, this size 1 axis will be dropped
(reducing the rank of the batch args by 1 before passing to `fun`). The
default value of True assumes that `fun` expects inputs with a batch axis.
Overriding this default can be useful if fun defines the loss function for
a single example, or if clipping should be applied at the group or user
level (in which case an extra batch axis is added to the inputs).
return_values: If True, the transformed function will also return the
per-example values, before clipping.
return_grad_norms: If True, the transformed function will also return the
per-example gradient norms, before clipping.
pre_clipping_transform: An optional function to apply to the per-example
gradients before clipping. The function should consume the gradient pytree
for a single example and returned a new pytree (possibly with different
structure). Can be used to e.g., scale the leaves of the pytree to
accommodate preconditioner clipping. Does not affect the sensitivity
guarantee.
microbatch_size: If set, input groups are formed into microbatches of this
size. These microbatches are then processed sequentially, with operations
on the groups within each microbatch being vectorized using `vmap`. This
can be used to reduce peak memory usage at the cost of increased
sequential computation. Microbatching will be at the level of
users/groups. E.g., if there are 500 users, with 7 examples per user, and
microbatch_size=100, then the input will be broken into 5 microbatches of
100 users, and when processing a microbatch, `fun` will be invoked 100
times (in parallel with vmap) on groups of 7 examples.
nan_safe: If True, the formal guarantees of the returned Callable still
holds in the presence of NaNs and infs. See `clip_pytree` for more details
on this argument.
dtype: Optional dtype for the returned gradient. If None, the dtype will be
the same as the dtypes of the gradient function. Can be useful to avoid
overflow issues when using low-precision dtypes as the returned function
computes a sum over a potentially large batch.
prng_argnum: If set, specifies which argument of `fun` is a PRNG key. The
PRNG will be split to have a batch dimension and vmapped over.
spmd_axis_name: See jax.vmap. Only relevant in distributed settings.
grid_scale: If set, per-example grads are additionally scaled and rounded to
an integer grid after clipping. Specifically, each clipped grad is
multiplied by ``grid_scale / l2_clip_norm``, rounded to the nearest
integer, and cast to ``jnp.int64``. The clipping norm is tightened
automatically so that the integer L2 norm of each rounded output is at
most ``grid_scale``. This option is designed for use with the discrete
Gaussian mechanism. Incompatible with ``rescale_to_unit_norm=True`` and
``normalize_by != 1.0``. When set, ``dtype`` is ignored (output is always
``jnp.int64``).
Returns:
A new function `values_and_clipped_grad_fn` that computes the sum of clipped
per-group gradients of `fun`. The returned function returns `grad`
if return_values = return_grad_norms = has_aux = False. Otherwise, it
returns a tuple of grad, AuxiliaryOutput, where AuxiliaryOutput is a
namedtuple with optional fields (values, grad_norms, aux) containing the
per-example values, gradient norms, and auxiliary data, respectively.
"""
_validate_static_args(argnums, batch_argnums, normalize_by)
fun = _normalize_fun_to_return_aux(fun, has_aux)
value_and_grad_fn = jax.value_and_grad(fun, argnums, has_aux=True)
def grad_fn(*args, **kwargs):
value_and_aux, grad = value_and_grad_fn(*args, **kwargs)
result = pre_clipping_transform(grad)
if has_aux or return_values or return_grad_norms:
aux = AuxiliaryOutput(
values=value_and_aux[0] if return_values else None,
grad_norms=optax.tree.norm(grad) if return_grad_norms else None,
aux=value_and_aux[1] if has_aux else None,
)
return result, aux
return result
return clipped_fun(
grad_fn,
has_aux=has_aux or return_values or return_grad_norms,
batch_argnums=batch_argnums,
l2_clip_norm=l2_clip_norm,
keep_batch_dim=keep_batch_dim,
rescale_to_unit_norm=rescale_to_unit_norm,
normalize_by=normalize_by,
microbatch_size=microbatch_size,
nan_safe=nan_safe,
dtype=dtype,
prng_argnum=prng_argnum,
spmd_axis_name=spmd_axis_name,
grid_scale=grid_scale,
)