keras/keras_core/backend/tensorflow/optimizer.py

125 lines
4.7 KiB
Python
Raw Normal View History

import tensorflow as tf
from keras_core import backend
from keras_core.optimizers import base_optimizer
class TFOptimizer(base_optimizer.Optimizer):
"""A class for Tensorflow specific optimizer logic.
The major behavior change for this class is for tf.distribute.
It will override methods from base Keras core Optimizer,
which provide distribute specific functionality, e.g. variable
creation, loss reduction, etc.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._distribution_strategy = tf.distribute.get_strategy()
def add_variable_from_reference(self, reference_variable, name=None):
if isinstance(reference_variable, backend.Variable):
colocate_var = reference_variable.value
else:
colocate_var = reference_variable
with self._distribution_strategy.extended.colocate_vars_with(
colocate_var
):
return super().add_variable_from_reference(
reference_variable, name=name
)
def _var_key(self, variable):
if isinstance(variable, backend.Variable):
variable = variable.value # Convert to tf.Variable
if hasattr(variable, "_distributed_container"):
variable = variable._distributed_container()
elif (
isinstance(variable, tf.__internal__.CompositeTensor)
and hasattr(variable, "handle")
and hasattr(variable.handle, "_distributed_container")
):
# For ResourceVariables, the _distributed_container attribute
# is added to their handle tensors.
variable = variable.handle._distributed_container()
return variable._unique_id
def _apply_weight_decay(self, variables):
if self.weight_decay is None:
return
def distributed_apply_weight_decay(distribution, variables, **kwargs):
def weight_decay_fn(variable):
if self._use_weight_decay(variable):
lr = tf.cast(self.learning_rate, variable.dtype)
wd = tf.cast(self.weight_decay, variable.dtype)
variable.assign(variable - variable * wd * lr)
for variable in variables:
distribution.extended.update(
variable, weight_decay_fn, group=False
)
tf.__internal__.distribute.interim.maybe_merge_call(
distributed_apply_weight_decay,
self._distribution_strategy,
variables,
)
def _internal_apply_gradients(self, grads_and_vars):
tf.__internal__.distribute.interim.maybe_merge_call(
self._distributed_apply_gradients_fn,
self._distribution_strategy,
grads_and_vars,
)
def _distributed_apply_gradients_fn(
self, distribution, grads_and_vars, **kwargs
):
"""`apply_gradients` using a `DistributionStrategy`."""
def apply_grad_to_update_var(var, grad):
learning_rate = self._get_current_learning_rate()
return self.update_step(grad, var, learning_rate)
for grad, var in grads_and_vars:
distribution.extended.update(
var, apply_grad_to_update_var, args=(grad,), group=False
)
if self.use_ema:
_, var_list = zip(*grads_and_vars)
self._update_model_variables_moving_average(var_list)
if self.ema_overwrite_frequency:
# Only when self.ema_overwrite_frequency is not None, we
# overwrite the model variables.
should_overwrite_model_vars = (
self.iterations + 1
) % self.ema_overwrite_frequency == 0
tf.cond(
tf.cast(should_overwrite_model_vars, tf.bool),
true_fn=lambda: self._overwrite_model_variables_with_average_value( # noqa: E501
var_list
),
false_fn=lambda: None,
)
self.iterations.assign(self.iterations + 1)
def _overwrite_model_variables_with_average_value(self, var_list):
"""Overwrite model variables with their moving average values.
This function overwrites variables on each device.
Args:
var_list: list of model variables.
"""
strategy = self._distribution_strategy
# Override model variable by the stored average value on all devices.
for var, average_var in zip(
var_list, self._model_variables_moving_average
):
strategy.extended.update(
var, lambda a, b: a.assign(b), args=(average_var,)
)