jax_privacy.clipping.clip_pytree
- jax_privacy.clipping.clip_pytree(pytree, clip_norm, rescale_to_unit_norm=False, nan_safe=True, return_zero=False)[source]
Clips a PyTree of jax arrays based on its global L2 norm.
Calculates the global L2 norm of the input PyTree. If the norm exceeds clip_norm, the PyTree is scaled down to have norm equal to clip_norm. If rescale_to_unit_norm is True, the PyTree is additionally scaled by 1.0 / clip_norm (resulting in a norm of at most 1.0 no matter what clip_norm is). Handles cases where the original norm is zero, or the clip norm is 0 or infinity.
Formal Guarantees:
The output PyTree will have norm at most clip_norm if rescale_to_unit_norm is False, and norm at most 1.0 if it is True.
The output PyTree will have the same structure+dtypes as the input PyTree.
Edge Case Handling:
Case
rescale_to_unit_norm
Output
clip_norm = 0
False
Zero
clip_norm = 0
True
Input / norm, as clip_norm -> 0
clip_norm = inf
False
Unchanged
clip_norm = inf
True
Zero
clip_norm < 0 (static)
Raises ValueError
clip_norm < 0 (dynamic)
Zero
pytree_norm = 0
Unchanged
- Parameters:
pytree (
Union[Array,ndarray,bool,number,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]]) – The PyTree of arrays to clip.clip_norm (
float) – The maximum L2 norm allowed.rescale_to_unit_norm (
bool) – If True, the output PyTree’s norm is rescaled by 1.0 / clip_norm after potential clipping. If False, the output PyTree has norm at most clip_norm.nan_safe (
bool) – If True, NaNs and +/- infs are converted to 0 before clipping. Must be True to preserve the formal guarantees in the presence of NaNs, although it does require potentially additional computation. If False, the NaNs in input PyTree will be preserved in the output PyTree. +/- infs will be converted to NaNs as well.return_zero (
bool) – If True, the output PyTree is guaranteed to be zero no matter what the inputs are. Does not influence the formal guarantees.
- Returns:
A tuple (clipped_pytree, original_l2_norm), where clipped_pytree is the processed PyTree and original_l2_norm is the L2 norm of the input PyTree.