jax_privacy.matrix_factorization.optimization.CallbackArgs
- class jax_privacy.matrix_factorization.optimization.CallbackArgs(step, loss, grad, params, state)[source]
Bases:
objectInformation passed to the callback function on each optimization step.
- Properties:
step: The current optimization step. loss: The loss value at the current step. grad: The gradient at the current step. params: The current parameters. state: The current optimizer state.
Methods
__init__Attributes
-
step:
int
-
loss:
Array
-
grad:
Union[Array,ndarray,bool,number,Iterable[Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Mapping[Any,Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],None]
-
params:
Union[Array,ndarray,bool,number,Iterable[Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Mapping[Any,Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]]]
-
state:
Any