jax_privacy.keras_api

API for adding DP-SGD to a Keras model.

Example Usage:

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
from jax_privacy import keras_api

model = keras.Sequential([
    keras.Input(shape=(1,)),
    keras.layers.Dense(1),
])
params = keras_api.DPKerasConfig(
    epsilon=1.0,
    delta=1e-5,
    clipping_norm=1.0,
    batch_size=8,
    gradient_accumulation_steps=1,
    train_steps=10,
    train_size=80,
    noise_multiplier=1.0,
)
private_model = keras_api.make_private(model, params)
private_model.get_noise_multiplier()

Functions

get_noise_multiplier(model)

Returns the noise multiplier used for DP-SGD training.

make_private(model, params)

Adds DP-SGD training to a Keras model without modifying its API.

Classes

DPKerasConfig(epsilon, delta, clipping_norm, ...)

Parameters for adding DP-SGD to a Keras model.