jax_privacy.clipping

Utilities for clipping function outputs and aggregating across a batch.

Functions

clip_and_round_to_grid(gradient, ...[, ...])

Clips a gradient and rounds it to an integer grid.

clip_pytree(pytree, clip_norm[, ...])

Clips a PyTree of jax arrays based on its global L2 norm.

clipped_fun(fun[, has_aux, batch_argnums, ...])

Transforms a function to clip its output and sum across a batch.

clipped_grad(fun[, argnums, has_aux, ...])

Create a function to compute the sum of clipped gradients of fun.

Classes

AuxiliaryOutput

alias of Aux

BoundedSensitivityCallable(fun, ...)

Callable with a sensitivity property.