jax_privacy.batch_selection.CyclicPoissonSampling

class jax_privacy.batch_selection.CyclicPoissonSampling(sampling_prob, iterations, truncated_batch_size=None, cycle_length=1, partition_type=PartitionType.EQUAL_SPLIT)[source]

Bases: BatchSelectionStrategy

Implements Poisson sampling, possibly with cyclic sampling and truncation.

This generalizes several common sampling strategies [1,2,3,4].

References: [1] https://arxiv.org/abs/2211.06530 [2] https://arxiv.org/abs/1607.00133 [3] https://arxiv.org/abs/2306.08153 [4] https://arxiv.org/abs/2411.04205

Formal guarantees of the batch_iterator:
  • All batches consist of indices in the range [0, num_examples).

  • Each example only appears in batches with index i such that i % cycle_length == j for some fixed j per example.

  • Without truncation, every index independently appears in each batch (where it is eligible to participate subject to the previous guarantee) with probability sampling_prob.

  • With truncation, if > truncated_batch_size examples appear in a batch under the previous guarantee, then we select truncated_batch_size of them uniformly at random and discard the rest.

  • If even_partition = True, num_examples % cycle_length examples are discarded, i.e. never sampled.

Variables:
  • sampling_prob – The probability of sampling an example in rounds when it is eligible to participate. To achieve an average batch size of expected_batch_size, one should ideally set sampling_prob = expected_batch_size / (num_examples // cycle_length).

  • iterations – The number of total iterations / batches to generate.

  • truncated_batch_size – If True, after Poisson sampling, if we have more than truncated_batch_size examples in a batch, we uniformly sample truncated_batch_size of them and discard the rest.

  • cycle_length – If > 1, we use cyclic Poisson sampling: we partition the examples into cycle_length groups, and do Poisson sampling from the groups in a round-robin fashion. cycle_length == 1 retrieves standard Poisson sampling.

  • partition_type – How to partition the examples into groups for before Poisson sampling. EQUAL_SPLIT is the default, and is only compatible with zero-out and replace-one adjacency notions, while INDEPENDENT is compatible with the add-remove adjacency notion.

Methods

__init__

batch_iterator

Yields 1D batches of data indices.

Attributes

cycle_length

partition_type

truncated_batch_size

sampling_prob

iterations

sampling_prob: float
iterations: int
truncated_batch_size: int | None = None
cycle_length: int = 1
partition_type: PartitionType = 2
batch_iterator(num_examples, rng=None)[source]

Yields 1D batches of data indices.

Return type:

Iterator[ndarray]