jax_privacy.optimizers.PreClippingTransform

class jax_privacy.optimizers.PreClippingTransform(*args, **kwargs)[source]

Bases: Protocol

A function that applies a transformation to a pytree of updates.

Methods

__init__

__call__(updates, inverse=False)[source]

Call self as a function.

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]