Overview
The core of the library written in JAX has the following main components:
Accounting: logic to do privacy budget accounting, e.g. compute the budget splitting, determine noise multiplier based on epsilon and delta, etc.
DP-SGD: classes including public API to implement DP-SGD in raw JAX and Flax linen. The main class is
GradientComputerwhich implements public methods to calculate the clipped gradients and add noise to them.
Then on top of the core library the following backend-specific public APIs are built:
These APIs abstract some complexity and reduce the amount of code necessary to implement DP training at the cost of less flexibility.