#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

DP-SGD tutorial using Flax Linen on MNIST

Copyright 2025 DeepMind Technologies Limited.

Welcome to Jax Privacy for Flax Linen! In this tutorial you will learn how to train a simple CNN model in a differentially-private (DP) way using DP-SGD algorithm. We will train our model on the MNIST dataset.

This turorial is based on the official MNIST example for Flax Linen. However, the code is rearranged the same way as in the official MNIST example for Flax NNX.

In the tutorial we highlight the changes we need to make in the official non-DP example to make the model differentially-private.

Install libraries

If one of the following libraries are not installed in your Python environment, use pip to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):

%%capture

!pip install flax
!pip install jaxtyping
!pip install dp_accounting

!pip install git+https://github.com/google-deepmind/jax_privacy.git

Define hyper-parameters

First we define hyper-parameters for our training. These parameters are important for DP-SGD training. See the comments describing what each parameter means.

Note, that in real applications the same hyper-parameters used for non-DP training might be not optimal for DP training. Therefore fine-tuning of these params might be necessary.

With this setup, expected DP-SGD result is ~92% accuracy on the test dataset and expected non-DP result is ~99%.

from jax_privacy.dp_sgd import grad_clipping

# Whether to train model using DP-SGD or not.
# Switch it to False to see the non-DP-SGD results for comparison.
# Expected DP-SGD accuracy: ~92%
# Expected non-DP accuracy: ~99%.
use_dp = True

# Training with DP-SGD might require different values for hyperparameters.
# How many optimization steps to do.
train_steps = 5000 if use_dp else 5000

# How often (number of steps) to evaluate model performance during training.
eval_every = 200

# Batch size for training. In our example, we don't accumulate gradients, so this
# is the batch size for each gradient update.
# In DP-SGD the batch size matters a lot, the bigger batch size the better
# model performance you will get spending the same amount of privacy budget.
batch_size = 256 if use_dp else 128

# Learning rate for the optimizer.
learning_rate = 0.1 if use_dp else 0.1

# Momentum for the optimizer.
momentum = 0.9 if use_dp else 0.9

# DP-SGD parameters.
# Epsilon DP parameter.
epsilon = 1.0

# Delta DP parameter.
delta = 1e-5

# Clipping norm for the gradient vector (i.e. all gradients will have L2 norm at most equal to this value).
# Clipping happens separately for each example in the batch (i.e. before taking sum or mean of gradients).
clipping_norm = 0.1

# Defines how to clip the gradients per each example in the batch.
# It does not affect the results but it allows to do speed/memory trade-offs.
# We use vectorized which requires more memory but is faster because it uses vmap and each example is clipped in parallel.
# There is also UNROLLED which is slower but uses less memory because it uses lax.scan and each example is summed and clipped sequentially.
per_example_grad_method = grad_clipping.VECTORIZED

Load the MNIST dataset

First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance.

No changes related to DP.

import tensorflow as tf  # TensorFlow / `tf.data` operations.
import tensorflow_datasets as tfds  # TFDS to download MNIST.

tf.random.set_seed(0)  # Set the random seed for reproducibility.

train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')

# Train size is important for DP-SGD.
train_size = train_ds.cardinality().numpy()
print(f'Train size: {train_size}')
test_size = test_ds.cardinality().numpy()
print(f'Test size: {test_size}')

train_ds = train_ds.map(
    lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 255,
        'label': sample['label'],
    }
)  # normalize train set
test_ds = test_ds.map(
    lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 255,
        'label': sample['label'],
    }
)  # Normalize the test set.

# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.
train_ds = train_ds.repeat().shuffle(1024)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
train_ds = (
    train_ds.batch(batch_size, drop_remainder=True)
    .take(train_steps)
    .prefetch(1)
)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[4], line 1
----> 1 import tensorflow as tf  # TensorFlow / `tf.data` operations.
      2 import tensorflow_datasets as tfds  # TFDS to download MNIST.
      4 tf.random.set_seed(0)  # Set the random seed for reproducibility.

ModuleNotFoundError: No module named 'tensorflow'

Define the CNN model with Flax Linen

No changes related to DP.

from flax import linen as nn


class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

Create model and Flax Linen train state

The only change is that we create two models: one we will train with DP and the other without to compare the performance.

from flax.training import train_state
from jax import random
import jax.numpy as jnp
import optax


def create_train_state(rng):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)


train_state = create_train_state(random.key(0))

Create DP-SGD gradient computer

At first we find standard deviation of Gaussian noise (noise_multiplier) we have to add to achieve the required privacy guarantees (defined by eps and delta).

Then with noise_multiplier parameter we create DP-SGD GradientComputer. We will use this object to clip gradients and add noise.

from jax_privacy.accounting import accountants, analysis, calibrate
from jax_privacy.dp_sgd import gradients

# Calculate noise_multiplier (stddev) given the privacy budget.
accountant = analysis.DpsgdTrainingAccountant(
    dp_accountant_config=accountants.PldAccountantConfig()
)
noise_multiplier = calibrate.calibrate_noise_multiplier(
    target_epsilon=epsilon,
    accountant=accountant,
    batch_sizes=batch_size,
    num_updates=train_steps,
    num_samples=train_size,
    target_delta=delta,
)
print(f'Noise multiplier {noise_multiplier}')
# Create gradient computer that will clip grads and add noise to them.
gradient_computer = gradients.DpsgdGradientComputer(
    clipping_norm=clipping_norm,
    noise_multiplier=noise_multiplier,
    # Simplifies learning-rate tuning, see https://arxiv.org/abs/2204.13650.
    rescale_to_unit_norm=True,
    per_example_grad_method=per_example_grad_method,
)

Define loss function

No changes related to DP.

def loss_fn(params, state, batch):
  logits = state.apply_fn({'params': params}, batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

Define evaluation function

No changes related to DP.

import jax


@jax.jit
def eval_step(state, batch):
  loss, logits = loss_fn(state.params, state, batch)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['label'])
  return loss, accuracy

Define non-DP train step

We will use it to train non-DP model for comparison.

@jax.jit
def non_dp_train_step(state, batch):
  """Train for a single step."""

  value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = value_and_grad_fn(state.params, state, batch)
  new_state = state.apply_gradients(grads=grads)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['label'])
  return new_state, {'loss': loss, 'accuracy': accuracy}

Define DP-SGD train step

This is the main function where DP-SGD magic happens.

First we define a loss function (train_loss_fn) to pass it to gradient computer. This function has to accept 4 arguments:

  • params: ArrayTree representing model weights

  • network_state: ArrayTree representing additional parameters that control the behavior of the network but are not updated via gradient, e.g. non-trainable parameters of Batch Norm (not used in our example).

  • rng_per_example: PRNGKey to generate random number per each example, e.g. to implement dropout or diffusion models (not used in our example).

  • inputs: ArrayTree repesenting the inputs (each leaf should contain an array with batch dimension)

train_loss_fn is defined inside train_step to capture state and pass it to the standard loss_fn defined above. After calling loss_fn we get loss and logits. The loss we return as it is and logits we put into metrics per-example dictionary. Per-example here means that the results will be stacked over batch dimension. This is exactly what we want to make predictions and calculate the accuracy later. Other options in metrics are scalars_avg (average over batch dimension), scalars_sum (same as average but sum) and scalars_last (take last from per-example results).

Secondly, we calculate clipped gradients by calling gradient_computer.loss_and_clipped_gradients. We pass our train_loss_fn and then arguments to be forward to train_loss_fn. The gradient computer will split the batch on per-example arguments (i.e. batch size will be 1) and call train_loss_fn for each example. The return gradients will be clipped and then averaged. gradient_computer.loss_and_clipped_gradients will return mean loss, network state (ignored in the example) and metrics.

Now we have clipped the gradients and are ready to make an optimizer step. However, to make it differentially-private we also need to add noise to the gradients vector. We do that by calling gradient_computer.add_noise_to_grads where we pass gradient vector, noise PRNGKey, total batch size (i.e. total number of examples accumulated in the gradient since the last optimizer step) and the noise state. gradient_computer.add_noise_to_grads will return gradients vector with added noise, used standard deviation of the noise for monitoring purposes (ignored in our case) and new noise state that we have to save and pass in the next call.

Then the rest of the code is the same as in the usual non-DP training: we calculate updated model weights by applying the calculating gradients with the help of optimizer and calculate the accuracy of predicions for monitoring purposes.

from jax_privacy.dp_sgd import typing as jax_privacy_typing


@jax.jit
def dp_sgd_train_step(state, batch, noise_state, noise_rng):
  """Train for a single step."""

  def train_loss_fn(params, unused_network_state, unused_rng, inputs):
    loss, logits = loss_fn(params, state, inputs)
    metrics = jax_privacy_typing.Metrics(per_example={'logits': logits})
    return loss, (unused_network_state, metrics)

  # ArrayTree representing additional state of the network (not used).
  unused_network_state = {}
  # PRNGKey to generate random number per each example (not used).
  unused_rng = random.PRNGKey(0)
  (loss, (_, metrics)), grads = gradient_computer.loss_and_clipped_gradients(
      loss_fn=train_loss_fn,
      params=state.params,
      network_state=unused_network_state,
      rng_per_local_microbatch=unused_rng,
      inputs=batch,
  )
  noisy_grads, _, new_noise_state = gradient_computer.add_noise_to_grads(
      grads, noise_rng, jnp.asarray(batch_size), noise_state
  )
  new_state = state.apply_gradients(grads=noisy_grads)
  logits = metrics.per_example['logits']
  accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['label'])
  return new_state, new_noise_state, {'loss': loss, 'accuracy': accuracy}

Train

The only difference from usual non-DP training is that we keep splitting the PRNGKey for the noise and keep the noise state.

Worth noting that if you train with DP-SGD you can’t call train of the same model for the second time because the noise_multiplier was calculated for exactly that number of train steps.

import time
import numpy as np

# Train loop
metrics_history = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}

checkpoint_start = time.time()
accumulated_losses = []
accumulated_accuracies = []
noise_state = {}
noise_rng = random.PRNGKey(42)
for step, batch in enumerate(train_ds.as_numpy_iterator()):
  if use_dp:
    rng_grads, noise_rng = random.split(noise_rng)
    train_state, noise_state, step_metrics = dp_sgd_train_step(
        train_state, batch, noise_state, rng_grads
    )
  else:
    train_state, step_metrics = non_dp_train_step(train_state, batch)
  accumulated_losses.append(step_metrics['loss'])
  accumulated_accuracies.append(step_metrics['accuracy'])

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
    checkpoint_time = time.time() - checkpoint_start
    # Log the training metrics.
    metrics_history[f'train_loss'].append(np.mean(accumulated_losses))
    metrics_history[f'train_accuracy'].append(np.mean(accumulated_accuracies))
    accumulated_losses = []
    accumulated_accuracies = []

    # Compute the metrics on the test set.
    for test_batch in test_ds.as_numpy_iterator():
      loss, accuracy = eval_step(train_state, test_batch)
      accumulated_losses.append(loss)
      accumulated_accuracies.append(accuracy)

    # Log the test metrics.
    metrics_history[f'test_loss'].append(np.mean(accumulated_losses))
    metrics_history[f'test_accuracy'].append(np.mean(accumulated_accuracies))
    accumulated_losses = []
    accumulated_accuracies = []

    print(
        f' [elapsed time]: {checkpoint_time:.2f}\n',
        f'[train] step: {step}, '
        f'loss: {metrics_history["train_loss"][-1]}, '
        f'accuracy: {metrics_history["train_accuracy"][-1] * 100}',
    )
    print(
        f' [test] step: {step}, '
        f'loss: {metrics_history["test_loss"][-1]}, '
        f'accuracy: {metrics_history["test_accuracy"][-1] * 100}'
    )
    checkpoint_start = time.time()