jax_privacy.clipping.clip_and_round_to_grid

jax_privacy.clipping.clip_and_round_to_grid(gradient, l2_clip_norm, grid_scale, *, nan_safe=True, return_zero=False)[source]

Clips a gradient and rounds it to an integer grid.

This function performs four operations on a single (unbatched) gradient pytree:

  1. Computes a tighter (adjusted) clip norm to account for the L2 norm increase caused by rounding.

  2. Clips the gradient to have L2 norm at most the adjusted clip norm.

  3. Scales the gradient by grid_scale / l2_clip_norm.

  4. Rounds each coordinate to the nearest integer.

After rounding, the integer vector has L2 norm at most grid_scale.

This function is designed to be used with jax.vmap to process a batch of per-example gradients.

Parameters:
  • gradient (Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]) – A pytree of gradient arrays for a single example.

  • l2_clip_norm (float) – The desired L2 clip norm (in original units).

  • grid_scale (int) – Number of integer grid steps corresponding to l2_clip_norm.

  • nan_safe (bool) – If True, NaNs and +/- infs are converted to 0 before clipping. See clip_pytree for details.

  • return_zero (bool) – If True, the output is guaranteed to be zero regardless of inputs. See clip_pytree for details.

Return type:

tuple[Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]], Array]

Returns:

A tuple (rounded, l2_norm) where rounded is a pytree of int64 arrays representing the rounded gradient on the integer grid, and l2_norm is the L2 norm of the input gradient before clipping.

Raises:

ValueError – If grid_scale is too small relative to the number of parameters (the rounding error would exceed the clip norm).