jax_privacy.batch_selection
API and implementations for batch selection strategies.
The BatchSelectionStrategy API specifies how batches should be formed from examples in a framework-agnostic manner. It produces indices into a dataset, not example elements themselves, and hence relies on having a way to access individual examples by index efficiently, such as with an in-memory list, array, or pytree of arrays. For datasets which do not fit in memory, we recommend using pygrain (https://github.com/google/grain), or using an offline job to reorder the data on disk before loading into your training pipeline.
The implementations in this file generally materialize a vector of indices of size num_examples and hence it requires that this object fits in memory, i.e., roughly that num_examples < 1e9.
The examples below demonstrate how to use the BatchSelectionStrategy API. via the CyclicPoissonSampling implementation (all with expected batch size 3):
- Example Usage (fixed order + multi-epoch) [1]:
>>> rng = np.random.default_rng(0) >>> b = CyclicPoissonSampling(sampling_prob=1, iterations=8, cycle_length=4) >>> print(*b.batch_iterator(12, rng=rng), sep=' ') [9 2 7] [ 4 5 11] [0 3 6] [10 8 1] [9 2 7] [ 4 5 11] [0 3 6] [10 8 1]
- Example Usage (standard Poisson sampling) [2]:
>>> b = CyclicPoissonSampling(sampling_prob=0.25, iterations=8) >>> print(*b.batch_iterator(12, rng=rng), sep=' ') [5 6 7] [5 8 3 7 2] [ 1 5 11] [0 3] [ 5 1 3 4 10] [2] [4 5 1 3] [6]
- Example Usage (BandMF-style sampling) [3]:
>>> p = 0.5 >>> b = CyclicPoissonSampling(sampling_prob=p, iterations=6, cycle_length=2) >>> print(*b.batch_iterator(12, rng=rng), sep=' ') [2 4] [1 8 9] [2 7 5 4] [11 1 3] [10 2 5 0 4] [ 1 11 6]
References: [1] https://arxiv.org/abs/2211.06530 [2] https://arxiv.org/abs/1607.00133 [3] https://arxiv.org/abs/2306.08153
Functions
|
Pads the last dimension of indices to a multiple of multiple. |
|
Splits a global batch of indices into a list of fixed-size minibatches. |
Classes
|
Implements b-min-sep sampling. |
|
Implements balls-in-bins sampling. |
Abstract base class for batch selection strategies. |
|
|
Implements Poisson sampling, possibly with cyclic sampling and truncation. |
|
Implements fixed-size batch sampling. |
|
An enum specifying how examples should be assigned to groups. |
Implements k-out-of-t random allocation (aka balanced-iteration sampling). |
|
|
A strategy that applies a base_strategy at the user level. |