jax_privacy.matrix_factorization.streaming_matrix.scale_rows_and_columns
- jax_privacy.matrix_factorization.streaming_matrix.scale_rows_and_columns(matrix, row_scale=None, col_scale=None)[source]
Returns a new StreamingMatrix with scaled rows and/or cols.
Assumes row_scale and col_scale can be indexed into for as many outputs are generated from matrix. If jax.Array objects are used, note row_scale[i] for i > len(row_scale) will return row_scale[-1].
- Parameters:
matrix (
StreamingMatrix) – The matrix to wrap.row_scale (
Array|None) – Multipliers to apply to the rows of matrix, equivalent to jnp.diag(row_scale) @ matrix.col_scale (
Array|None) – Multipliers to apply to the columns of matrix, equivalent to matrix @ jnp.diag(col_scale).
- Return type:
- Returns:
The wrapped StreamingMatrix.