jax_privacy.clipping.BoundedSensitivityCallable
- class jax_privacy.clipping.BoundedSensitivityCallable(fun, l2_norm_bound, has_aux)[source]
Bases:
objectCallable with a sensitivity property.
If has_aux is False, the sensitivity guarantee holds for the entire output which may be an arbitrary PyTree of JAX Arrays. If has_aux is True, the output of the function is a pair (value, aux) and the sensitivity guarantee only holds for value PyTree. The aux PyTree is returned on a per-example basis (i.e., as a PyTree of arrays having a batch axis). The caller should handle the aux output with care w.r.t. DP guarantees, should they be needed.
Methods
__init__Returns the L2 sensitivity of the Callable.
Attributes
-
fun:
Callable[...,Any]
-
l2_norm_bound:
float
-
has_aux:
bool
- sensitivity(neighboring_relation=NeighboringRelation.REPLACE_SPECIAL)[source]
Returns the L2 sensitivity of the Callable.
The L2 sensitivity is defined with respect to the given neighboring relation and the unit of privacy implied by the function that created this instance.
- Parameters:
neighboring_relation (
NeighboringRelation) – The neighboring relation to consider.- Returns:
The L2 sensitivity of the Callable.
-
fun: