jax_privacy.experimental.training.train

jax_privacy.experimental.training.train(plan, dataset, loss_fn, params, optimizer, padding_multiple=1, callback=None, rng=None)[source]

Runs an end-to-end differentially private training loop.

Sharding: This function does not shard params or data. For multi-device training, provide params with explicit sharding annotations and configure spmd_axis_name through the plan’s PerformanceFlags. If data sharding is needed, loss_fn should reshard its inputs using sharding-in-types.

Parameters:
  • plan (DPExecutionPlan) – A DPExecutionPlan specifying the DP mechanism.

  • dataset (Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]) – The training dataset, as a PyTree of arrays.

  • loss_fn (LossFn) – The per-example loss function. See LossFn.

  • params (Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]) – Initial parameter PyTree.

  • optimizer (AugmentedGradientTransformation | GradientTransformation) – An AugmentedGradientTransformation or a plain optax.GradientTransformation.

  • padding_multiple (int) – If set, batch sizes are padded to a multiple of this value.

  • callback (Callable[[int, TrainingState, Aux], None] | None) – Called after each step as callback(step, state, aux). step is a Python int.

  • rng (Generator | int | None) – Optional random seed or numpy.random.Generator for reproducibility.

Return type:

TrainingState

Returns:

Final TrainingState.