jax_privacy.matrix_factorization.optimization

Simple wrapper around optax to be used for strategy optimization.

Functions

jax_enable_x64(fn)

Decorator to enable x64 precision for a function.

optimize(*args, **kwargs)

Classes

CallbackArgs(step, loss, grad, params, state)

Information passed to the callback function on each optimization step.