jax_privacy.batch_selection.split_and_pad_global_batch

jax_privacy.batch_selection.split_and_pad_global_batch(indices, minibatch_size, microbatch_size=None)[source]

Splits a global batch of indices into a list of fixed-size minibatches.

The last minibatch will be padded with -1 indices to make it the right size. It is crucial that downstream users correctly account for this by e.g., loading in a dummy example not derived from real data, and explicitly passing in is_padding_example to the clipped gradient function to ensure the gradients for these examples are correctly zeroed out.

Example Usage:
>>> indices = np.arange(10)
>>> split_and_pad_global_batch(indices, minibatch_size=4)
[array([0, 1, 2, 3]), array([4, 5, 6, 7]), array([ 8,  9, -1, -1])]
Parameters:
  • indices (ndarray) – A 1D or 2D numpy array of indices representing the global batch.

  • minibatch_size (int) – The desired size of each minibatch. Minibatches of this size will typically be passed into a compiled function that computes and accumulates clipped gradients.

  • microbatch_size (int | None) – The size of each microbatch. If set, will reorder the last minibatch to ensure that the padding indices appear in the right indices to enable early stopping within the last minibatch gradient evaluation. See microbatching.compute_early_stopping_order for more details on this.

Return type:

list[ndarray]

Returns:

A list of minibatches of indices, each of size exactly minibatch_size. The last minibatch may contain extra -1 indices representing padding examples to make it the right size.