jax_privacy.batch_selection.BallsInBinsSampling
- class jax_privacy.batch_selection.BallsInBinsSampling(iterations, cycle_length)[source]
Bases:
BatchSelectionStrategyImplements balls-in-bins sampling.
In balls-in-bins, each example is independently assigned a ‘bin’ from 0 to cycle_length-1 uniformly at random, and then appears in all rounds i such that i % cycle_length == bin. See https://arxiv.org/abs/2410.06266 and https://arxiv.org/abs/2412.16802 for more details.
- Formal guarantees of the batch_iterator:
All batches consist of indices in the range [0, num_examples).
Each example appears in all batches with index i such that i % cycle_length == j, with j chosen uniformly at random independently for each example, and in no other batches.
- Variables:
iterations – The number of total iterations / batches to generate.
cycle_length – The number of batches in a cycle, equivalently the separation between two consecutive appearances of the same example.
Methods
__init__Yields 1D batches of data indices.
Attributes
-
iterations:
int
-
cycle_length:
int