jax_privacy.batch_selection.RandomAllocationSampling

class jax_privacy.batch_selection.RandomAllocationSampling(total_participations, iterations)[source]

Bases: BatchSelectionStrategy

Implements k-out-of-t random allocation (aka balanced-iteration sampling).

Each example independently selects exactly k steps (out of iterations total) to participate in, uniformly at random. For k=1, this participation pattern is equivalent to BallsInBinsSampling. See the below papers for details about this strategy: - https://arxiv.org/abs/2206.03151 (k=1 only) - https://arxiv.org/abs/2410.06266 (k=1 only) - https://arxiv.org/abs/2412.16802 (k=1 only) - https://arxiv.org/abs/2502.08202 (k>1) - https://arxiv.org/abs/2503.03043 (k>1) - https://arxiv.org/abs/2601.21636 (k>1) - https://arxiv.org/abs/2602.17284 (k>1) - https://arxiv.org/abs/2605.07072 (k>1)

Formal guarantees of the batch_iterator:
  • All batches consist of indices in the range [0, num_examples).

  • Each example appears in exactly k of the iterations batches, chosen uniformly at random without replacement from [0, iterations).

  • The allocation for each example is independent of all other examples.

Variables:
  • total_participations – The number of steps each example participates in (k).

  • iterations – The total number of iterations / batches to generate (t).

Methods

__init__

batch_iterator

Yields 1D batches of data indices.

Attributes

total_participations

iterations

total_participations: int
iterations: int
batch_iterator(num_examples, rng=None)[source]

Yields 1D batches of data indices.

Return type:

Iterator[ndarray]