jax_privacy.matrix_factorization.optimization.CallbackArgs

class jax_privacy.matrix_factorization.optimization.CallbackArgs(step, loss, grad, params, state)[source]

Bases: object

Information passed to the callback function on each optimization step.

Properties:

step: The current optimization step. loss: The loss value at the current step. grad: The gradient at the current step. params: The current parameters. state: The current optimizer state.

Methods

__init__

Attributes

step

loss

grad

params

state

step: int
loss: Array
grad: Union[Array, ndarray, bool, number, Iterable[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], None]
params: Union[Array, ndarray, bool, number, Iterable[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Mapping[Any, Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]
state: Any