# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of streamin matrix interface."""
from __future__ import annotations
from collections.abc import Callable
import dataclasses
from typing import Any, Generic, TypeAlias, TypeVar
import chex
import jax
from jax import numpy as jnp
State = TypeVar('State', bound=chex.ArrayTree)
Shape: TypeAlias = tuple[int, ...]
ShapePyTree = Any
# Disabling pylint invalid-name to allow mathematical notation including
# single-capital-letter variables for matrices.
# See README.md for notation conventions.
# pylint:disable=invalid-name
[docs]
@dataclasses.dataclass(frozen=True)
class StreamingMatrix(Generic[State]):
"""A linear mapping x -> A x for a lower-triangular (streaming) A matrix.
Via the attributes / member functions `init_multiply` and `multiply_next`,
this class allows you to efficiently compute a linear mapping x -> A x
in streaming fashion (one element at a time). The precise meaning of the term
`efficiently` is implementation-dependent, with examples including constant
memory overhead, and / or without fully materializing A or x.
Example Usage:
>>> A = prefix_sum()
>>> x = jnp.arange(1, 5).astype(float)
>>> slices = []
>>> state = A.init_multiply(x[0])
>>> for i in range(len(x)):
... result_slice, state = A.multiply_next(x[i], state)
... slices.append(result_slice)
>>> Ax = jnp.stack(slices)
>>> print(Ax)
[ 1. 3. 6. 10.]
>>> print(jnp.cumsum(x))
[ 1. 3. 6. 10.]
See the constructor docstring for a full description of `init_multiply` and
`multiply_next`.
Importantly, this design encodes the fact that Ax[i] may only depend on
x[i] and state captured from computing Ax[0], ..., Ax[i-1]. This is equivalent
to `A` having a lower-triangular matrix representation in the standard basis.
In general, `A` and `x` may both be infinite; thus we sidestep the
question of how many elements of `A x` one wishes to compute by assuming the
user provides a range.
Attributes:
init_multiply: A function that returns the initial state given the expected
shape of inputs to each call to multiply_next.
multiply_next: A function that returns (next_slice, updated_state) from
(next_input, current_state).
"""
init_multiply: Callable[[chex.ArrayTree], State]
multiply_next: Callable[[chex.ArrayTree, State], tuple[chex.ArrayTree, State]]
[docs]
@classmethod
def from_array_implementation(
cls,
init_multiply_fn: Callable[[jax.Array | jax.ShapeDtypeStruct], State],
multiply_next_fn: Callable[[jax.Array, State], tuple[jax.Array, State]],
) -> StreamingMatrix:
"""Construct a StreamingMatrix object from an implementation of init/next.
This class method expects the `init_multiply_fn` and `multiply_next_fn` to
be defined w.r.t. a single `jax.Array` input. These implementations will
be "lifted" to operate on pytrees of arrays.
Args:
init_multiply_fn: a function that returns the initial state given the
expected shape of inputs to each call to next_fn.
multiply_next_fn: a function that returns (next_slice, updated_state) from
(next_input, current_state).
Returns:
A StreamingMatrix that operates over PyTrees of `jax.Array` objects.
"""
def tree_unzip(tree, treedef):
leaves = treedef.flatten_up_to(tree)
return tuple(treedef.unflatten(x) for x in zip(*leaves))
def lifted_init(abstract_value):
return jax.tree.map(init_multiply_fn, abstract_value)
def lifted_next(value, state):
return tree_unzip(
jax.tree.map(multiply_next_fn, value, state),
jax.tree.structure(value),
)
return cls(lifted_init, lifted_next)
[docs]
def materialize(self, n: int) -> jax.Array:
"""A utility method to materialize this matrix as an n x n ndarray.
Note `n` needs to be a parameter, because a general `StreamingMatrix`
can represent an infinite-dimensional matrix.
NOTE: Primarily for debugging and testing implementations of init and next.
Args:
n: The size of the square matrix to materialize.
Returns:
An n x n materialization of this matrix.
"""
return self @ jnp.eye(n)
[docs]
def row_norms_squared(self, n: int, scan_fn=jax.lax.scan) -> jax.Array:
"""Computes the row-wise L2^2 norm of the matrix.
Given a StreamingMatrix B = A C^{-1}, this function computes the per-query
expected squared error of the factorization A = BC. The expected total
squared error and the maximum expected squared error can be computed from
this vector via jnp.sum and jnp.max respectively.
This function consumes an optional scan_fn argument, which is primarily
useful if you need to backpropagate through this function, in which case
using a checkpointed scan can be helpful to avoid OOMing on GPUs.
For example, the scan_fn defined below would store 8 intermediate states
of the scan in memory, rather than the default which stores the entire
scan history in memory during backpropagation. The number of checkpoints
determines the computation / memory tradeoff.
::
scan_fn = functools.partial(
equinox.internal.scan,
kind='checkpointed',
checkpoints=8,
)
Args:
n: The number of rows to compute squared-norms of.
scan_fn: A function with the same signature as jax.lax.scan.
Returns:
A vector of length n containing the row-wise L2^2 norm of the matrix.
"""
zero = jnp.zeros(n)
def next_state_and_row_norm(state, i):
# Note: state is first for use in scan_fn.
ei = zero.at[i].set(1)
row, state = self.multiply_next(ei, state)
return state, row @ row
return scan_fn(
next_state_and_row_norm, self.init_multiply(zero), jnp.arange(n)
)[1]
# TODO: b/329444015 - Consider making protected and updating callsites
# to call the member-function directly.
[docs]
def scale_rows_and_columns(
matrix: StreamingMatrix,
row_scale: jax.Array | None = None,
col_scale: jax.Array | None = None,
) -> StreamingMatrix:
"""Returns a new `StreamingMatrix` with scaled rows and/or cols.
Assumes row_scale and col_scale can be indexed into for as many outputs
are generated from matrix. If `jax.Array` objects are used, note
row_scale[i] for i > len(row_scale) will return row_scale[-1].
Args:
matrix: The matrix to wrap.
row_scale: Multipliers to apply to the rows of `matrix`, equivalent to
jnp.diag(row_scale) @ matrix.
col_scale: Multipliers to apply to the columns of `matrix`, equivalent to
matrix @ jnp.diag(col_scale).
Returns:
The wrapped `StreamingMatrix`.
"""
result = matrix
if row_scale is not None:
result = multiply_streaming_matrices(diagonal(row_scale), result)
if col_scale is not None:
result = multiply_streaming_matrices(result, diagonal(col_scale))
return result
[docs]
def multiply_array(A: StreamingMatrix, x: jax.Array) -> jax.Array:
"""Computes the matrix-vector product A x."""
# Reverse (value, state) -> (state, value) for scan.
def f(state, value):
return A.multiply_next(value, state)[::-1]
return jax.lax.scan(f, A.init_multiply(x[0]), x)[1]
# TODO: b/329444015 - Consider making protected and updating callsites
# to call the member-function directly.
[docs]
def multiply_streaming_matrices(
A: StreamingMatrix,
B: StreamingMatrix,
) -> StreamingMatrix:
"""Multiply a StreamingMatrix by another StreamingMatrix.
Args:
A: The left hand side matrix
B: The right hand side matrix
Returns:
A B, represented as another StreamingMatrix.
"""
def init_multiply(abstract_value):
return A.init_multiply(abstract_value), B.init_multiply(abstract_value)
def multiply_next(value, state):
A_state, B_state = state
inner, B_state = B.multiply_next(value, B_state)
outer, A_state = A.multiply_next(inner, A_state)
return outer, (A_state, B_state)
return StreamingMatrix(init_multiply, multiply_next)
[docs]
def identity() -> StreamingMatrix:
"""An implicit representation of the identity matrix."""
return StreamingMatrix(lambda _: (), lambda value, _: (value, ()))
[docs]
def prefix_sum() -> StreamingMatrix:
"""An implicit representation of the lower triangular matrix of ones."""
def init_multiply(abstract_value):
return jnp.zeros_like(abstract_value)
def multiply_next(state, value):
result = state + value
return result, result
return StreamingMatrix.from_array_implementation(init_multiply, multiply_next)
[docs]
def diagonal(diag: jax.Array) -> StreamingMatrix:
"""An implicit representation of a diagonal matrix.
The returned StreamingMatrix represents an infinitely large diagonal matrix.
The diagonal elements are taken from the provided array `diag` up to row
n = diag.size, and is equal to diag[-1] beyond that point.
Args:
diag: A 1D array of diagonal elements.
Returns:
A StreamingMatrix representing the corresponding diagonal matrix.
"""
return StreamingMatrix.from_array_implementation(
lambda _: jnp.array(0),
lambda value, i: (value * diag.at[i].get(mode='clip'), i + 1),
)
[docs]
def momentum_sgd_matrix(
momentum: float = 0, learning_rates: jax.Array | None = None
) -> StreamingMatrix:
"""An implicit representation of the momentum sgd matrix."""
lr_sched = jnp.ones(1) if learning_rates is None else learning_rates
if lr_sched.min() <= 0.0:
raise ValueError(
'Learning rates must be positive (zero learning rates may prevent '
f'matrix factorization from succeeding.) Found {learning_rates}'
)
def init_multiply(abstract_value):
dtype = jnp.promote_types(abstract_value.dtype, lr_sched.dtype)
zero = jnp.zeros_like(abstract_value, dtype=dtype)
return jnp.array(0), zero, zero
def multiply_next(
value: jax.Array, state: tuple[int, jax.Array, jax.Array]
) -> tuple[jax.Array, tuple[int, jax.Array, jax.Array]]:
index, momentum_buf, result = state
momentum_buf = momentum * momentum_buf + value
# If index is out-of-bounds, return the last value in the array.
result = result + lr_sched.at[index].get(mode='clip') * momentum_buf
updated_state = (index + 1, momentum_buf, result)
return result, updated_state
return StreamingMatrix.from_array_implementation(init_multiply, multiply_next)
T = TypeVar('T', StreamingMatrix, jax.Array)
def _multiply_any(self: StreamingMatrix, other: T) -> T:
"""Multiply a StreamingMatrix by an array or a StreamingMatrix."""
if isinstance(other, StreamingMatrix):
return multiply_streaming_matrices(self, other)
elif isinstance(other, jax.Array):
return multiply_array(self, other)
else:
raise ValueError(f'Unsupported type for multiplication: {type(other)}')
def _multiply_scalar(self: StreamingMatrix, other: float) -> StreamingMatrix:
"""Multiply a StreamingMatrix by a scalar."""
return multiply_streaming_matrices(self, diagonal(jnp.array([other])))
# Add some syntax sugar.
StreamingMatrix.__matmul__ = _multiply_any
StreamingMatrix.__mul__ = _multiply_scalar
StreamingMatrix.scale_rows_and_columns = scale_rows_and_columns