jax_privacy.batch_selection.BallsInBinsSampling

class jax_privacy.batch_selection.BallsInBinsSampling(iterations, cycle_length)[source]

Bases: BatchSelectionStrategy

Implements balls-in-bins sampling.

In balls-in-bins, each example is independently assigned a ‘bin’ from 0 to cycle_length-1 uniformly at random, and then appears in all rounds i such that i % cycle_length == bin. See https://arxiv.org/abs/2410.06266 and https://arxiv.org/abs/2412.16802 for more details.

Formal guarantees of the batch_iterator:
  • All batches consist of indices in the range [0, num_examples).

  • Each example appears in all batches with index i such that i % cycle_length == j, with j chosen uniformly at random independently for each example, and in no other batches.

Variables:
  • iterations – The number of total iterations / batches to generate.

  • cycle_length – The number of batches in a cycle, equivalently the separation between two consecutive appearances of the same example.

Methods

__init__

batch_iterator

Yields 1D batches of data indices.

Attributes

iterations

cycle_length

iterations: int
cycle_length: int
batch_iterator(num_examples, rng=None)[source]

Yields 1D batches of data indices.

Return type:

Iterator[ndarray]