jax_privacy.experimental.training.LossFn
- class jax_privacy.experimental.training.LossFn(*args, **kwargs)[source]
Bases:
ProtocolExpected contract for loss functions used in DP training.
Loss functions must accept
paramsand adatabatch, and return(loss, aux). They may optionally accept a PRNG key as a third positional argument for stochastic operations (e.g., dropout).Any additional context the loss function needs — frozen parameters, model configuration, label smoothing constants, etc. — should be closed over before passing the function to
train():frozen = model.freeze(some_params) def my_loss(params, data, prng): all_params = {**frozen, **params} logits = model.apply(all_params, data['x'], rngs={'dropout': prng}) return cross_entropy(logits, data['y']), {'logits': logits} training.train(..., loss_fn=my_loss, ...)
Mutable state that persists across steps is intentionally unsupported by this signature. Patterns like batch-norm running statistics or online accumulators that carry state from one step to the next are generally incompatible with differential privacy unless extreme care is taken, and are therefore excluded by design. If you need such patterns, fold the state into
paramsand manage it explicitly.Example signature:
def loss_fn(params, data, prng): ... return loss, aux
Methods
__init__