Source code for f3dasm_optimize._src.optax_optimizers

#                                                                       Modules
# =============================================================================

# Third-party
from typing import Optional

import optax
from f3dasm import Block

# Local
from .adapters.optax_implementations import OptaxOptimizer

#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Martin van der Schelling (M.P.vanderSchelling@tudelft.nl)'
__credits__ = ['Martin van der Schelling']
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================


[docs]def adam(learning_rate: float = 0.001, beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-07, eps_root: float = 0.0, seed: Optional[int] = None, **kwargs) -> Block: """ Adam optimizer. Adapted from the Optax library. Parameters ---------- learning_rate : float, optional The learning rate, by default 0.001. beta_1 : float, optional Exponential decay rate for the first moment estimates, by default 0.9. beta_2 : float, optional Exponential decay rate for the second moment estimates, by default 0.999. epsilon : float, optional A small constant for numerical stability, by default 1e-07. eps_root : float, optional A small constant for numerical stability, by default 0.0. seed : int, optional Random seed, by default None. Returns ------- Optimizer Optimizer object. """ return OptaxOptimizer( algorithm_cls=optax.adam, seed=seed, learning_rate=learning_rate, b1=beta_1, b2=beta_2, eps=epsilon, eps_root=eps_root, **kwargs )
# =============================================================================
[docs]def sgd(learning_rate: float = 0.01, momentum: float = 0.0, nesterov: bool = False, seed: Optional[int] = None, **kwargs ) -> Block: """ Stochastic Gradient Descent (SGD) optimizer. Adapted from the Optax library. Parameters ---------- learning_rate : float, optional The learning rate, by default 0.01. momentum : float, optional Momentum parameter, by default 0.0. nesterov : bool, optional Use Nesterov momentum, by default False. seed : int, optional Random seed, by default None. Returns ------- Optimizer Optimizer object. """ return OptaxOptimizer( algorithm_cls=optax.sgd, seed=seed, learning_rate=learning_rate, momentum=momentum, nesterov=nesterov, **kwargs )
# =============================================================================