jax_privacy.batch_selection.pad_to_multiple_of

jax_privacy.batch_selection.pad_to_multiple_of(indices, multiple)[source]

Pads the last dimension of indices to a multiple of multiple.

Example Usage:
>>> indices = np.arange(10)
>>> pad_to_multiple_of(indices, multiple=4)
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, -1, -1])
Parameters:
  • indices (ndarray) – A 1D array of batch indices.

  • multiple (int) – A positive integer. The input batch will be padded to a multiple of this value.

Return type:

ndarray

Returns:

A new 1D array of indices padded with -1.