jax_privacy.clipping
Utilities for clipping function outputs and aggregating across a batch.
Functions
|
Clips a gradient and rounds it to an integer grid. |
|
Clips a PyTree of jax arrays based on its global L2 norm. |
|
Transforms a function to clip its output and sum across a batch. |
|
Create a function to compute the sum of clipped gradients of fun. |
Classes
alias of |
|
|
Callable with a sensitivity property. |