JAX & Flax API
accounting and dp_sgd modules provide the API for raw JAX and Flax Linen. Flax NNX is not supported yet.
accounting module contains logic for computing the hyperparameters to ensure the privacy budget is not exceeded. dp_sgd module contains logic for calculating the gradients and adding noise to them.
The main steps of using the API are:
Choose the accounting algorithm (RDP or PLD).
Create the accountant.
Use the accountant to compute the “unfixed” DP hyperparameter.
Create the gradient computer.
Create the loss function wrapper that will be called to calculate loss per single example.
Use the gradient computer in the traing step to calculate the clipped gradients and add noise to them.
Apply the clipped noisy gradients to the model.
Comments about step 1:
Using PLD accounting algorithm is recommended because it utilizes the budget better. It takes a bit longer,but still reasonably fast.
Comments about step 3:
Two of the following hyperparameters has to be fixed beforehand:
Noise multiplier (noise stddev = noise_multiplier * l2_clip_norm).
Number of updates.
Physical batch size.
Then the third one can be inferred from the other fixed two. Computation of the non-fixed hyperparameter is called “calibration”. The calibration.py provides convenient functions to do that. See the API specification below for details. Of course, you also have to supply other core DP hyperparameters, like epsilon and delta.
Usually batch size and number of updates are fixed and noise multiplier is calibrated.
Comments about step 4:
DpsgdGradientComputer is the core of the API. The main params to supply there is the maximum L2 norm to which the gradients are clipped and noise multiplier.
Comments about step 5:
It is important to understand that the gradients have to be calculated and clipped per single example. This is one of the main differences between the usually SGD training and DP SGD training. Therefore you have to create a wrapper around your loss function that will be used to calculate the loss per single example (see the example below).
In JAX Privacy the loss function wrapper has to be of the specific signature: it has to accept model parameters, network state, random generator and input data (tuple of x and y) and return the loss and tuple of new network state and the metrics. For simple cases, network state and random generator are not used and can be ignored. Network state is just a python dictionary, therefore you can pass an empty dictionary. In the returned metrics object you can add metrics and specify how to aggregate them: stack it, average it or sum it. For example, you can stack logits to calculate accuracy later in the code. There are such strict requirements to the loss function wrapper because it is not clear how to aggregate a metric or a state over batch dimension: just stack it or average it or sum it, etc. See the API reference below for exact types and signatures. This part might be simplified in the future versions of JAX Privacy.
To make it more clear, here is the example using raw JAX that illustrates the aforementioned steps. There is also a Jupyter notebook that shows the usage of this API with Flax Linen: DP-SGD tutorial using Flax Linen on MNIST.
Example Usage
1 x_train_full, y_train_full = load_data(train_size, true_w, true_b)
2 model_params = init_model_params()
3
4 gradient_computer = None
5 noise_rng = None
6 if use_dp:
7 # Calculate noise_multiplier (stddev) given the privacy budget.
8 accountant = analysis.DpsgdTrainingAccountant(
9 dp_accountant_config=accountants.PldAccountantConfig()
10 )
11 noise_multiplier = calibrate.calibrate_noise_multiplier(
12 target_epsilon=1.0,
13 accountant=accountant,
14 batch_sizes=batch_size,
15 num_updates=num_epochs * train_size // batch_size,
16 num_samples=train_size,
17 target_delta=1e-5,
18 )
19 print(f"Noise multiplier {noise_multiplier}")
20 # Create gradient computer that will clip grads and add noise to them.
21 gradient_computer = gradients.DpsgdGradientComputer(
22 clipping_norm=1.0,
23 noise_multiplier=noise_multiplier,
24 rescale_to_unit_norm=True,
25 per_example_grad_method=grad_clipping.VECTORIZED,
26 )
27 noise_rng = random.key(42)
28
29 # Loss function wrapper that calculates loss per single example. Necessary
30 # because the gradients have to be calculated and clipped per single example.
31 # The signature of the loss function has to be exactly as here.
32 def single_example_loss_fn(
33 model_params, unused_network_state, unused_rng, inputs
34 ):
35 x, y = inputs
36 loss = loss_fn(model_params, x, y)
37 return loss, (unused_network_state, jax_privacy_typing.Metrics())
38
39 @jax.jit
40 def dp_update_step(model_params, batch_x, batch_y, noise_rng):
41 (loss, _), grads = gradient_computer.loss_and_clipped_gradients(
42 loss_fn=single_example_loss_fn,
43 params=model_params,
44 network_state={}, # not used
45 rng_per_local_microbatch=random.key(0), # not used
46 inputs=(batch_x, batch_y),
47 )
48 rng_grads, noise_rng = random.split(noise_rng)
49 noisy_grads, _, _ = gradient_computer.add_noise_to_grads(
50 grads, rng_grads, jnp.asarray(batch_size), noise_state={}
51 )
52 model_params = updated_model_params(model_params, noisy_grads)
53 return model_params, loss, noise_rng
54
55 @jax.jit
56 def update_step(model_params, batch_x, batch_y):
57 loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(
58 model_params, batch_x, batch_y
59 )
60 model_params = updated_model_params(model_params, grads)
61 return model_params, loss
62
63 print("\nStarting training...")
64 for epoch in range(num_epochs):
65 batched_dataset = batch_dataset(x_train_full, y_train_full, batch_size)
66
67 epoch_loss = 0.0
68 for batch_x_tf, batch_y_tf in batched_dataset:
69 batch_x = jnp.asarray(batch_x_tf)
70 batch_y = jnp.asarray(batch_y_tf)
71
72 if use_dp:
73 model_params, loss, noise_rng = dp_update_step(
74 model_params, batch_x, batch_y, noise_rng
75 )
76 else:
77 model_params, loss = update_step(model_params, batch_x, batch_y)
API Reference
The main API functions can be found in the API reference below.
Here is the list of available budget accounting algorithms (configs):
- class jax_privacy.accounting.accountants.RdpAccountantConfig(*, orders: ~collections.abc.Sequence[int] = <factory>)[source]
Bases:
DpAccountantConfigConfiguration for the RDP Accountant to use.
- class jax_privacy.accounting.accountants.PldAccountantConfig(*, value_discretization_interval: float = 0.0001)[source]
Bases:
DpAccountantConfigConfiguration for the PLD Accountant to use.
PLD utilizes the budget better allowing to add less noise, but takes more time, however still reasonably fast, therefore it is recommended to use it.
Once you have chosen the accountant, you can create DP-SGD accountant for training passing the accountant config from the previous step as a constructor argument:
- class jax_privacy.accounting.analysis.DpsgdTrainingAccountant(dp_accountant_config: DpAccountantConfig)[source]
Bases:
DpTrainingAccountantDefines privacy computations for Band-MF with Cyclic Poisson sampling.
This includes DP-SGD style analysis as a special case.
For accounting we follow the reduction in https://arxiv.org/abs/2306.08153. We assume that if num_samples % cycle_length != 0, then num_samples % cycle_length examples are discarded.
Initializes the accountant for Differential Privacy.
- Parameters:
dp_accountant_config – Configuration for the DP accountant to use.
Then you can use this accountant to compute one of the DP hyperparameters you have not fixed:
- jax_privacy.accounting.calibrate.calibrate_noise_multiplier(*, target_epsilon: float, accountant: DpTrainingAccountant, batch_sizes: int | Sequence[tuple[int, int]], num_updates: int, num_samples: int, target_delta: float, examples_per_user: int | None = None, cycle_length: int | None = None, truncated_batch_size: int | None = None, initial_max_noise: float = 1.0, initial_min_noise: float = 0.0, tol: float = 0.01) float[source]
Computes the noise multiplier to achieve target_epsilon.
- Parameters:
target_epsilon – The desired final epsilon.
accountant – Method of computing the privacy guarantee.
batch_sizes – Batch size. Integer or list of pairs (t: int, bs: int) if the noise multiplier changes across steps. ‘t’ indicates step where batch_size is set to ‘bs’.
num_updates – Total number of iterations.
num_samples – Number of training examples.
target_delta – Desired delta for the returned epsilon.
examples_per_user – If multiple examples per user are used, this is the maximum number any user contributes to the training set.
cycle_length – If using cyclic Poisson sampling with BandMF, the length of the cycle.
truncated_batch_size – If using truncated Poisson sampling, the maximum batch size to truncate to.
initial_max_noise – An initial estimate of the noise multiplier.
initial_min_noise – Minimum noise multiplier.
tol – tolerance of the optimizer for the calibration.
- Returns:
Noise multiplier.
- jax_privacy.accounting.calibrate.calibrate_num_updates(*, target_epsilon: float, accountant: DpTrainingAccountant, noise_multipliers: float | Sequence[tuple[int, float]], batch_sizes: int | Sequence[tuple[int, int]], num_samples: int, target_delta: float, examples_per_user: int | None = None, cycle_length: int | None = None, truncated_batch_size: int | None = None, initial_max_updates: int = 4, initial_min_updates: int = 1, tol: float = 0.1) int[source]
Computes the number of steps to achieve target_epsilon.
- Parameters:
target_epsilon – The desired final epsilon.
accountant – Method of computing the privacy guarantee.
noise_multipliers – Noise multiplier. Float or list of pairs (t: int, nm: float) if the noise multiplier changes across updates. ‘t’ indicates update where noise_multiplier is set to ‘nm’.
batch_sizes – Batch size. Integer or list of pairs (t: int, bs: int) if the noise multiplier changes across updates. ‘t’ indicates step where batch_size is set to ‘bs’.
num_samples – Number of training examples.
target_delta – Desired delta for the returned epsilon.
examples_per_user – If multiple examples per user are used, this is the maximum number any user contributes to the training set.
cycle_length – If using cyclic Poisson sampling with BandMF, the length of the cycle.
truncated_batch_size – If using truncated Poisson sampling, the maximum batch size to truncate to.
initial_max_updates – An initial estimate of the number of updates.
initial_min_updates – Minimum number of updates.
tol – tolerance of the optimizer for the calibration.
- Returns:
Number of updates.
- jax_privacy.accounting.calibrate.calibrate_batch_size(*, target_epsilon: float, accountant: DpTrainingAccountant, noise_multipliers: float | Sequence[tuple[int, float]], num_updates: int, num_samples: int, target_delta: float, examples_per_user: int | None = None, cycle_length: int | None = None, truncated_batch_size: int | None = None, initial_max_batch_size: int = 8, initial_min_batch_size: int = 1, tol: float = 0.01) int[source]
Computes the batch size required to achieve target_epsilon.
- Parameters:
target_epsilon – The desired final epsilon.
accountant – Method of computing the privacy guarantee.
noise_multipliers – Noise multiplier. Float or list of pairs (t: int, nm: float) if the noise multiplier changes across steps. ‘t’ indicates step where noise_multiplier is set to ‘nm’.
num_updates – Total number of iterations.
num_samples – Number of training examples.
target_delta – Desired delta for the returned epsilon.
examples_per_user – If multiple examples per user are used, this is the maximum number any user contributes to the training set.
cycle_length – If using cyclic Poisson sampling with BandMF, the length of the cycle.
truncated_batch_size – If using truncated Poisson sampling, the maximum batch size to truncate to.
initial_max_batch_size – An initial estimate of the batch size.
initial_min_batch_size – Minimum batch size.
tol – tolerance of the optimizer for the calibration.
- Returns:
Batch size.
Then you can create DpSgdGradientComputer to calculate the clipped gradients and add noise to them.
- class jax_privacy.dp_sgd.gradients.GradientComputer(*, clipping_norm: float | None, noise_multiplier: float | None, rescale_to_unit_norm: bool, per_example_grad_method: ~jax_privacy.dp_sgd.grad_clipping.PerExampleGradMethod, rng_per_param_fn: ~typing.Callable[[~jax.Array], ~jax.Array] = <function GradientComputer.<lambda>>, global_norm_fn: ~typing.Callable[[~jax_privacy.dp_sgd.typing.ParamsT], ~jax.Array] = <function global_norm>)[source]
Bases:
ABC,Generic[InputsT,ParamsT,ModelStateT,NoiseStateT]Computes (potentially) clipped and noisy gradients.
Initialises the gradient computation.
- Parameters:
clipping_norm – maximum L2 norm to which the input tree should be clipped.
noise_multiplier – standard deviation of the noise to add to the average of the clipped gradient to make it differentially private. It will be multiplied by clipping_norm / total_batch_size before the noise gets actually added.
rescale_to_unit_norm – If true, additionally rescale the clipped gradient by 1/clipping_norm so it has an L2 norm of at most one.
per_example_grad_method – Per-example gradient clipping method to use. Does not affect the results, but controls speed/memory trade-off.
rng_per_param_fn – Optional callable to allow gradient noise random keys to be specialised for different param slices.
global_norm_fn – function to compute the L2 norm of an ArrayTree.
- abstract add_noise_to_grads(grads: ParamsT, rng_per_batch: Array, total_batch_size: Array, noise_state: NoiseStateT) tuple[ParamsT, Array, NoiseStateT][source]
Adds noise to gradients.
- Parameters:
grads – gradients to privatize.
rng_per_batch – random number generation key.
total_batch_size – total batch-size once accumulated over devices and steps (i.e. as seen by the optimizer performing the update).
noise_state – additional state required to compute noise.
- Returns:
gradients with the added noise. std: standard deviation used for the noise (for monitoring purposes). noise_state: (updated, if needed) state required to compute noise.
- Return type:
noisy_grads
- clean_gradients(*, loss_fn: LossFn, params: ParamsT, network_state: ModelStateT, rng_per_local_microbatch: Array, inputs: InputsT) ParamsT[source]
Computes unclipped gradients of the given loss function.
- Parameters:
loss_fn – Loss function whose gradients are required.
params – Trainable parameters.
network_state – Network state input to loss_fn.
rng_per_local_microbatch – Random number key for the batch. The caller must provide independent keys for different training steps, accumulation steps (if using microbatching), and model replicas (if invoked in a pmap).
inputs – Inputs to loss_fn.
- Returns:
Unclipped gradients.
- abstract init_noise_state(params: ParamsT) NoiseStateT[source]
Returns a new noise_state to be used for adding noise to gradients.
- l2_loss(params: ParamsT) Array | ndarray | bool | number | float | int[source]
Computes the squared L2 loss.
- Parameters:
params – model parameters for which the loss should be computed, assumed to be in haiku-like format.
- Returns:
Squared L2 loss.
- loss_and_clipped_gradients(*, loss_fn: ~jax_privacy.dp_sgd.typing.LossFn, params: ~jax_privacy.dp_sgd.typing.ParamsT, network_state: ~jax_privacy.dp_sgd.typing.ModelStateT, rng_per_local_microbatch: ~jax.Array, inputs: ~jax_privacy.dp_sgd.typing.InputsT, state_acc_strategies: ~jax_privacy.dp_sgd.grad_clipping_utils.StateAccumulationStrategy | ~collections.abc.Mapping[str, StateAccumulationStrategyTree] = <jax_privacy.dp_sgd.grad_clipping_utils.Reject object>) tuple[tuple[Array, tuple[ModelStateT, Metrics]], ParamsT][source]
Computes (potentially) clipped gradients of the given loss function.
- Parameters:
loss_fn – Loss function whose clipped gradients are required.
params – Trainable parameters.
network_state – Network state input to loss_fn.
rng_per_local_microbatch – Random number key for the batch. The caller must provide independent keys for different training steps, accumulation steps (if using microbatching), and model replicas (if invoked in a pmap).
inputs – Inputs to loss_fn.
state_acc_strategies – Prefix tree of network state accumulation strategies.
- Returns:
Tuple consisting of (loss-and-aux, clipped_grads) where loss-and-aux is as is returned by loss_fn (with the addition of the grad norm per example in the metrics).
- value_and_clipped_grad(value_and_grad_fn: ~jax_privacy.dp_sgd.typing.ValueAndGradFn, *, state_acc_strategies: ~jax_privacy.dp_sgd.grad_clipping_utils.StateAccumulationStrategy | ~collections.abc.Mapping[str, StateAccumulationStrategyTree] = <jax_privacy.dp_sgd.grad_clipping_utils.Reject object>) ValueAndGradFn[source]
Creates the function computing (potentially) clipped gradients.
- Parameters:
value_and_grad_fn – Function that produces unclipped gradients. It is expected to have the following signature: (loss, aux), grad = grad_fn(params, network_state, rng_key, inputs).
state_acc_strategies – Prefix tree of network state accumulation strategies. The default is to raise an error if any network state is present, but this can be overridden, e.g. to average state across microbatches. CAUTION - Any approach in which the state depends on the inputs _and_ influences trainable parameters (as will be the case with batch normalisation) will invalidate the DP guarantees, as it’s bypassing the DP noise/clipping.
- Returns:
A function computing gradients that are potentially clipped per sample.
- class jax_privacy.dp_sgd.gradients.DpsgdGradientComputer(*, clipping_norm: float | None, noise_multiplier: float | None, rescale_to_unit_norm: bool, per_example_grad_method: ~jax_privacy.dp_sgd.grad_clipping.PerExampleGradMethod, rng_per_param_fn: ~typing.Callable[[~jax.Array], ~jax.Array] = <function GradientComputer.<lambda>>, global_norm_fn: ~typing.Callable[[~jax_privacy.dp_sgd.typing.ParamsT], ~jax.Array] = <function global_norm>)[source]
Bases:
GradientComputer[InputsT,ParamsT,ModelStateT,Mapping[str,Array]]Gradient computer for DP-SGD.
Initialises the gradient computation.
- Parameters:
clipping_norm – maximum L2 norm to which the input tree should be clipped.
noise_multiplier – standard deviation of the noise to add to the average of the clipped gradient to make it differentially private. It will be multiplied by clipping_norm / total_batch_size before the noise gets actually added.
rescale_to_unit_norm – If true, additionally rescale the clipped gradient by 1/clipping_norm so it has an L2 norm of at most one.
per_example_grad_method – Per-example gradient clipping method to use. Does not affect the results, but controls speed/memory trade-off.
rng_per_param_fn – Optional callable to allow gradient noise random keys to be specialised for different param slices.
global_norm_fn – function to compute the L2 norm of an ArrayTree.
- add_noise_to_grads(grads: ParamsT, rng_per_batch: Array, total_batch_size: Array, noise_state: Mapping[str, Array]) tuple[ParamsT, Array, Mapping[str, Array]][source]
Adds noise to gradients.
- Parameters:
grads – gradients to privatize.
rng_per_batch – random number generation key.
total_batch_size – total batch-size once accumulated over devices and steps (i.e. as seen by the optimizer performing the update).
noise_state – additional state required to compute noise.
- Returns:
gradients with the added noise. std: standard deviation used for the noise (for monitoring purposes). noise_state: (updated, if needed) state required to compute noise.
- Return type:
noisy_grads