jax_privacy.keras_api.make_private

jax_privacy.keras_api.make_private(model, params)[source]

Adds DP-SGD training to a Keras model without modifying its API.

This function mutates model in place, installs the DP-SGD hooks, and returns the same model instance. When params.poisson_sampling_in_fit is enabled, the wrapped fit() path expects random-access per-example arrays or pytrees of arrays so it can perform Poisson sampling internally.

Parameters:
  • model (Model) – The Keras model to add DP-SGD training to.

  • params (DPKerasConfig) – The parameters for DP-SGD training.

Return type:

Model

Returns:

The input model with overloaded methods for DP-SGD training.