jax_privacy.experimental.training.train
- jax_privacy.experimental.training.train(plan, dataset, loss_fn, params, optimizer, padding_multiple=1, callback=None, rng=None)[source]
Runs an end-to-end differentially private training loop.
Sharding: This function does not shard params or data. For multi-device training, provide
paramswith explicit sharding annotations and configurespmd_axis_namethrough the plan’sPerformanceFlags. If data sharding is needed,loss_fnshould reshard its inputs using sharding-in-types.- Parameters:
plan (
DPExecutionPlan) – ADPExecutionPlanspecifying the DP mechanism.dataset (
Union[Array,ndarray,bool,number,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]]) – The training dataset, as a PyTree of arrays.loss_fn (
LossFn) – The per-example loss function. SeeLossFn.params (
Union[Array,ndarray,bool,number,Iterable[chex.ArrayTree],Mapping[Any, chex.ArrayTree]]) – Initial parameter PyTree.optimizer (
AugmentedGradientTransformation|GradientTransformation) – AnAugmentedGradientTransformationor a plainoptax.GradientTransformation.padding_multiple (
int) – If set, batch sizes are padded to a multiple of this value.callback (
Callable[[int,TrainingState,Aux],None] |None) – Called after each step ascallback(step, state, aux).stepis a Python int.rng (
Generator|int|None) – Optional random seed ornumpy.random.Generatorfor reproducibility.
- Return type:
- Returns:
Final
TrainingState.