jax_privacy.execution_plan.PerformanceFlags

class jax_privacy.execution_plan.PerformanceFlags(dtype=<class 'numpy.float32'>, noise_seed=None, intermediate_strategy=SupportedStrategies.DEFAULT, microbatch_size=None, spmd_axis_name=None)[source]

Bases: object

Performance-only flags that do not affect mechanism behavior or privacy.

These flags control implementation details such as numerical precision, sharding strategy, and memory/compute trade-offs. Changing them should not alter the privacy properties or the mathematical definition of the DP mechanism.

Variables:
  • dtype – The dtype to use for noise generation and gradient aggregation.

  • noise_seed – A seed for the random number generator used for noise addition.

  • intermediate_strategy – Strategy for generating intermediate noise, controls sharding behavior for noise addition.

  • microbatch_size – If set, per-example gradient computation is broken into sequential microbatches to reduce peak memory at the cost of compute.

  • spmd_axis_name – Axis name for distributed vmap in SPMD settings.

Methods

__init__

Attributes

intermediate_strategy

microbatch_size

noise_seed

spmd_axis_name

dtype

alias of float32

noise_seed: int | None = None
intermediate_strategy: SupportedStrategies = (<function SupportedStrategies.<lambda>>, <function SupportedStrategies.<lambda>>)
microbatch_size: int | None = None
spmd_axis_name: str | None = None