jax_privacy.execution_plan.PerformanceFlags
- class jax_privacy.execution_plan.PerformanceFlags(dtype=<class 'numpy.float32'>, noise_seed=None, intermediate_strategy=SupportedStrategies.DEFAULT, microbatch_size=None, spmd_axis_name=None)[source]
Bases:
objectPerformance-only flags that do not affect mechanism behavior or privacy.
These flags control implementation details such as numerical precision, sharding strategy, and memory/compute trade-offs. Changing them should not alter the privacy properties or the mathematical definition of the DP mechanism.
- Variables:
dtype – The dtype to use for noise generation and gradient aggregation.
noise_seed – A seed for the random number generator used for noise addition.
intermediate_strategy – Strategy for generating intermediate noise, controls sharding behavior for noise addition.
microbatch_size – If set, per-example gradient computation is broken into sequential microbatches to reduce peak memory at the cost of compute.
spmd_axis_name – Axis name for distributed vmap in SPMD settings.
Methods
__init__Attributes
- dtype
alias of
float32
-
noise_seed:
int|None= None
-
intermediate_strategy:
SupportedStrategies= (<function SupportedStrategies.<lambda>>, <function SupportedStrategies.<lambda>>)
-
microbatch_size:
int|None= None
-
spmd_axis_name:
str|None= None