jax_privacy.batch_selection.UserSelectionStrategy

class jax_privacy.batch_selection.UserSelectionStrategy(base_strategy, examples_per_user_per_batch=1, shuffle_per_user=False)[source]

Bases: object

A strategy that applies a base_strategy at the user level.

Each batch returned by the batch_iterator is a 2D array of integer indices, where all entries in the same row are examples owned by the same user. The examples in this user-batch are chosen in a cyclic fashion (maybe after shuffling). For example, if a user owns 3 examples [0, 5, 10], then each time this user is selected, the batches will be selected from [0, 5, 10, 0, 5, 10, 0, 5, 10, …]. It is expected that the gradient will be evaluated and clipped w.r.t. this user-batch before being aggregated across users.

Example Usage:
>>> rng = np.random.default_rng(0)
>>> base_strategy = CyclicPoissonSampling(sampling_prob=1, iterations=5)
>>> strategy = UserSelectionStrategy(base_strategy, 2)
>>> user_ids = np.array([0,0,0,1,1,2])
>>> iterator = strategy.batch_iterator(user_ids, rng)
>>> print(next(iterator))
[[5 5]
 [0 1]
 [3 4]]
>>> print(next(iterator))
[[5 5]
 [2 0]
 [3 4]]
Variables:
  • base_strategy – The base batch selection strategy to apply at the user level. Will be used to select batches of users from the set of users.

  • examples_per_user_per_batch – The number of examples to select for each user in each batch. Determines the number of columns in the returned batches.

  • shuffle_per_user – Whether to shuffle the examples for each user before selecting examples_per_user_per_batch.

Methods

__init__

batch_iterator

Yields 2D batches of data indices.

Attributes

examples_per_user_per_batch

shuffle_per_user

base_strategy

base_strategy: BatchSelectionStrategy
examples_per_user_per_batch: int = 1
shuffle_per_user: bool = False
batch_iterator(user_ids, rng=None)[source]

Yields 2D batches of data indices.

Parameters:
  • user_ids (ndarray) – A 1D array that maps each example to a user id, where each user_id can be an arbitrary integer.

  • rng (Generator | int | None) – A random seed or random number generator.

Yields:

2D batches of data indices, where users are sampled according to the base_strategy and all entries in the same row are examples owned by the same selected user.

Return type:

Iterator[ndarray]