jax_privacy.batch_selection.UserSelectionStrategy
- class jax_privacy.batch_selection.UserSelectionStrategy(base_strategy, examples_per_user_per_batch=1, shuffle_per_user=False)[source]
Bases:
objectA strategy that applies a base_strategy at the user level.
Each batch returned by the batch_iterator is a 2D array of integer indices, where all entries in the same row are examples owned by the same user. The examples in this user-batch are chosen in a cyclic fashion (maybe after shuffling). For example, if a user owns 3 examples [0, 5, 10], then each time this user is selected, the batches will be selected from [0, 5, 10, 0, 5, 10, 0, 5, 10, …]. It is expected that the gradient will be evaluated and clipped w.r.t. this user-batch before being aggregated across users.
- Example Usage:
>>> rng = np.random.default_rng(0) >>> base_strategy = CyclicPoissonSampling(sampling_prob=1, iterations=5) >>> strategy = UserSelectionStrategy(base_strategy, 2) >>> user_ids = np.array([0,0,0,1,1,2]) >>> iterator = strategy.batch_iterator(user_ids, rng) >>> print(next(iterator)) [[5 5] [0 1] [3 4]] >>> print(next(iterator)) [[5 5] [2 0] [3 4]]
- Variables:
base_strategy – The base batch selection strategy to apply at the user level. Will be used to select batches of users from the set of users.
examples_per_user_per_batch – The number of examples to select for each user in each batch. Determines the number of columns in the returned batches.
shuffle_per_user – Whether to shuffle the examples for each user before selecting examples_per_user_per_batch.
Methods
__init__Yields 2D batches of data indices.
Attributes
-
base_strategy:
BatchSelectionStrategy
-
examples_per_user_per_batch:
int= 1
-
shuffle_per_user:
bool= False
- batch_iterator(user_ids, rng=None)[source]
Yields 2D batches of data indices.
- Parameters:
user_ids (
ndarray) – A 1D array that maps each example to a user id, where each user_id can be an arbitrary integer.rng (
Generator|int|None) – A random seed or random number generator.
- Yields:
2D batches of data indices, where users are sampled according to the base_strategy and all entries in the same row are examples owned by the same selected user.
- Return type:
Iterator[ndarray]