jax_privacy.matrix_factorization.streaming_matrix.StreamingMatrix

class jax_privacy.matrix_factorization.streaming_matrix.StreamingMatrix(init_multiply, multiply_next)[source]

Bases: Generic[State]

A linear mapping x -> A x for a lower-triangular (streaming) A matrix.

Via the attributes / member functions init_multiply and multiply_next, this class allows you to efficiently compute a linear mapping x -> A x in streaming fashion (one element at a time). The precise meaning of the term efficiently is implementation-dependent, with examples including constant memory overhead, and / or without fully materializing A or x.

Example Usage:
>>> A = prefix_sum()
>>> x = jnp.arange(1, 5).astype(float)
>>> slices = []
>>> state = A.init_multiply(x[0])
>>> for i in range(len(x)):
...   result_slice, state = A.multiply_next(x[i], state)
...   slices.append(result_slice)
>>> Ax = jnp.stack(slices)
>>> print(Ax)
[ 1.  3.  6. 10.]
>>> print(jnp.cumsum(x))
[ 1.  3.  6. 10.]

See the constructor docstring for a full description of init_multiply and multiply_next.

Importantly, this design encodes the fact that Ax[i] may only depend on x[i] and state captured from computing Ax[0], …, Ax[i-1]. This is equivalent to A having a lower-triangular matrix representation in the standard basis.

In general, A and x may both be infinite; thus we sidestep the question of how many elements of A x one wishes to compute by assuming the user provides a range.

Variables:
  • init_multiply – A function that returns the initial state given the expected shape of inputs to each call to multiply_next.

  • multiply_next – A function that returns (next_slice, updated_state) from (next_input, current_state).

Methods

__init__

from_array_implementation

Construct a StreamingMatrix object from an implementation of init/next.

materialize

A utility method to materialize this matrix as an n x n ndarray.

row_norms_squared

Computes the row-wise L2^2 norm of the matrix.

scale_rows_and_columns

Returns a new StreamingMatrix with scaled rows and/or cols.

Attributes

init_multiply

multiply_next

init_multiply: Callable[[Union[Array, ndarray, bool, number, Iterable[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]], TypeVar(State, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])]
multiply_next: Callable[[Union[Array, ndarray, bool, number, Iterable[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]], TypeVar(State, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])], tuple[Union[Array, ndarray, bool, number, Iterable[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]], TypeVar(State, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])]]
classmethod from_array_implementation(init_multiply_fn, multiply_next_fn)[source]

Construct a StreamingMatrix object from an implementation of init/next.

This class method expects the init_multiply_fn and multiply_next_fn to be defined w.r.t. a single jax.Array input. These implementations will be “lifted” to operate on pytrees of arrays.

Parameters:
  • init_multiply_fn (Callable[[Array | ShapeDtypeStruct], TypeVar(State, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])]) – a function that returns the initial state given the expected shape of inputs to each call to next_fn.

  • multiply_next_fn (Callable[[Array, TypeVar(State, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])], tuple[Array, TypeVar(State, bound= Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])]]) – a function that returns (next_slice, updated_state) from (next_input, current_state).

Return type:

StreamingMatrix

Returns:

A StreamingMatrix that operates over PyTrees of jax.Array objects.

materialize(n)[source]

A utility method to materialize this matrix as an n x n ndarray.

Note n needs to be a parameter, because a general StreamingMatrix can represent an infinite-dimensional matrix.

NOTE: Primarily for debugging and testing implementations of init and next.

Parameters:

n (int) – The size of the square matrix to materialize.

Return type:

Array

Returns:

An n x n materialization of this matrix.

row_norms_squared(n, scan_fn=<function scan>)[source]

Computes the row-wise L2^2 norm of the matrix.

Given a StreamingMatrix B = A C^{-1}, this function computes the per-query expected squared error of the factorization A = BC. The expected total squared error and the maximum expected squared error can be computed from this vector via jnp.sum and jnp.max respectively.

This function consumes an optional scan_fn argument, which is primarily useful if you need to backpropagate through this function, in which case using a checkpointed scan can be helpful to avoid OOMing on GPUs. For example, the scan_fn defined below would store 8 intermediate states of the scan in memory, rather than the default which stores the entire scan history in memory during backpropagation. The number of checkpoints determines the computation / memory tradeoff.

scan_fn = functools.partial(
  equinox.internal.scan,
  kind='checkpointed',
  checkpoints=8,
)
Parameters:
  • n (int) – The number of rows to compute squared-norms of.

  • scan_fn – A function with the same signature as jax.lax.scan.

Return type:

Array

Returns:

A vector of length n containing the row-wise L2^2 norm of the matrix.

scale_rows_and_columns(row_scale=None, col_scale=None)

Returns a new StreamingMatrix with scaled rows and/or cols.

Assumes row_scale and col_scale can be indexed into for as many outputs are generated from matrix. If jax.Array objects are used, note row_scale[i] for i > len(row_scale) will return row_scale[-1].

Parameters:
  • matrix (StreamingMatrix) – The matrix to wrap.

  • row_scale (Array | None) – Multipliers to apply to the rows of matrix, equivalent to jnp.diag(row_scale) @ matrix.

  • col_scale (Array | None) – Multipliers to apply to the columns of matrix, equivalent to matrix @ jnp.diag(col_scale).

Return type:

StreamingMatrix

Returns:

The wrapped StreamingMatrix.