jax_privacy.batch_selection.pad_to_multiple_of
- jax_privacy.batch_selection.pad_to_multiple_of(indices, multiple)[source]
Pads the last dimension of indices to a multiple of multiple.
- Example Usage:
>>> indices = np.arange(10) >>> pad_to_multiple_of(indices, multiple=4) array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1])
- Parameters:
indices (
ndarray) – A 1D array of batch indices.multiple (
int) – A positive integer. The input batch will be padded to a multiple of this value.
- Return type:
ndarray- Returns:
A new 1D array of indices padded with -1.