jax_privacy.experimental.compilation_utils.optimal_physical_batch_sizes
- jax_privacy.experimental.compilation_utils.optimal_physical_batch_sizes(batch_sizes, num_compilations)[source]
Find a set of of compiled batch sizes that minimizes wasted compute.
Given a list of batch sizes $B_1, …, B_n$ and and a compilation budget, $C$, this function finds compiled batch sizes $M_1, …, M_C$ that minimizes the following objective:
$ L_(M_1, …, M_C) = sum_{i=1}^n min_{j : M_j geq B_i} (M_j - B_i) $
The term M_j - B_i in this objective represents the wasted compute for evaluating gradients for a batch of size M_j when the true batch size is B_i.
The time complexity of this function is $O(C * b^2)$ where $b$ is the number of unique batch sizes in the list. It is currently not highly optimized.
- Parameters:
batch_sizes (
list[int]) – A list of non-negative integers B_1, …, B_n.num_compilations (
int) – A non-negative integer representing the number of unique batch sizes to return (and compile downstream functions for).
- Return type:
set[int]- Returns:
A set of integers.