jax_privacy.clipping.clipped_fun
- jax_privacy.clipping.clipped_fun(fun, has_aux=False, *, batch_argnums=0, keep_batch_dim=True, l2_clip_norm=1.0, rescale_to_unit_norm=False, normalize_by=1.0, return_norms=False, microbatch_size=None, nan_safe=True, dtype=None, prng_argnum=None, spmd_axis_name=None, grid_scale=None)[source]
Transforms a function to clip its output and sum across a batch.
- Example Usage:
>>> data = jnp.array([0, 1, 2, 3, 4, 5]) >>> clipped_mean = clipped_fun(jnp.mean, l2_clip_norm=1.0) >>> clipped_mean(data) Array(5., dtype=float32)
- Formal Guarantees:
- For the first function output:
The L2 sensitivity of the returned function with respect to the batch arguments (specified by batch_argnums) under add/remove or zero-out differential privacy definitions is guaranteed to be 1.0 if rescale_to_unit_norm is True. Otherwise, the sensitivity is l2_clip_norm. Under replace-one DP, the sensitivity is doubled (2.0 or 2 * l2_clip_norm).
- Extra auxiliary outputs (aux, norms) are per-example. This function
guarantees that per-example outputs only depend the data for the same example. This allows maximum flexibility for the caller to aggregate these as desired (possibly with a DP mean, median, quantile, or histogram mechanism).
- Parameters:
fun (
Callable) – The function to be clipped.has_aux (
bool) – If True, fun is expected to return a tuple (value, aux). Only the value will be clipped + aggregated, aux will be returned on a per-example basis. Exercise caution when using this as the sensitivity guarantees of the returned Callable are only provided w.r.t. value.batch_argnums (
int|Sequence[int]) – Specifies which argument(s) of fun contain the batch dimension. All arguments specified here must have the same size along the 0th axis.keep_batch_dim (
bool) – If True, batch inputs will be passed to fun with a leading batch axis of size 1. If False, this size 1 axis will be dropped (reducing the rank of the batch args by 1 before passing to fun).l2_clip_norm (
float) – The maximum L2 norm allowed.rescale_to_unit_norm (
bool) – If True, the output PyTree’s norm is rescaled by 1.0 / clip_norm after potential clipping. If False, the output PyTree has norm at most clip_norm.normalize_by (
float) – Divide the clipped output by this value before returning.return_norms (
bool) – If True, the returned Callable will return the l2_norms of the per-example values before clipping. These values should be handled with care, see the formal guarantees above.microbatch_size (
int|None) – If set, the batch is split up into microbatches of this size. These microbatches are then processed sequentially, with operations on the groups within each microbatch being vectorized using vmap. This can be used to reduce peak memory usage at the cost of increased sequential computation.nan_safe (
bool) – If True, the formal guarantees of the returned Callable still holds in the presence of NaNs and infs. See clip_pytree for more details on this argument.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – Optional dtype for the clipped+aggregated PyTree. If None, the dtype will be the same as the dtypes of the function output. Can be useful to avoid overflow issues when using low-precision dtypes as the transformed function computes a sum over a potentially large batch.prng_argnum (
int|None) – If set, specifies which argument of fun is a PRNG key. The PRNG will be split to have a batch dimension and vmapped over.spmd_axis_name (
str|None) – See jax.vmap.grid_scale (
int|None) – If set, per-example outputs are additionally scaled and rounded to an integer grid after clipping. Specifically, each clipped output is multiplied bygrid_scale / l2_clip_norm, rounded to the nearest integer, and cast tojnp.int64. The clipping norm is tightened automatically so that the integer L2 norm of each rounded output is at mostgrid_scale. This option is designed for use with the discrete Gaussian mechanism. Incompatible withrescale_to_unit_norm=Trueandnormalize_by != 1.0. When set,dtypeis ignored (output is alwaysjnp.int64).
- Return type:
- Returns:
A new function clip_fn that clips the output of fun and sums across the batch. clip_fn takes the same arguments as fun. The exact output signature depends on has_aux and return_norms:
has_aux | return_norms | clipped_fn returns |:——– | :————–| :——————– |False | False | value |True | False | value, aux |False | True | value, norms |True | True | value, (aux, norms) |