jax_privacy.batch_selection.BMinSepSampling
- class jax_privacy.batch_selection.BMinSepSampling(sampling_prob, iterations, min_sep, warm_start=True, truncated_batch_size=None)[source]
Bases:
BatchSelectionStrategyImplements b-min-sep sampling.
Each batch is sampled using Poisson sampling, ignoring any example that participated in the previous min_sep-1 iterations.
While both this and cyclic Poisson sampling enforce the b-min-sep property, cyclic Poisson does so by only allowing an example to be eligible to participate every b iterations, whereas here every example is eligible to participate in every iteration as long as it did not participate in the previous b-1 iterations. However, we can do accounting for cyclic Poisson via PLD accounting, whereas we only know how to analyze b-min-sep via Monte Carlo accounting.
This is a generalization of balls-in-bins, which also enforces b-min-sep sampling. In particular, this reduces to balls-in-bins when warm_start = True and sampling_prob = 1.
See https://arxiv.org/abs/2602.09338 for more details.
- Variables:
sampling_prob – The probability an example is sampled in a given iteration, given that it was not sampled in any of the previous min_sep - 1 iterations. Note that the expected batch size is dataset size / (min_sep - 1 + 1 / sampling_prob), not just dataset_size * sampling_prob.
iterations – The number of total iterations / batches to generate.
min_sep – The minimum separation between two sampled examples.
warm_start – If True, we initialize the b-min-sep sampling process at a warm start. This ensures the batch size is consistent from the start of training.
truncated_batch_size – If set, we truncate the batch to this size. To maintain that the participation of examples is independent prior to truncation, examples which were sampled and then truncated are still excluded in the next min_sep - 1 iterations.
Methods
__init__Yields 1D batches of data indices.
Attributes
-
sampling_prob:
float
-
iterations:
int
-
min_sep:
int
-
warm_start:
bool= True
-
truncated_batch_size:
int|None= None