jax_privacy.experimental.training.LossFn

class jax_privacy.experimental.training.LossFn(*args, **kwargs)[source]

Bases: Protocol

Expected contract for loss functions used in DP training.

Loss functions must accept params and a data batch, and return (loss, aux). They may optionally accept a PRNG key as a third positional argument for stochastic operations (e.g., dropout).

Any additional context the loss function needs — frozen parameters, model configuration, label smoothing constants, etc. — should be closed over before passing the function to train():

frozen = model.freeze(some_params)
def my_loss(params, data, prng):
    all_params = {**frozen, **params}
    logits = model.apply(all_params, data['x'], rngs={'dropout': prng})
    return cross_entropy(logits, data['y']), {'logits': logits}

training.train(..., loss_fn=my_loss, ...)

Mutable state that persists across steps is intentionally unsupported by this signature. Patterns like batch-norm running statistics or online accumulators that carry state from one step to the next are generally incompatible with differential privacy unless extreme care is taken, and are therefore excluded by design. If you need such patterns, fold the state into params and manage it explicitly.

Example signature:

def loss_fn(params, data, prng):
    ...
    return loss, aux

Methods

__init__

__call__(params, data, prng)[source]

Call self as a function.

Return type:

tuple[Array, Union[Array, ndarray, bool, number, Iterable[chex.ArrayTree], Mapping[Any, chex.ArrayTree]]]