from keras_core.optimizers import optimizer from keras_core import operations as ops class SGD(optimizer.Optimizer): """Gradient descent (with momentum) optimizer. Update rule for parameter `w` with gradient `g` when `momentum` is 0: ```python w = w - learning_rate * g ``` Update rule when `momentum` is larger than 0: ```python velocity = momentum * velocity - learning_rate * g w = w + velocity ``` When `nesterov=True`, this rule becomes: ```python velocity = momentum * velocity - learning_rate * g w = w + momentum * velocity - learning_rate * g ``` Args: learning_rate: A `Tensor`, floating point value, or a schedule that is a `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that takes no arguments and returns the actual value to use. The learning rate. Defaults to 0.001. momentum: float hyperparameter >= 0 that accelerates gradient descent in the relevant direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient descent. nesterov: boolean. Whether to apply Nesterov momentum. Defaults to `False`. {{base_optimizer_keyword_args}} """ def __init__( self, learning_rate=0.01, momentum=0.0, nesterov=False, weight_decay=None, clipnorm=None, clipvalue=None, global_clipnorm=None, use_ema=False, ema_momentum=0.99, ema_overwrite_frequency=None, name="SGD", **kwargs ): super().__init__( learning_rate=learning_rate, name=name, weight_decay=weight_decay, clipnorm=clipnorm, clipvalue=clipvalue, global_clipnorm=global_clipnorm, use_ema=use_ema, ema_momentum=ema_momentum, ema_overwrite_frequency=ema_overwrite_frequency, **kwargs ) if not isinstance(momentum, float) or momentum < 0 or momentum > 1: raise ValueError("`momentum` must be a float between [0, 1].") self.momentum = momentum self.nesterov = nesterov def build(self, variables): """Initialize optimizer variables. SGD optimizer has one variable `momentums`, only set if `self.momentum` is not 0. Args: var_list: list of model variables to build SGD variables on. """ if self.built: return super().build(variables) self.momentums = [] for variable in variables: self.momentums.append( self.add_variable_from_reference( reference_variable=variable, name="m" ) ) def update_step(self, gradient, variable, learning_rate): """Update step given gradient and the associated model variable.""" learning_rate = ops.cast(learning_rate, variable.dtype) m = None momentum = ops.cast(self.momentum, variable.dtype) m = self.momentums[self._get_variable_index(variable)] if m is not None: m.assign(-gradient * learning_rate + m * momentum) if self.nesterov: variable.assign( variable - gradient * learning_rate + m * momentum ) else: variable.assign(variable + m) else: variable.assign(variable - gradient * learning_rate) def get_config(self): config = super().get_config() config.update( { "momentum": self.momentum, "nesterov": self.nesterov, } ) return config SGD.__doc__ = SGD.__doc__.replace( "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args )