jax_privacy.optimizers.AugmentedGradientTransformation
- class jax_privacy.optimizers.AugmentedGradientTransformation(init: Callable[[Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]], Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]], update: Callable[[...], tuple[Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')], Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]]], pre_clipping_transform: Callable[[Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]], PreClippingTransform])[source]
Bases:
NamedTupleA gradient transformation augmented with a pre-clipping transform.
This extends the standard optax.GradientTransformation interface with a pre_clipping_transform field that, given the current optimizer state, returns a pre_clipping_transform specifying how to transform per-example gradients before and after the clipping/noising step.
The update function expects to receive gradients that have already been transformed by pre_clipping_transform(…), clipped, summed, and noised. It will internally apply the inverse transform before delegating to the base optimizer’s update. See the module docstring for an example usage.
- Variables:
init – Initializes the optimizer state given initial parameters. Matches the optax.GradientTransformation.init API: init(params) -> state
update – Computes parameter updates from noisy gradients. The noisy gradients should be in the scaled space (transform -> clip -> aggregate -> noise). This function applies the inverse transform internally before calling the base optimizer’s update: update(updates, state, params=None) -> (updates, new_state)
pre_clipping_transform – Given the current optimizer state, returns a pre_clipping_transform function intended to be used with jax_privacy.clipped_grad. It consumes a pytree with structure matching the parameters and returns a transformed pytree. The transformed pytree may or may not have the same structure. The update function is responsible for mapping the input back to the original structure.
Create new instance of AugmentedGradientTransformation(init, update, pre_clipping_transform)
Methods
__init__countReturn number of occurrences of value.
indexReturn first index of value.
Attributes
Alias for field number 0
Alias for field number 2
Alias for field number 1
-
init:
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree]]] Alias for field number 0
-
update:
Callable[...,tuple[Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree]],Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree]]]] Alias for field number 1
-
pre_clipping_transform:
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],PreClippingTransform] Alias for field number 2