jax_privacy.matrix_factorization.streaming_matrix.scale_rows_and_columns

jax_privacy.matrix_factorization.streaming_matrix.scale_rows_and_columns(matrix, row_scale=None, col_scale=None)[source]

Returns a new StreamingMatrix with scaled rows and/or cols.

Assumes row_scale and col_scale can be indexed into for as many outputs are generated from matrix. If jax.Array objects are used, note row_scale[i] for i > len(row_scale) will return row_scale[-1].

Parameters:
  • matrix (StreamingMatrix) – The matrix to wrap.

  • row_scale (Array | None) – Multipliers to apply to the rows of matrix, equivalent to jnp.diag(row_scale) @ matrix.

  • col_scale (Array | None) – Multipliers to apply to the columns of matrix, equivalent to matrix @ jnp.diag(col_scale).

Return type:

StreamingMatrix

Returns:

The wrapped StreamingMatrix.