jax_privacy.experimental.training

End-to-end training loop for differentially private training.

This module provides a general-purpose DP training loop driven by a DPExecutionPlan, supporting arbitrary mechanisms.

Functions

train(plan, dataset, loss_fn, params, optimizer)

Runs an end-to-end differentially private training loop.

Classes

LossFn(*args, **kwargs)

Expected contract for loss functions used in DP training.

TrainingState(*, step, params, opt_state, ...)

Container for the state of the training loop.