jax_privacy.clipping.clipped_grad

jax_privacy.clipping.clipped_grad(fun, argnums=0, has_aux=False, *, l2_clip_norm, rescale_to_unit_norm=False, normalize_by=1.0, batch_argnums=1, keep_batch_dim=True, return_values=False, return_grad_norms=False, pre_clipping_transform=<function <lambda>>, microbatch_size=None, nan_safe=True, dtype=None, prng_argnum=None, spmd_axis_name=None, grid_scale=None)[source]

Create a function to compute the sum of clipped gradients of fun.

This function acts as a transformation similar to jax.grad, but with added functionality for gradient clipping applied on a per-example (or per-group) basis before summation. It computes the gradient of fun with respect to argnums, calculates the L2 norm of the gradient for each example slice along the first axis of the batch_argnums args, clips each per-example gradient to have a norm of at most l2_clip_norm, and finally sums these clipped gradients.

Non-grad outputs of the returned function (aux, values, grad_norms) may optionally be returned by setting the arguments has_aux, return_values, and/or return_grad_norms to True. These outputs are per-example, and hence have a batch axis. It is up to the caller to handle these as necessary. See the DP Sensitivity Guarantee below for more details on this design choice.

Example Usage:
>>> import jax.numpy as jnp
>>> f = lambda param, data: 0.5 * jnp.mean((data - param)**2)
>>> g = clipped_grad(f, l2_clip_norm=jnp.inf)
>>> g(3.0, jnp.array([0, 7, -2]))
Array(4., dtype=float32)
Example Usage (with Auxiliary Output):
>>> g = clipped_grad(
...   f, l2_clip_norm=jnp.inf, return_values=True, return_grad_norms=True
... )
>>> _, aux = g(3.0, jnp.array([0, 7, -2]))
>>> aux.values
Array([ 4.5,  8. , 12.5], dtype=float32)
>>> aux.grad_norms
Array([3., 4., 5.], dtype=float32)
Example Usage (with Per-User Clipping):
>>> f = lambda param, data: 0.5 * jnp.mean((data - param)**2)
>>> g = clipped_grad(f, l2_clip_norm=jnp.inf, keep_batch_dim=False)
>>> userA = jnp.array([1, -1])
>>> userB = jnp.array([2, 2])
>>> userC = jnp.array([0, 3])
>>> g(3.0, jnp.array([userA, userB, userC]))
Array(5.5, dtype=float32)
Formal Guarantees:
For the gradient output:

The L2 sensitivity of the returned function with respect to the batch arguments (specified by batch_argnums) under add/remove or zero-out differential privacy definitions is guaranteed to be 1.0 if rescale_to_unit_norm is True. Otherwise, the sensitivity is l2_clip_norm. Under replace-one DP, the sensitivity is doubled (2.0 or 2 * l2_clip_norm).

All auxiliary outputs (aux, values, grad_norms) are per-example. This

function guarantees that per-example outputs only depend the data for the same example. This allows maximum flexibility for the caller to aggregate these as desired (possibly with a DP mean, median, quantile, or histogram mechanism).

Parameters:
  • fun (Callable) – The function to be differentiated, which should return a scalar loss value. If has_aux is True, it should return a tuple (value, aux).

  • argnums (int | Sequence[int]) – Specifies which argument(s) of fun to differentiate with respect to. Can be an integer or a sequence of integers. These arguments should not have a batch dimension.

  • has_aux (bool) – If True, fun is expected to return a tuple (value, aux). The auxiliary data aux will be returned by the transformed function. Exercise caution when using this as no DP sensitivity guarantees are provided for the auxiliary data.

  • l2_clip_norm (float) – The maximum L2 norm for each per-example gradient. Gradients with a norm larger than this value will be scaled down.

  • rescale_to_unit_norm (bool) – If True, clipped gradients are rescaled by 1.0 / l2_clip_norm. This ensures the sensitivity is 1.0. If False, they are only scaled down if their norm exceeds l2_clip_norm, resulting in a sensitivity of l2_clip_norm. The motivation for setting this to True is to decouple the clipping norm from the learning rate for non-adaptive optimizers, as described in https://arxiv.org/abs/2204.13650.

  • normalize_by (float) – Divide the clipped output by this value before returning.

  • batch_argnums (int | Sequence[int]) – Specifies which argument(s) of fun contain the batch dimension (usually the data and labels). Can be an integer or a sequence of integers. All arguments specified here must have the same size along their first dimension (the batch dimension). The default value of 1 assumes the signature of fun is fun(params, batch).

  • keep_batch_dim (bool) – If True, batch inputs will be passed to fun with a leading batch axis of size 1. If False, this size 1 axis will be dropped (reducing the rank of the batch args by 1 before passing to fun). The default value of True assumes that fun expects inputs with a batch axis. Overriding this default can be useful if fun defines the loss function for a single example, or if clipping should be applied at the group or user level (in which case an extra batch axis is added to the inputs).

  • return_values (bool) – If True, the transformed function will also return the per-example values, before clipping.

  • return_grad_norms (bool) – If True, the transformed function will also return the per-example gradient norms, before clipping.

  • pre_clipping_transform (Callable[[Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]], Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]]) – An optional function to apply to the per-example gradients before clipping. The function should consume the gradient pytree for a single example and returned a new pytree (possibly with different structure). Can be used to e.g., scale the leaves of the pytree to accommodate preconditioner clipping. Does not affect the sensitivity guarantee.

  • microbatch_size (int | None) – If set, input groups are formed into microbatches of this size. These microbatches are then processed sequentially, with operations on the groups within each microbatch being vectorized using vmap. This can be used to reduce peak memory usage at the cost of increased sequential computation. Microbatching will be at the level of users/groups. E.g., if there are 500 users, with 7 examples per user, and microbatch_size=100, then the input will be broken into 5 microbatches of 100 users, and when processing a microbatch, fun will be invoked 100 times (in parallel with vmap) on groups of 7 examples.

  • nan_safe (bool) – If True, the formal guarantees of the returned Callable still holds in the presence of NaNs and infs. See clip_pytree for more details on this argument.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – Optional dtype for the returned gradient. If None, the dtype will be the same as the dtypes of the gradient function. Can be useful to avoid overflow issues when using low-precision dtypes as the returned function computes a sum over a potentially large batch.

  • prng_argnum (int | None) – If set, specifies which argument of fun is a PRNG key. The PRNG will be split to have a batch dimension and vmapped over.

  • spmd_axis_name (str | None) – See jax.vmap. Only relevant in distributed settings.

  • grid_scale (int | None) – If set, per-example grads are additionally scaled and rounded to an integer grid after clipping. Specifically, each clipped grad is multiplied by grid_scale / l2_clip_norm, rounded to the nearest integer, and cast to jnp.int64. The clipping norm is tightened automatically so that the integer L2 norm of each rounded output is at most grid_scale. This option is designed for use with the discrete Gaussian mechanism. Incompatible with rescale_to_unit_norm=True and normalize_by != 1.0. When set, dtype is ignored (output is always jnp.int64).

Return type:

BoundedSensitivityCallable

Returns:

A new function values_and_clipped_grad_fn that computes the sum of clipped per-group gradients of fun. The returned function returns grad if return_values = return_grad_norms = has_aux = False. Otherwise, it returns a tuple of grad, AuxiliaryOutput, where AuxiliaryOutput is a namedtuple with optional fields (values, grad_norms, aux) containing the per-example values, gradient norms, and auxiliary data, respectively.