jax_privacy.clipping.clip_and_round_to_grid
- jax_privacy.clipping.clip_and_round_to_grid(gradient, l2_clip_norm, grid_scale, *, nan_safe=True, return_zero=False)[source]
Clips a gradient and rounds it to an integer grid.
This function performs four operations on a single (unbatched) gradient pytree:
Computes a tighter (adjusted) clip norm to account for the L2 norm increase caused by rounding.
Clips the gradient to have L2 norm at most the adjusted clip norm.
Scales the gradient by
grid_scale / l2_clip_norm.Rounds each coordinate to the nearest integer.
After rounding, the integer vector has L2 norm at most
grid_scale.This function is designed to be used with
jax.vmapto process a batch of per-example gradients.- Parameters:
gradient (
Union[Array,ndarray,bool,number,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]]) – A pytree of gradient arrays for a single example.l2_clip_norm (
float) – The desired L2 clip norm (in original units).grid_scale (
int) – Number of integer grid steps corresponding to l2_clip_norm.nan_safe (
bool) – If True, NaNs and +/- infs are converted to 0 before clipping. Seeclip_pytreefor details.return_zero (
bool) – If True, the output is guaranteed to be zero regardless of inputs. Seeclip_pytreefor details.
- Return type:
tuple[Union[Array,ndarray,bool,number,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]],Array]- Returns:
A tuple
(rounded, l2_norm)whereroundedis a pytree of int64 arrays representing the rounded gradient on the integer grid, andl2_normis the L2 norm of the input gradient before clipping.- Raises:
ValueError – If grid_scale is too small relative to the number of parameters (the rounding error would exceed the clip norm).