Fix the loss scale optimizer for tf.distribute. (#18691)

* Fix the loss scale optimizer for tf.distribute.

* Address review comments.
This commit is contained in:
Qianli Scott Zhu 2023-10-26 12:59:32 -07:00 committed by GitHub
parent 123b61fee4
commit 85dcc8b9f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 142 additions and 75 deletions

@ -44,7 +44,9 @@ def test_model_fit():
# Fit from numpy arrays: # Fit from numpy arrays:
with strategy.scope(): with strategy.scope():
model.compile( model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01), optimizer=optimizers.LossScaleOptimizer(
optimizers.SGD(learning_rate=0.001, momentum=0.01)
),
loss=losses.MeanSquaredError(), loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()], metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable # TODO(scottzhu): Find out where is the variable

@ -89,8 +89,20 @@ class LossScaleOptimizer(optimizer.Optimizer):
"must be built (i.e. its variables must have been created). " "must be built (i.e. its variables must have been created). "
"You can build it via `optimizer.build(trainable_variables)`." "You can build it via `optimizer.build(trainable_variables)`."
) )
finite = self.check_finite(grads)
return ops.cond(
finite,
lambda: self._stateless_handle_finite_grads(
optimizer_variables, grads, trainable_variables
),
lambda: self._stateless_handle_non_finite_grads(
optimizer_variables, trainable_variables
),
)
def handle_finite_grads(): def _stateless_handle_finite_grads(
self, optimizer_variables, grads, trainable_variables
):
def upscale(): def upscale():
mapping = list(zip(self.variables, optimizer_variables)) mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope: with backend.StatelessScope(state_mapping=mapping) as scope:
@ -130,7 +142,9 @@ class LossScaleOptimizer(optimizer.Optimizer):
new_optimizer_variables = own_variables + new_inner_variables new_optimizer_variables = own_variables + new_inner_variables
return new_trainable_variables, new_optimizer_variables return new_trainable_variables, new_optimizer_variables
def handle_non_finite_grads(): def _stateless_handle_non_finite_grads(
self, optimizer_variables, trainable_variables
):
mapping = list(zip(self.variables, optimizer_variables)) mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope: with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign(0) self.step_counter.assign(0)
@ -140,9 +154,6 @@ class LossScaleOptimizer(optimizer.Optimizer):
new_optimizer_variables.append(scope.get_current_value(v)) new_optimizer_variables.append(scope.get_current_value(v))
return trainable_variables, new_optimizer_variables return trainable_variables, new_optimizer_variables
finite = self.check_finite(grads)
return ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
def apply(self, grads, trainable_variables=None): def apply(self, grads, trainable_variables=None):
# Optionally build optimizer. # Optionally build optimizer.
if not self.built: if not self.built:
@ -150,7 +161,12 @@ class LossScaleOptimizer(optimizer.Optimizer):
self.build(trainable_variables) self.build(trainable_variables)
self.built = True self.built = True
def handle_finite_grads(): if backend.backend() == "tensorflow":
self._tf_apply(grads, trainable_variables)
else:
self._common_apply(grads, trainable_variables)
def _stateful_handle_finite_grads(self, grads, trainable_variables):
scale = self.dynamic_scale scale = self.dynamic_scale
# Unscale gradients. # Unscale gradients.
unscaled_grads = [ unscaled_grads = [
@ -174,13 +190,62 @@ class LossScaleOptimizer(optimizer.Optimizer):
increment, increment,
) )
def handle_non_finite_grads(): def _stateful_handle_non_finite_grads(self):
# If any inf or nan in grads, downscale loss and reset counter. # If any inf or nan in grads, downscale loss and reset counter.
self.step_counter.assign(0) self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0) self.dynamic_scale.assign(self.dynamic_scale / 2.0)
def _common_apply(self, grads, trainable_variables=None):
finite = self.check_finite(grads) finite = self.check_finite(grads)
ops.cond(finite, handle_finite_grads, handle_non_finite_grads) ops.cond(
finite,
lambda: self._stateful_handle_finite_grads(
grads, trainable_variables
),
self._stateful_handle_non_finite_grads,
)
def _tf_apply(self, grads, trainable_variables=None):
"""Tensorflow specific logic for apply, which handles distribution."""
from keras.utils.module_utils import tensorflow as tf
if tf.distribute.in_cross_replica_context():
raise ValueError("apply() must be called in a replica context.")
if tf.__internal__.distribute.strategy_supports_no_merge_call():
self._common_apply(grads, trainable_variables=trainable_variables)
else:
def _handle_cross_replica(distribution, grads, trainable_variables):
finite_per_replica = (
distribution.extended.call_for_each_replica(
self.check_finite, args=(grads,)
)
)
# Each replica computed the same `finite` value, since
# `grads` is all-reduced across replicas. Arbitrarily take
# `finite` from the first replica.
finite = distribution.experimental_local_results(
finite_per_replica
)[0]
def apply_fn():
distribution.extended.call_for_each_replica(
self._stateful_handle_finite_grads,
args=(grads, trainable_variables),
)
# Note: We must call this cond() in a cross-replica context.
# DistributionStrategy does not support having a cond in a
# replica context with a branch that calls `merge_call`, and
# self._optimizer.apply_gradients calls `merge_call`.
ops.cond(
finite, apply_fn, self._stateful_handle_non_finite_grads
)
tf.distribute.get_replica_context().merge_call(
_handle_cross_replica, args=(grads, trainable_variables)
)
def check_finite(self, grads): def check_finite(self, grads):
tensor_grads = [g for g in grads if g is not None] tensor_grads = [g for g in grads if g is not None]