jax_privacy.clipping.BoundedSensitivityCallable

class jax_privacy.clipping.BoundedSensitivityCallable(fun, l2_norm_bound, has_aux)[source]

Bases: object

Callable with a sensitivity property.

If has_aux is False, the sensitivity guarantee holds for the entire output which may be an arbitrary PyTree of JAX Arrays. If has_aux is True, the output of the function is a pair (value, aux) and the sensitivity guarantee only holds for value PyTree. The aux PyTree is returned on a per-example basis (i.e., as a PyTree of arrays having a batch axis). The caller should handle the aux output with care w.r.t. DP guarantees, should they be needed.

Methods

__init__

sensitivity

Returns the L2 sensitivity of the Callable.

Attributes

fun

l2_norm_bound

has_aux

fun: Callable[..., Any]
l2_norm_bound: float
has_aux: bool
__call__(*args, **kwargs)[source]

Call self as a function.

sensitivity(neighboring_relation=NeighboringRelation.REPLACE_SPECIAL)[source]

Returns the L2 sensitivity of the Callable.

The L2 sensitivity is defined with respect to the given neighboring relation and the unit of privacy implied by the function that created this instance.

Parameters:

neighboring_relation (NeighboringRelation) – The neighboring relation to consider.

Returns:

The L2 sensitivity of the Callable.