keras/keras_core/optimizers/optimizer.py

564 lines
22 KiB
Python
Raw Normal View History

import re
import warnings
import numpy as np
2023-04-09 19:21:45 +00:00
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
2023-04-09 19:21:45 +00:00
from keras_core.optimizers.schedules import learning_rate_schedule
from keras_core.utils.naming import auto_name
from keras_core.utils.tracking import Tracker
2023-04-09 19:21:45 +00:00
2023-04-09 19:53:37 +00:00
@keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])
2023-04-09 19:21:45 +00:00
class Optimizer:
def __init__(
self,
learning_rate,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name=None,
):
self.name = name
self.weight_decay = weight_decay
self.clipnorm = clipnorm
self.global_clipnorm = global_clipnorm
self.clipvalue = clipvalue
self.use_ema = use_ema
if use_ema:
# Verify the arguments related to EMA.
if ema_momentum > 1 or ema_momentum < 0:
raise ValueError(
"`ema_momentum` must be in the range [0, 1]. "
f"Received: ema_momentum={ema_momentum}"
)
if ema_overwrite_frequency and (
not isinstance(ema_overwrite_frequency, int)
or ema_overwrite_frequency < 1
):
raise ValueError(
2023-04-27 03:42:23 +00:00
"`ema_overwrite_frequency` must be an integer >= 1 or "
"None. Received: ema_overwrite_frequency="
2023-04-09 19:21:45 +00:00
f"{ema_overwrite_frequency}"
)
self.ema_momentum = ema_momentum
self.ema_overwrite_frequency = ema_overwrite_frequency
if self.clipnorm is not None and self.global_clipnorm is not None:
raise ValueError(
"Only one of `clipnorm` and `global_clipnorm` can "
f"be set. Received: clipnorm={self.clipnorm}, "
f"global_clipnorm={self.global_clipnorm}"
)
self.built = False
2023-04-26 17:29:40 +00:00
# NOTE: the below will not work in a stateless scope.
# optimizers should be created outside of any stateless scope,
# at this time.
2023-04-09 19:21:45 +00:00
self.iterations = backend.Variable(
2023-04-18 04:26:04 +00:00
0, name="iteration", dtype="int32", trainable=False
2023-04-09 19:21:45 +00:00
)
if isinstance(
learning_rate, learning_rate_schedule.LearningRateSchedule
):
2023-04-09 19:21:45 +00:00
self._learning_rate = learning_rate
elif callable(learning_rate):
self._learning_rate = learning_rate
else:
if not isinstance(learning_rate, float):
raise ValueError(
2023-04-27 03:42:23 +00:00
"Argument `learning_rate` should be float, or an instance "
"of LearningRateSchedule, or a callable "
2023-04-09 19:21:45 +00:00
"(that takes in the current iteration value "
2023-04-27 03:42:23 +00:00
"and returns the corresponding learning rate value). "
f"Received instead: learning_rate={learning_rate}"
2023-04-09 19:21:45 +00:00
)
self._learning_rate = backend.Variable(
learning_rate,
name="learning_rate",
dtype=backend.floatx(),
trainable=False,
)
self._variables = []
self._trainable_variables = []
self._tracker = Tracker(
{
"variables": (
lambda x: isinstance(x, backend.Variable),
self._variables,
),
}
)
self._trainable_variables_indices = {}
def build(self, variables):
for i, variable in enumerate(variables):
self._trainable_variables_indices[id(variable)] = i
self._trainable_variables = variables[:]
self.built = True
@property
def variables(self):
return self._variables[:]
def _get_variable_index(self, variable):
return self._trainable_variables_indices[id(variable)]
def add_variable(
self,
shape,
initializer,
dtype=None,
name=None,
):
self._check_super_called()
2023-04-26 17:29:40 +00:00
initializer = initializers.get(initializer)
2023-04-09 19:21:45 +00:00
variable = backend.Variable(
2023-04-26 17:29:40 +00:00
initializer=initializer,
shape=shape,
2023-04-09 19:21:45 +00:00
dtype=dtype,
trainable=False,
name=name,
)
self._variables.append(variable)
# Prevent double-tracking
self._tracker.stored_ids["variables"].add(id(variable))
return variable
def add_variable_from_reference(self, reference_variable, name=None):
2023-04-27 03:42:23 +00:00
"""Add an all-zeros variable with the shape and dtype of a reference
variable.
"""
2023-04-09 19:21:45 +00:00
initializer = initializers.Zeros()
name = name or auto_name(self.__class__.__name__)
2023-04-24 23:45:03 +00:00
return self.add_variable(
2023-04-09 19:21:45 +00:00
shape=reference_variable.shape,
initializer=initializer,
dtype=reference_variable.dtype,
name=name,
)
def _check_variables_are_known(self, variables):
for v in variables:
if id(v) not in self._trainable_variables_indices:
raise ValueError(
f"Unknown variable: {v}. This optimizer can only "
"be called for the variables it was originally built with. "
2023-04-27 03:42:23 +00:00
"When working with a new set of variables, you should "
"recreate a new optimizer instance."
2023-04-09 19:21:45 +00:00
)
def update_step(self, gradient, variable, learning_rate):
raise NotImplementedError
def apply_gradients(self, grads_and_vars):
grads, trainable_variables = zip(*grads_and_vars)
return self.apply(grads, trainable_variables)
def apply(self, grads, variables=None):
"""
`grads` should be a list of gradient tensors
with 1:1 mapping to the list of variables the optimizer was built with.
`variables` can be provided on the first call to build the optimizer.
"""
grads = list(grads)
if len(grads) == 0:
# It is possible that the grad is empty. In this case,
# `apply_gradients` is a no-op.
return
if variables is None:
if not self.built:
raise ValueError(
"When passing `grads` without `variables`, the optimizer "
2023-04-27 03:42:23 +00:00
"must already be built on a list of variables. "
"Call `optimizer.build(trainable_variables)` first. "
2023-04-09 19:21:45 +00:00
)
if len(grads) != len(self._trainable_variables_indices):
raise ValueError(
2023-04-27 03:42:23 +00:00
"When passing `grads` as a list of gradient tensors, the "
f"gradients must match `optimizer.variables` one-to-on. "
f"Received a list of {len(grads)} gradients, but the "
f"optimizer is tracking {len(self._trainable_variables)} "
2023-04-09 19:21:45 +00:00
"trainable variables."
)
trainable_variables = self._trainable_variables
else:
trainable_variables = list(variables)
# Optionally build optimizer.
if not self.built:
with ops.name_scope(self.name):
self.build(trainable_variables)
self.built = True
self._check_variables_are_known(trainable_variables)
grads_and_vars = list(zip(grads, self._trainable_variables))
with ops.name_scope(self.name):
# Filter empty gradients.
grads_and_vars = self._filter_empty_gradients(grads_and_vars)
if len(list(grads_and_vars)) == 0:
return
# Apply clipping and weight decay.
grads, trainable_variables = zip(*grads_and_vars)
grads = self._clip_gradients(grads)
self._apply_weight_decay(trainable_variables)
# Apply gradient updates.
learning_rate = self._get_current_learning_rate()
for grad, var in zip(grads, trainable_variables):
self.update_step(grad, var, learning_rate)
self.iterations.assign(self.iterations + 1)
# Apply variable constraints after applying gradients.
for variable in trainable_variables:
if getattr(variable, "constraint", None) is not None:
variable.assign(variable.constraint(variable))
2023-04-18 21:49:38 +00:00
def stateless_apply(self, grads, trainable_variables, optimizer_variables):
2023-04-09 19:21:45 +00:00
self._check_super_called()
if not self.built:
raise ValueError(
"To call stateless_apply_gradients, {self.__class__.__name__} "
2023-04-27 03:42:23 +00:00
"must be built (i.e. its variables must have been created). "
2023-04-09 19:21:45 +00:00
"You can build it via `optimizer.build(trainable_variables)`."
)
if len(optimizer_variables) != len(self.variables):
raise ValueError(
"Argument `optimizer_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
2023-04-27 03:42:23 +00:00
f"Received list with length {len(optimizer_variables)}, but "
f"expected {len(self.variables)} variables."
2023-04-09 19:21:45 +00:00
)
2023-04-18 21:49:38 +00:00
if len(trainable_variables) != len(self._trainable_variables):
raise ValueError(
"Argument `optimizer_variables` must be a list of tensors "
"corresponding 1:1 to the trainable variables list that "
2023-04-27 03:42:23 +00:00
"the optimizer was built with. Received "
f"len(trainable_variables) == {len(trainable_variables)} "
2023-04-18 21:49:38 +00:00
"whereas the optimizer was built with "
f"{len(self._trainable_variables)} variables."
)
2023-04-09 19:21:45 +00:00
# Gather variable mapping
2023-04-18 21:49:38 +00:00
mapping = list(
zip(self._trainable_variables, trainable_variables)
) + list(zip(self.variables, optimizer_variables))
2023-04-09 19:21:45 +00:00
# Call in stateless scope
with backend.StatelessScope(state_mapping=mapping) as scope:
self.apply(grads)
# Gather updated variables
trainable_variables = []
for v in self._trainable_variables:
new_v = scope.get_current_value(v)
if new_v is not None:
trainable_variables.append(new_v)
else:
trainable_variables.append(v)
optimizer_variables = []
for v in self.variables:
new_v = scope.get_current_value(v)
if new_v is not None:
optimizer_variables.append(new_v)
else:
optimizer_variables.append(v)
return trainable_variables, optimizer_variables
@property
def learning_rate(self):
return self._get_current_learning_rate()
def save_own_variables(self, store):
"""Get the state of this optimizer object."""
for i, variable in enumerate(self.variables):
store[str(i)] = np.array(variable)
def load_own_variables(self, store):
"""Set the state of this optimizer object."""
if len(store.keys()) != len(self.variables):
msg = (
f"Skipping variable loading for optimizer '{self.name}', "
f"because it has {len(self.variables)} variables whereas "
f"the saved optimizer has {len(store.keys())} variables. "
)
if len(self.variables) == 0:
msg += (
"This is likely because the optimizer has not been "
"called/built yet."
)
warnings.warn(msg, stacklevel=2)
return
for i, variable in enumerate(self.variables):
variable.assign(store[str(i)])
2023-04-09 19:21:45 +00:00
def _get_current_learning_rate(self):
if isinstance(
self._learning_rate, learning_rate_schedule.LearningRateSchedule
):
2023-04-09 19:21:45 +00:00
return self._learning_rate(self.iterations)
elif callable(self._learning_rate):
return self._learning_rate(self.iterations)
return self._learning_rate
def _filter_empty_gradients(self, grads_and_vars):
filtered = [(g, v) for g, v in grads_and_vars if g is not None]
if not filtered:
raise ValueError("No gradients provided for any variable.")
2023-04-09 19:21:45 +00:00
if len(filtered) < len(grads_and_vars):
missing_grad_vars = [v for g, v in grads_and_vars if g is None]
warnings.warn(
"Gradients do not exist for variables "
f"{[v.name for v in missing_grad_vars]} when minimizing the "
"loss. If you're using `model.compile()`, did you forget to "
"provide a `loss` argument?"
)
return filtered
def _clip_gradients(self, grads):
if self.clipnorm and self.clipnorm > 0:
raise NotImplementedError # TODO
# clipped_grads = []
# for g in grads:
# if g is None:
# clipped_grads.append(g)
# else:
# clipped_grads.append(tf.clip_by_norm(g, self.clipnorm))
# return clipped_grads
if self.global_clipnorm and self.global_clipnorm > 0:
raise NotImplementedError # TODO
# return tf.clip_by_global_norm(grads, self.global_clipnorm)[0]
if self.clipvalue and self.clipvalue > 0:
raise NotImplementedError # TODO
# clipped_grads = []
# for g in grads:
# if g is None:
# clipped_grads.append(g)
# else:
# clipped_grads.append(
# tf.clip_by_value(
# g,
# clip_value_min=-self.clipvalue,
# clip_value_max=self.clipvalue,
# )
# )
# return clipped_grads
return grads
def exclude_from_weight_decay(self, var_list=None, var_names=None):
"""Exclude variables from weight decay.
This method must be called before the optimizer's `build` method is
called. You can set specific variables to exclude out, or set a list of
strings as the anchor words, if any of which appear in a variable's
name, then the variable is excluded.
Args:
var_list: A list of `tf.Variable`s to exclude from weight decay.
var_names: A list of strings. If any string in `var_names` appear
in the model variable's name, then this model variable is
excluded from weight decay. For example, `var_names=['bias']`
excludes all bias variables from weight decay.
"""
if hasattr(self, "_built") and self._built:
raise ValueError(
"`exclude_from_weight_decay()` can only be configued before "
"the optimizer is built."
)
if var_list:
self._exclude_from_weight_decay = [
id(variable) for variable in var_list
]
2023-04-09 19:21:45 +00:00
else:
self._exclude_from_weight_decay = []
self._exclude_from_weight_decay_names = var_names or []
def _use_weight_decay(self, variable):
exclude_from_weight_decay = getattr(
self, "_exclude_from_weight_decay", []
)
2023-04-09 19:21:45 +00:00
exclude_from_weight_decay_names = getattr(
self, "_exclude_from_weight_decay_names", []
)
variable_id = id(variable)
for exclude_id in exclude_from_weight_decay:
if variable_id == exclude_id:
return False
for name in exclude_from_weight_decay_names:
if re.search(name, variable.name) is not None:
return False
return True
def _apply_weight_decay(self, variables):
if self.weight_decay is None:
return
for variable in variables:
if self._use_weight_decay(variable):
lr = ops.cast(self._get_current_learning_rate(), variable.dtype)
wd = ops.cast(self.weight_decay, variable.dtype)
variable.assign(variable - variable * wd * lr)
def _check_super_called(self):
if not hasattr(self, "_tracker"):
raise RuntimeError(
f"In optimizer '{self.__class__.__name__}', you forgot to call "
"`super().__init__()` in the `__init__()` method. "
"Go add it!"
)
def _update_model_variables_moving_average(self, var_list):
"""Update the stored moving average using the latest value."""
if self.use_ema:
for var, average in zip(
var_list, self._model_variables_moving_average
):
2023-04-09 19:21:45 +00:00
average.assign(
self.ema_momentum * average + (1 - self.ema_momentum) * var
)
def _overwrite_model_variables_with_average_value(self, var_list):
"""Overwrite model variables with its moving average."""
if len(var_list) != len(self._model_variables_moving_average):
raise ValueError(
f"The length of model variables ({len(var_list)}) to "
"override does not match the length of model variables "
"stored in the optimizer "
f"({len(self._model_variables_moving_average)}). Please "
"check if the optimizer was called on your model."
)
self._overwrite_model_variables_with_average_value_helper(var_list)
def _overwrite_model_variables_with_average_value_helper(self, var_list):
"""Helper function that overwrites model variables."""
for var, average_var in zip(
var_list, self._model_variables_moving_average
):
2023-04-09 19:21:45 +00:00
var.assign(average_var)
def finalize_variable_values(self, var_list):
"""Set the final value of model's trainable variables.
Sometimes there are some extra steps before ending the variable updates,
such as overriding the model variables with its average value.
Args:
var_list: list of model variables.
"""
if self.use_ema:
# If the optimizer uses EMA, then when finalizing, we replace the
# model variable value with its moving average stored inside
# optimizer.
self._overwrite_model_variables_with_average_value(var_list)
def get_config(self):
"""Returns the config of the optimizer.
An optimizer config is a Python dictionary (serializable)
containing the configuration of an optimizer.
The same optimizer can be reinstantiated later
(without any saved state) from this configuration.
Subclass optimizer should override this method to include other
hyperparameters.
Returns:
Python dictionary.
"""
if isinstance(
self._learning_rate, learning_rate_schedule.LearningRateSchedule
):
learning_rate = learning_rate_schedule.serialize(
self._learning_rate
)
2023-04-09 19:21:45 +00:00
elif isinstance(self._learning_rate, backend.Variable):
learning_rate = float(self._learning_rate.numpy())
elif ops.is_tensor(self._learning_rate):
learning_rate = float(self._learning_rate)
elif callable(self._learning_rate):
# TODO: serialize custom object
learning_rate = self._learning_rate
config = {
"name": self.name,
"learning_rate": learning_rate,
"weight_decay": self.weight_decay,
"clipnorm": self.clipnorm,
"global_clipnorm": self.global_clipnorm,
"clipvalue": self.clipvalue,
"use_ema": self.use_ema,
"ema_momentum": self.ema_momentum,
"ema_overwrite_frequency": self.ema_overwrite_frequency,
}
return config
@classmethod
def from_config(cls, config, custom_objects=None):
"""Creates an optimizer from its config.
This method is the reverse of `get_config`, capable of instantiating the
same optimizer from the config dictionary.
Args:
config: A Python dictionary, typically the output of get_config.
custom_objects: A Python dictionary mapping names to additional
user-defined Python objects needed to recreate this optimizer.
Returns:
An optimizer instance.
"""
if "learning_rate" in config:
if isinstance(config["learning_rate"], dict):
config["learning_rate"] = learning_rate_schedule.deserialize(
config["learning_rate"], custom_objects=custom_objects
)
return cls(**config)
base_optimizer_keyword_args = """name: String. The name to use
for momentum accumulator weights created by
the optimizer.
weight_decay: Float, defaults to None. If set, weight decay is applied.
clipnorm: Float. If set, the gradient of each weight is individually
clipped so that its norm is no higher than this value.
clipvalue: Float. If set, the gradient of each weight is clipped to be no
higher than this value.
global_clipnorm: Float. If set, the gradient of all weights is clipped so
that their global norm is no higher than this value.
use_ema: Boolean, defaults to False. If True, exponential moving average
(EMA) is applied. EMA consists of computing an exponential moving
average of the weights of the model (as the weight values change after
each training batch), and periodically overwriting the weights with
their moving average.
ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`.
This is the momentum to use when computing
the EMA of the model's weights:
`new_average = ema_momentum * old_average + (1 - ema_momentum) *
current_variable_value`.
ema_overwrite_frequency: Int or None, defaults to None. Only used if
`use_ema=True`. Every `ema_overwrite_frequency` steps of iterations,
we overwrite the model variable by its moving average.
If None, the optimizer
does not overwrite model variables in the middle of training, and you
need to explicitly overwrite the variables at the end of training
by calling `optimizer.finalize_variable_values()`
(which updates the model
variables in-place). When using the built-in `fit()` training loop,
this happens automatically after the last epoch,
and you don't need to do anything."""