jax_privacy.optimizers.scale_then_privatize

jax_privacy.optimizers.scale_then_privatize(base_optimizer, eps=1e-08, eps_root=0.0, extract_preconditioner_from_state_fn=None)[source]

Constructs an AugmentedGradientTransformation for scale-then-privatize.

This implements Algorithm 8 from Ganesh, McMahan, Thakurta (2507.01129). The key idea is to use the optimizer’s second-moment estimate v_{t-1} from the previous step to define a non-isotropic geometry for clipping and noising per-example gradients. Specifically:

s_t = 1 / (sqrt(v_{t-1} + eps_root) + eps)

Before clipping, each per-example gradient g is transformed to s_t ⊙ g. After clipping + aggregation + noise addition, the update function applies the inverse (divides by s_t) before passing to the base optimizer’s update.

A large eps or eps_root passed here (but not in the adaptive optimizer’s scaling) will cause all coordinates to be scaled nearly-identically, effectively retrieving no pre-clipping transform. eps or eps_root matching the adaptive optimizer may add large noise in coordinates where the gradient i s small. Ideally, this should parameter should be tuned to tradeoff between these two regimes.

Parameters:
  • base_optimizer (GradientTransformation) – A standard optax.GradientTransformation, typically an adaptive optimizer like optax.adamw(…), optax.adam(…), or any chained transformation containing a scale_by_adam (or similar) component.

  • eps (float) – A small constant added to the denominator outside the square root when computing the scaling vector s_t. Analogous to the eps parameter in Adam. This also acts as a stability constant to prevent excessively large scaling in coordinates where νv is very small. Corresponds to ε_{s₁} in Algorithm 8 of the paper. See the note above on tuning this parameter.

  • eps_root (float) – A small constant added to v inside the square root, analogous to eps_root in optax.scale_by_adam. See the note above on tuning this parameter.

  • extract_preconditioner_from_state_fn (Optional[Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]], Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]]]) – A function that takes the optimizer state and returns the second-moment estimate (v) pytree. If None, uses a default implementation that handles common optax adaptive optimizers (Adam, AMSGrad, RMSProp, AdaGrad).

Return type:

AugmentedGradientTransformation

Returns:

An AugmentedGradientTransformation for the scale-then-privatize technique.