jax_privacy.optimizers.scale_then_privatize
- jax_privacy.optimizers.scale_then_privatize(base_optimizer, eps=1e-08, eps_root=0.0, extract_preconditioner_from_state_fn=None)[source]
Constructs an AugmentedGradientTransformation for scale-then-privatize.
This implements Algorithm 8 from Ganesh, McMahan, Thakurta (2507.01129). The key idea is to use the optimizer’s second-moment estimate v_{t-1} from the previous step to define a non-isotropic geometry for clipping and noising per-example gradients. Specifically:
s_t = 1 / (sqrt(v_{t-1} + eps_root) + eps)
Before clipping, each per-example gradient g is transformed to s_t ⊙ g. After clipping + aggregation + noise addition, the update function applies the inverse (divides by s_t) before passing to the base optimizer’s update.
A large eps or eps_root passed here (but not in the adaptive optimizer’s scaling) will cause all coordinates to be scaled nearly-identically, effectively retrieving no pre-clipping transform. eps or eps_root matching the adaptive optimizer may add large noise in coordinates where the gradient i s small. Ideally, this should parameter should be tuned to tradeoff between these two regimes.
- Parameters:
base_optimizer (
GradientTransformation) – A standard optax.GradientTransformation, typically an adaptive optimizer like optax.adamw(…), optax.adam(…), or any chained transformation containing a scale_by_adam (or similar) component.eps (
float) – A small constant added to the denominator outside the square root when computing the scaling vector s_t. Analogous to the eps parameter in Adam. This also acts as a stability constant to prevent excessively large scaling in coordinates where νv is very small. Corresponds to ε_{s₁} in Algorithm 8 of the paper. See the note above on tuning this parameter.eps_root (
float) – A small constant added to v inside the square root, analogous to eps_root in optax.scale_by_adam. See the note above on tuning this parameter.extract_preconditioner_from_state_fn (
Optional[Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]]],Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]]]]) – A function that takes the optimizer state and returns the second-moment estimate (v) pytree. If None, uses a default implementation that handles common optax adaptive optimizers (Adam, AMSGrad, RMSProp, AdaGrad).
- Return type:
- Returns:
An AugmentedGradientTransformation for the scale-then-privatize technique.