from keras_core import backend from keras_core import operations as ops from keras_core.utils.tracking import Tracker from keras_core import initializers from keras_core.optimizers.schedules import learning_rate_schedule from keras_core.utils.naming import auto_name import re import warnings class Optimizer: # TODO: support jit_compile def __init__( self, learning_rate, weight_decay=None, clipnorm=None, clipvalue=None, global_clipnorm=None, use_ema=False, ema_momentum=0.99, ema_overwrite_frequency=None, name=None, ): self.name = name self.weight_decay = weight_decay self.clipnorm = clipnorm self.global_clipnorm = global_clipnorm self.clipvalue = clipvalue self.use_ema = use_ema if use_ema: # Verify the arguments related to EMA. if ema_momentum > 1 or ema_momentum < 0: raise ValueError( "`ema_momentum` must be in the range [0, 1]. " f"Received: ema_momentum={ema_momentum}" ) if ema_overwrite_frequency and ( not isinstance(ema_overwrite_frequency, int) or ema_overwrite_frequency < 1 ): raise ValueError( "`ema_overwrite_frequency` must be an integer >= 1 or None. " "Received: ema_overwrite_frequency=" f"{ema_overwrite_frequency}" ) self.ema_momentum = ema_momentum self.ema_overwrite_frequency = ema_overwrite_frequency if self.clipnorm is not None and self.global_clipnorm is not None: raise ValueError( "Only one of `clipnorm` and `global_clipnorm` can " f"be set. Received: clipnorm={self.clipnorm}, " f"global_clipnorm={self.global_clipnorm}" ) self.built = False self.iterations = backend.Variable( 0, name="iteration", dtype="int64", trainable=False ) if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule): self._learning_rate = learning_rate elif callable(learning_rate): self._learning_rate = learning_rate else: if not isinstance(learning_rate, float): raise ValueError( "Argument `learning_rate` should be float, or an instance of " "LearningRateSchedule, or a callable " "(that takes in the current iteration value " "and returns the corresponding learning rate value). Received instead: " f"learning_rate={learning_rate}" ) self._learning_rate = backend.Variable( learning_rate, name="learning_rate", dtype=backend.floatx(), trainable=False, ) self._variables = [] self._trainable_variables = [] self._tracker = Tracker( { "variables": ( lambda x: isinstance(x, backend.Variable), self._variables, ), } ) self._trainable_variables_indices = {} def build(self, variables): for i, variable in enumerate(variables): self._trainable_variables_indices[id(variable)] = i self._trainable_variables = variables[:] self.built = True @property def variables(self): return self._variables[:] def _get_variable_index(self, variable): return self._trainable_variables_indices[id(variable)] def add_variable( self, shape, initializer, dtype=None, name=None, ): self._check_super_called() if callable(initializer): value = initializer(shape=shape, dtype=dtype) else: raise ValueError(f"Invalid initializer: {initializer}") variable = backend.Variable( value=value, dtype=dtype, trainable=False, name=name, ) self._variables.append(variable) # Prevent double-tracking self._tracker.stored_ids["variables"].add(id(variable)) return variable def add_variable_from_reference(self, reference_variable, name=None): """Add an all-zeros variable with the shape and dtype of a reference variable.""" initializer = initializers.Zeros() name = name or auto_name(self.__class__.__name__) self.add_variable( shape=reference_variable.shape, initializer=initializer, dtype=reference_variable.dtype, name=name, ) def _check_variables_are_known(self, variables): for v in variables: if id(v) not in self._trainable_variables_indices: raise ValueError( f"Unknown variable: {v}. This optimizer can only " "be called for the variables it was originally built with. " "When working with a new set of variables, you should recreate " "a new optimizer instance." ) def update_step(self, gradient, variable, learning_rate): raise NotImplementedError def apply_gradients(self, grads_and_vars): grads, trainable_variables = zip(*grads_and_vars) return self.apply(grads, trainable_variables) def apply(self, grads, variables=None): """ `grads` should be a list of gradient tensors with 1:1 mapping to the list of variables the optimizer was built with. `variables` can be provided on the first call to build the optimizer. """ grads = list(grads) if len(grads) == 0: # It is possible that the grad is empty. In this case, # `apply_gradients` is a no-op. return if variables is None: if not self.built: raise ValueError( "When passing `grads` without `variables`, the optimizer " "must already be built on a list of variables. Call `optimizer.build(trainable_variables)` first. " ) if len(grads) != len(self._trainable_variables_indices): raise ValueError( "When passing `grads` as a list of gradient tensors, the gradients must " f"match `optimizer.variables` one-to-on. Received a list of {len(grads)} " f"gradients, but the optimizer is tracking {len(self._trainable_variables)} " "trainable variables." ) trainable_variables = self._trainable_variables else: trainable_variables = list(variables) # Optionally build optimizer. if not self.built: with ops.name_scope(self.name): self.build(trainable_variables) self.built = True self._check_variables_are_known(trainable_variables) grads_and_vars = list(zip(grads, self._trainable_variables)) with ops.name_scope(self.name): # Filter empty gradients. grads_and_vars = self._filter_empty_gradients(grads_and_vars) if len(list(grads_and_vars)) == 0: return # Apply clipping and weight decay. grads, trainable_variables = zip(*grads_and_vars) grads = self._clip_gradients(grads) self._apply_weight_decay(trainable_variables) # Apply gradient updates. learning_rate = self._get_current_learning_rate() for grad, var in zip(grads, trainable_variables): self.update_step(grad, var, learning_rate) self.iterations.assign(self.iterations + 1) # Apply variable constraints after applying gradients. for variable in trainable_variables: if getattr(variable, "constraint", None) is not None: variable.assign(variable.constraint(variable)) def stateless_apply(self, grads, optimizer_variables): self._check_super_called() if not self.built: raise ValueError( "To call stateless_apply_gradients, {self.__class__.__name__} " "must be built (i.e. its variables must have been already created). " "You can build it via `optimizer.build(trainable_variables)`." ) if len(optimizer_variables) != len(self.variables): raise ValueError( "Argument `optimizer_variables` must be a list of tensors " f"corresponding 1:1 to {self.__class__.__name__}().variables. " f"Received list with length {len(optimizer_variables)}, but expected " f"{len(self.variables)} variables." ) # Gather variable mapping mapping = list(zip(self.variables, optimizer_variables)) # Call in stateless scope with backend.StatelessScope(state_mapping=mapping) as scope: self.apply(grads) # Gather updated variables trainable_variables = [] for v in self._trainable_variables: new_v = scope.get_current_value(v) if new_v is not None: trainable_variables.append(new_v) else: trainable_variables.append(v) optimizer_variables = [] for v in self.variables: new_v = scope.get_current_value(v) if new_v is not None: optimizer_variables.append(new_v) else: optimizer_variables.append(v) return trainable_variables, optimizer_variables @property def learning_rate(self): return self._get_current_learning_rate() def _get_current_learning_rate(self): if isinstance(self._learning_rate, learning_rate_schedule.LearningRateSchedule): return self._learning_rate(self.iterations) elif callable(self._learning_rate): return self._learning_rate(self.iterations) return self._learning_rate def _filter_empty_gradients(self, grads_and_vars): filtered = [(g, v) for g, v in grads_and_vars if g is not None] if not filtered: raise ValueError(f"No gradients provided for any variable.") if len(filtered) < len(grads_and_vars): missing_grad_vars = [v for g, v in grads_and_vars if g is None] warnings.warn( "Gradients do not exist for variables " f"{[v.name for v in missing_grad_vars]} when minimizing the " "loss. If you're using `model.compile()`, did you forget to " "provide a `loss` argument?" ) return filtered def _clip_gradients(self, grads): if self.clipnorm and self.clipnorm > 0: raise NotImplementedError # TODO # clipped_grads = [] # for g in grads: # if g is None: # clipped_grads.append(g) # else: # clipped_grads.append(tf.clip_by_norm(g, self.clipnorm)) # return clipped_grads if self.global_clipnorm and self.global_clipnorm > 0: raise NotImplementedError # TODO # return tf.clip_by_global_norm(grads, self.global_clipnorm)[0] if self.clipvalue and self.clipvalue > 0: raise NotImplementedError # TODO # clipped_grads = [] # for g in grads: # if g is None: # clipped_grads.append(g) # else: # clipped_grads.append( # tf.clip_by_value( # g, # clip_value_min=-self.clipvalue, # clip_value_max=self.clipvalue, # ) # ) # return clipped_grads return grads def exclude_from_weight_decay(self, var_list=None, var_names=None): """Exclude variables from weight decay. This method must be called before the optimizer's `build` method is called. You can set specific variables to exclude out, or set a list of strings as the anchor words, if any of which appear in a variable's name, then the variable is excluded. Args: var_list: A list of `tf.Variable`s to exclude from weight decay. var_names: A list of strings. If any string in `var_names` appear in the model variable's name, then this model variable is excluded from weight decay. For example, `var_names=['bias']` excludes all bias variables from weight decay. """ if hasattr(self, "_built") and self._built: raise ValueError( "`exclude_from_weight_decay()` can only be configued before " "the optimizer is built." ) if var_list: self._exclude_from_weight_decay = [id(variable) for variable in var_list] else: self._exclude_from_weight_decay = [] self._exclude_from_weight_decay_names = var_names or [] def _use_weight_decay(self, variable): exclude_from_weight_decay = getattr(self, "_exclude_from_weight_decay", []) exclude_from_weight_decay_names = getattr( self, "_exclude_from_weight_decay_names", [] ) variable_id = id(variable) for exclude_id in exclude_from_weight_decay: if variable_id == exclude_id: return False for name in exclude_from_weight_decay_names: if re.search(name, variable.name) is not None: return False return True def _apply_weight_decay(self, variables): if self.weight_decay is None: return for variable in variables: if self._use_weight_decay(variable): lr = ops.cast(self._get_current_learning_rate(), variable.dtype) wd = ops.cast(self.weight_decay, variable.dtype) variable.assign(variable - variable * wd * lr) def _check_super_called(self): if not hasattr(self, "_tracker"): raise RuntimeError( f"In optimizer '{self.__class__.__name__}', you forgot to call " "`super().__init__()` in the `__init__()` method. " "Go add it!" ) def _update_model_variables_moving_average(self, var_list): """Update the stored moving average using the latest value.""" if self.use_ema: for var, average in zip(var_list, self._model_variables_moving_average): average.assign( self.ema_momentum * average + (1 - self.ema_momentum) * var ) def _overwrite_model_variables_with_average_value(self, var_list): """Overwrite model variables with its moving average.""" if len(var_list) != len(self._model_variables_moving_average): raise ValueError( f"The length of model variables ({len(var_list)}) to " "override does not match the length of model variables " "stored in the optimizer " f"({len(self._model_variables_moving_average)}). Please " "check if the optimizer was called on your model." ) self._overwrite_model_variables_with_average_value_helper(var_list) def _overwrite_model_variables_with_average_value_helper(self, var_list): """Helper function that overwrites model variables.""" for var, average_var in zip(var_list, self._model_variables_moving_average): var.assign(average_var) def finalize_variable_values(self, var_list): """Set the final value of model's trainable variables. Sometimes there are some extra steps before ending the variable updates, such as overriding the model variables with its average value. Args: var_list: list of model variables. """ if self.use_ema: # If the optimizer uses EMA, then when finalizing, we replace the # model variable value with its moving average stored inside # optimizer. self._overwrite_model_variables_with_average_value(var_list) def get_config(self): """Returns the config of the optimizer. An optimizer config is a Python dictionary (serializable) containing the configuration of an optimizer. The same optimizer can be reinstantiated later (without any saved state) from this configuration. Subclass optimizer should override this method to include other hyperparameters. Returns: Python dictionary. """ if isinstance(self._learning_rate, learning_rate_schedule.LearningRateSchedule): learning_rate = learning_rate_schedule.serialize(self._learning_rate) elif isinstance(self._learning_rate, backend.Variable): learning_rate = float(self._learning_rate.numpy()) elif ops.is_tensor(self._learning_rate): learning_rate = float(self._learning_rate) elif callable(self._learning_rate): # TODO: serialize custom object learning_rate = self._learning_rate config = { "name": self.name, "learning_rate": learning_rate, "weight_decay": self.weight_decay, "clipnorm": self.clipnorm, "global_clipnorm": self.global_clipnorm, "clipvalue": self.clipvalue, "use_ema": self.use_ema, "ema_momentum": self.ema_momentum, "ema_overwrite_frequency": self.ema_overwrite_frequency, "jit_compile": self.jit_compile, } return config @classmethod def from_config(cls, config, custom_objects=None): """Creates an optimizer from its config. This method is the reverse of `get_config`, capable of instantiating the same optimizer from the config dictionary. Args: config: A Python dictionary, typically the output of get_config. custom_objects: A Python dictionary mapping names to additional user-defined Python objects needed to recreate this optimizer. Returns: An optimizer instance. """ if "learning_rate" in config: if isinstance(config["learning_rate"], dict): config["learning_rate"] = learning_rate_schedule.deserialize( config["learning_rate"], custom_objects=custom_objects ) return cls(**config) base_optimizer_keyword_args = """name: String. The name to use for momentum accumulator weights created by the optimizer. weight_decay: Float, defaults to None. If set, weight decay is applied. clipnorm: Float. If set, the gradient of each weight is individually clipped so that its norm is no higher than this value. clipvalue: Float. If set, the gradient of each weight is clipped to be no higher than this value. global_clipnorm: Float. If set, the gradient of all weights is clipped so that their global norm is no higher than this value. use_ema: Boolean, defaults to False. If True, exponential moving average (EMA) is applied. EMA consists of computing an exponential moving average of the weights of the model (as the weight values change after each training batch), and periodically overwriting the weights with their moving average. ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`. This is the momentum to use when computing the EMA of the model's weights: `new_average = ema_momentum * old_average + (1 - ema_momentum) * current_variable_value`. ema_overwrite_frequency: Int or None, defaults to None. Only used if `use_ema=True`. Every `ema_overwrite_frequency` steps of iterations, we overwrite the model variable by its moving average. If None, the optimizer does not overwrite model variables in the middle of training, and you need to explicitly overwrite the variables at the end of training by calling `optimizer.finalize_variable_values()` (which updates the model variables in-place). When using the built-in `fit()` training loop, this happens automatically after the last epoch, and you don't need to do anything."""