jax_privacy.matrix_factorization.streaming_matrix.StreamingMatrix
- class jax_privacy.matrix_factorization.streaming_matrix.StreamingMatrix(init_multiply, multiply_next)[source]
Bases:
Generic[State]A linear mapping x -> A x for a lower-triangular (streaming) A matrix.
Via the attributes / member functions init_multiply and multiply_next, this class allows you to efficiently compute a linear mapping x -> A x in streaming fashion (one element at a time). The precise meaning of the term efficiently is implementation-dependent, with examples including constant memory overhead, and / or without fully materializing A or x.
- Example Usage:
>>> A = prefix_sum() >>> x = jnp.arange(1, 5).astype(float) >>> slices = [] >>> state = A.init_multiply(x[0]) >>> for i in range(len(x)): ... result_slice, state = A.multiply_next(x[i], state) ... slices.append(result_slice) >>> Ax = jnp.stack(slices) >>> print(Ax) [ 1. 3. 6. 10.] >>> print(jnp.cumsum(x)) [ 1. 3. 6. 10.]
See the constructor docstring for a full description of init_multiply and multiply_next.
Importantly, this design encodes the fact that Ax[i] may only depend on x[i] and state captured from computing Ax[0], …, Ax[i-1]. This is equivalent to A having a lower-triangular matrix representation in the standard basis.
In general, A and x may both be infinite; thus we sidestep the question of how many elements of A x one wishes to compute by assuming the user provides a range.
- Variables:
init_multiply – A function that returns the initial state given the expected shape of inputs to each call to multiply_next.
multiply_next – A function that returns (next_slice, updated_state) from (next_input, current_state).
Methods
__init__Construct a StreamingMatrix object from an implementation of init/next.
A utility method to materialize this matrix as an n x n ndarray.
Computes the row-wise L2^2 norm of the matrix.
Returns a new StreamingMatrix with scaled rows and/or cols.
Attributes
-
init_multiply:
Callable[[Union[Array,ndarray,bool,number,Iterable[Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Mapping[Any,Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]]]],TypeVar(State, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])]
-
multiply_next:
Callable[[Union[Array,ndarray,bool,number,Iterable[Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Mapping[Any,Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]]],TypeVar(State, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])],tuple[Union[Array,ndarray,bool,number,Iterable[Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Mapping[Any,Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]]],TypeVar(State, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])]]
- classmethod from_array_implementation(init_multiply_fn, multiply_next_fn)[source]
Construct a StreamingMatrix object from an implementation of init/next.
This class method expects the init_multiply_fn and multiply_next_fn to be defined w.r.t. a single jax.Array input. These implementations will be “lifted” to operate on pytrees of arrays.
- Parameters:
init_multiply_fn (
Callable[[Array|ShapeDtypeStruct],TypeVar(State, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])]) – a function that returns the initial state given the expected shape of inputs to each call to next_fn.multiply_next_fn (
Callable[[Array,TypeVar(State, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])],tuple[Array,TypeVar(State, bound=Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]])]]) – a function that returns (next_slice, updated_state) from (next_input, current_state).
- Return type:
- Returns:
A StreamingMatrix that operates over PyTrees of jax.Array objects.
- materialize(n)[source]
A utility method to materialize this matrix as an n x n ndarray.
Note n needs to be a parameter, because a general StreamingMatrix can represent an infinite-dimensional matrix.
NOTE: Primarily for debugging and testing implementations of init and next.
- Parameters:
n (
int) – The size of the square matrix to materialize.- Return type:
Array- Returns:
An n x n materialization of this matrix.
- row_norms_squared(n, scan_fn=<function scan>)[source]
Computes the row-wise L2^2 norm of the matrix.
Given a StreamingMatrix B = A C^{-1}, this function computes the per-query expected squared error of the factorization A = BC. The expected total squared error and the maximum expected squared error can be computed from this vector via jnp.sum and jnp.max respectively.
This function consumes an optional scan_fn argument, which is primarily useful if you need to backpropagate through this function, in which case using a checkpointed scan can be helpful to avoid OOMing on GPUs. For example, the scan_fn defined below would store 8 intermediate states of the scan in memory, rather than the default which stores the entire scan history in memory during backpropagation. The number of checkpoints determines the computation / memory tradeoff.
scan_fn = functools.partial( equinox.internal.scan, kind='checkpointed', checkpoints=8, )
- Parameters:
n (
int) – The number of rows to compute squared-norms of.scan_fn – A function with the same signature as jax.lax.scan.
- Return type:
Array- Returns:
A vector of length n containing the row-wise L2^2 norm of the matrix.
- scale_rows_and_columns(row_scale=None, col_scale=None)
Returns a new StreamingMatrix with scaled rows and/or cols.
Assumes row_scale and col_scale can be indexed into for as many outputs are generated from matrix. If jax.Array objects are used, note row_scale[i] for i > len(row_scale) will return row_scale[-1].
- Parameters:
matrix (
StreamingMatrix) – The matrix to wrap.row_scale (
Array|None) – Multipliers to apply to the rows of matrix, equivalent to jnp.diag(row_scale) @ matrix.col_scale (
Array|None) – Multipliers to apply to the columns of matrix, equivalent to matrix @ jnp.diag(col_scale).
- Return type:
- Returns:
The wrapped StreamingMatrix.