jax_privacy.execution_plan.DPExecutionPlan

class jax_privacy.execution_plan.DPExecutionPlan(clipped_grad, batch_selection_strategy, noise_addition_transform, dp_event, neighboring_relation)[source]

Bases: object

Class for defining a DP execution plan.

A DP execution plan consists of a collection of components which when used together in the expected manner determine the DP guarantee, along with a DpEvent which precisely quantifies it. If constructed via one of the ExecutionPlanConfig classes defined in this module, then the dp_event can be trusted as having been formally verified by the JAX Privacy authors.

In pseudo-code, the components of this dataclass should roughly be used as follows:

plan = ... # Plan depending on the flavor of DP training you want
noise_state = plan.noise_addition_transform.init(...)
batch_sampler = plan.batch_selection_strategy
for indices in batch_sampler.batch_iterator(num_examples):
  batch = data.select(indices)
  grad_fn = plan.clipped_grad(loss_fn)
  clipped_grad_sum = grad_fn(params, batch, ...)

  dp_grad, noise_state = plan.noise_addition_transform.update(
      clipped_grad_sum, noise_state
  )
  # Sensitive, discard immediately after use.
  del indices, batch, clipped_grad_sum
  # Arbitrary post-processing of dp_grad.

If possible, we recommend coupling the batch selection, clipped aggregation, and noise addition components as tightly as possible to ensure sensitive objects are not intercepted and used unintentionally. For example, it is critical that no modification is applied to the clipped_grad_sum (such as scaling) before the noise_addition_transform is applied, as such a modifications could invalidate the DP guarantee because the noise is calibrated based on the sensivity of the clipped_grad_sum.

Variables:
  • clipped_grad – A function with a similar signature to jax.value_and_grad, but computes a sum of per-example clipped gradients.

  • batch_selection_strategy – Determines how batches are formed in each iteration.

  • noise_addition_transform – Stateful transformation that adds noise to clipped and aggregated gradients after each iteration.

  • dp_event – Characterizes the mechanism in terms of primitive building blocks that dp_accounting knows how to analyze.

  • neighboring_relation – The DP neighboring relation assumed by this mechanism.

Methods

__init__

Attributes

clipped_grad

batch_selection_strategy

noise_addition_transform

dp_event

neighboring_relation

clipped_grad: Callable[..., BoundedSensitivityCallable]
batch_selection_strategy: BatchSelectionStrategy
noise_addition_transform: GradientTransformation
dp_event: DpEvent
neighboring_relation: NeighboringRelation