jax_privacy.batch_selection.FixedBatchSampling

class jax_privacy.batch_selection.FixedBatchSampling(batch_size, iterations, replace=False)[source]

Bases: BatchSelectionStrategy

Implements fixed-size batch sampling.

Each batch is sampled uniformly at random from the dataset. By default, batches are sampled without replacement within a batch, and with replacement across batches (i.e., the same example can appear in multiple iterations).

References: https://arxiv.org/abs/1807.01647 and https://arxiv.org/abs/1908.10530

Variables:
  • batch_size – The number of examples per batch.

  • iterations – The number of total iterations / batches to generate.

  • replace – Whether to sample with replacement within each batch.

Methods

__init__

batch_iterator

Yields 1D batches of data indices.

Attributes

replace

batch_size

iterations

batch_size: int
iterations: int
replace: bool = False
batch_iterator(num_examples, rng=None)[source]

Yields 1D batches of data indices.

Return type:

Iterator[ndarray]