jax_privacy.keras_api.DPKerasConfig
- class jax_privacy.keras_api.DPKerasConfig(epsilon, delta, clipping_norm, batch_size, gradient_accumulation_steps, train_steps, train_size, poisson_sampling_in_fit=False, noise_multiplier=None, rescale_to_unit_norm=True, microbatch_size=None, seed=None)[source]
Bases:
objectParameters for adding DP-SGD to a Keras model.
- Variables:
epsilon – The epsilon that defines the differential-privacy budget. It should be in (0; +infinity) range. 0 means perfect privacy guarantee (not achievable in practice due to infinite noise), +infinity means no privacy guarantee. A commonly used value is ln(3) (smaller value more noise). You should set this value before training and only based on the privacy guarantees you have to achieve. You should not increase the epsilon only because of poor model performance.
delta – The delta that defines the differential-privacy budget. The value of it means the probability of full disclosure, no-privacy. It should be in (0, 1] and be as small as possible (e.g. 1e-5, smaller value more noise). You should set this value before training and only based on the privacy guarantees you have to achieve. You should not increase the delta only because of poor model performance.
clipping_norm – The clipping norm for the per-example gradients.
batch_size – The batch size used by the DP optimizer. When poisson_sampling_in_fit=True, this is the expected batch size of the internal Poisson sampler. Otherwise it must match the batch size supplied to fit().
gradient_accumulation_steps – The number of gradient accumulation steps. This is the number of batches to accumulate before adding noise and performing an optimizer step. 1 means that there is no gradient accumulation, each optimizer step is performed after a single batch. This parameter defines the effective batch size = (physical) batch_size * gradient_accumulation_steps, i.e. the real accumulated batch size used for the model update. Usually DP training provides better accuracy with larger effective batch size, therefore it is recommended to set gradient_accumulation_steps to a value larger than 1. In many cases, you won’t be able to set the physical batch size to a large enough value due to memory constraints, therefore gradient accumulation technique is very useful during DP training.
train_steps – The number of training steps (optimizer update steps). If you try to train the model for more steps, it will fail. If you train by epochs, then it is epochs * (train_size // batch_size). If you train while the dataset iterator is not over then it is the length of the dataset iterator.
train_size – The number of training examples in the dataset. If you repeat the examples in your dataset iterator, it should be the number of training examples in the original dataset before repeating.
poisson_sampling_in_fit – Whether fit() should internally resample random-access array inputs using Poisson sampling. Leave this as False for backwards-compatible behavior or when the user supplies a dataset iterator that already handles sampling.
noise_multiplier – The noise multiplier for the gradients. If None (recommended), the noise multiplier will be automatically calculated based on epsilon, delta, effective_batch_size, train_steps and train_size. The noise added to the average of gradients per total batch is normal with mean 0 and stddev = noise_multiplier * clipping_norm / effective_batch_size.
rescale_to_unit_norm – Whether to rescale the gradients to unit norm. Simplifies learning-rate tuning, see https://arxiv.org/abs/2204.13650.
seed – The seed for the random number generator. If None, a random seed is used. It must be an int64. Useful for reproducibility.
microbatch_size – The size of each microbatch. The device batch size will be split up into microbatches of this size and processed sequentially on the forward/backward pass. By setting microbatch_size=batch_size, the forward/backward pass is performed once on the entire batch using jax.vmap. By setting microbatch_size=1, the forward/backward pass is performed on each batch element individually, with the gradients accumulated sequentially using jax.lax.scan. Setting to batch_size gives the largest degree of parllelism, while setting to 1 gives the least memory consumption. Any value in between can be used to trade-off memory consumption vs. parallel computation. This parameter is similar to gradient_accumulation_steps, but it works fully inside of device memory under a single jitted function, while gradient_accumulation_steps operates outside of the jit boundary. The default value is None, which means that no microbatching is used, and is equivalent to microbatch_size=batch_size.
Methods
__init__Calculates the noise multiplier for the given DP training parameters.
Attributes
The effective batch size which is used for the model update.
-
epsilon:
float
-
delta:
float
-
clipping_norm:
float
-
batch_size:
int
-
gradient_accumulation_steps:
int
-
train_steps:
int
-
train_size:
int
-
poisson_sampling_in_fit:
bool= False
-
noise_multiplier:
float|None= None
-
rescale_to_unit_norm:
bool= True
-
microbatch_size:
int|None= None
-
seed:
int|None= None
- property effective_batch_size: int
The effective batch size which is used for the model update.
It equals to batch_size * gradient_accumulation_steps.