jax_privacy.batch_selection

API and implementations for batch selection strategies.

The BatchSelectionStrategy API specifies how batches should be formed from examples in a framework-agnostic manner. It produces indices into a dataset, not example elements themselves, and hence relies on having a way to access individual examples by index efficiently, such as with an in-memory list, array, or pytree of arrays. For datasets which do not fit in memory, we recommend using pygrain (https://github.com/google/grain), or using an offline job to reorder the data on disk before loading into your training pipeline.

The implementations in this file generally materialize a vector of indices of size num_examples and hence it requires that this object fits in memory, i.e., roughly that num_examples < 1e9.

The examples below demonstrate how to use the BatchSelectionStrategy API. via the CyclicPoissonSampling implementation (all with expected batch size 3):

Example Usage (fixed order + multi-epoch) [1]:
>>> rng = np.random.default_rng(0)
>>> b = CyclicPoissonSampling(sampling_prob=1, iterations=8, cycle_length=4)
>>> print(*b.batch_iterator(12, rng=rng), sep=' ')
[9 2 7] [ 4  5 11] [0 3 6] [10  8  1] [9 2 7] [ 4  5 11] [0 3 6] [10  8  1]
Example Usage (standard Poisson sampling) [2]:
>>> b = CyclicPoissonSampling(sampling_prob=0.25, iterations=8)
>>> print(*b.batch_iterator(12, rng=rng), sep=' ')
[5 6 7] [5 8 3 7 2] [ 1  5 11] [0 3] [ 5  1  3  4 10] [2] [4 5 1 3] [6]
Example Usage (BandMF-style sampling) [3]:
>>> p = 0.5
>>> b = CyclicPoissonSampling(sampling_prob=p, iterations=6, cycle_length=2)
>>> print(*b.batch_iterator(12, rng=rng), sep=' ')
[2 4] [1 8 9] [2 7 5 4] [11  1  3] [10  2  5  0  4] [ 1 11  6]

References: [1] https://arxiv.org/abs/2211.06530 [2] https://arxiv.org/abs/1607.00133 [3] https://arxiv.org/abs/2306.08153

Functions

pad_to_multiple_of(indices, multiple)

Pads the last dimension of indices to a multiple of multiple.

split_and_pad_global_batch(indices, ...[, ...])

Splits a global batch of indices into a list of fixed-size minibatches.

Classes

BMinSepSampling(sampling_prob, iterations, ...)

Implements b-min-sep sampling.

BallsInBinsSampling(iterations, cycle_length)

Implements balls-in-bins sampling.

BatchSelectionStrategy()

Abstract base class for batch selection strategies.

CyclicPoissonSampling(sampling_prob, iterations)

Implements Poisson sampling, possibly with cyclic sampling and truncation.

FixedBatchSampling(batch_size, iterations[, ...])

Implements fixed-size batch sampling.

PartitionType(value)

An enum specifying how examples should be assigned to groups.

RandomAllocationSampling(...)

Implements k-out-of-t random allocation (aka balanced-iteration sampling).

UserSelectionStrategy(base_strategy[, ...])

A strategy that applies a base_strategy at the user level.