jax_privacy.optimizers.AugmentedGradientTransformation

class jax_privacy.optimizers.AugmentedGradientTransformation(init: Callable[[Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]], Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]], update: Callable[[...], tuple[Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')], Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]]], pre_clipping_transform: Callable[[Array | ndarray | bool | number | bool | int | float | complex | Iterable[TypeAliasForwardRef('chex.ArrayTree')] | Mapping[Any, TypeAliasForwardRef('chex.ArrayTree')]], PreClippingTransform])[source]

Bases: NamedTuple

A gradient transformation augmented with a pre-clipping transform.

This extends the standard optax.GradientTransformation interface with a pre_clipping_transform field that, given the current optimizer state, returns a pre_clipping_transform specifying how to transform per-example gradients before and after the clipping/noising step.

The update function expects to receive gradients that have already been transformed by pre_clipping_transform(…), clipped, summed, and noised. It will internally apply the inverse transform before delegating to the base optimizer’s update. See the module docstring for an example usage.

Variables:
  • init – Initializes the optimizer state given initial parameters. Matches the optax.GradientTransformation.init API: init(params) -> state

  • update – Computes parameter updates from noisy gradients. The noisy gradients should be in the scaled space (transform -> clip -> aggregate -> noise). This function applies the inverse transform internally before calling the base optimizer’s update: update(updates, state, params=None) -> (updates, new_state)

  • pre_clipping_transform – Given the current optimizer state, returns a pre_clipping_transform function intended to be used with jax_privacy.clipped_grad. It consumes a pytree with structure matching the parameters and returns a transformed pytree. The transformed pytree may or may not have the same structure. The update function is responsible for mapping the input back to the original structure.

Create new instance of AugmentedGradientTransformation(init, update, pre_clipping_transform)

Methods

__init__

count

Return number of occurrences of value.

index

Return first index of value.

Attributes

init

Alias for field number 0

pre_clipping_transform

Alias for field number 2

update

Alias for field number 1

init: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]

Alias for field number 0

update: Callable[..., tuple[Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]

Alias for field number 1

pre_clipping_transform: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], PreClippingTransform]

Alias for field number 2