Fix the loss scale optimizer for tf.distribute. (#18691)

* Fix the loss scale optimizer for tf.distribute.

* Address review comments.
This commit is contained in:
Qianli Scott Zhu 2023-10-26 12:59:32 -07:00 committed by GitHub
parent 123b61fee4
commit 85dcc8b9f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 142 additions and 75 deletions

@ -44,7 +44,9 @@ def test_model_fit():
# Fit from numpy arrays:
with strategy.scope():
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),
optimizer=optimizers.LossScaleOptimizer(
optimizers.SGD(learning_rate=0.001, momentum=0.01)
),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable

@ -89,8 +89,20 @@ class LossScaleOptimizer(optimizer.Optimizer):
"must be built (i.e. its variables must have been created). "
"You can build it via `optimizer.build(trainable_variables)`."
)
finite = self.check_finite(grads)
return ops.cond(
finite,
lambda: self._stateless_handle_finite_grads(
optimizer_variables, grads, trainable_variables
),
lambda: self._stateless_handle_non_finite_grads(
optimizer_variables, trainable_variables
),
)
def handle_finite_grads():
def _stateless_handle_finite_grads(
self, optimizer_variables, grads, trainable_variables
):
def upscale():
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
@ -130,7 +142,9 @@ class LossScaleOptimizer(optimizer.Optimizer):
new_optimizer_variables = own_variables + new_inner_variables
return new_trainable_variables, new_optimizer_variables
def handle_non_finite_grads():
def _stateless_handle_non_finite_grads(
self, optimizer_variables, trainable_variables
):
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign(0)
@ -140,9 +154,6 @@ class LossScaleOptimizer(optimizer.Optimizer):
new_optimizer_variables.append(scope.get_current_value(v))
return trainable_variables, new_optimizer_variables
finite = self.check_finite(grads)
return ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
def apply(self, grads, trainable_variables=None):
# Optionally build optimizer.
if not self.built:
@ -150,7 +161,12 @@ class LossScaleOptimizer(optimizer.Optimizer):
self.build(trainable_variables)
self.built = True
def handle_finite_grads():
if backend.backend() == "tensorflow":
self._tf_apply(grads, trainable_variables)
else:
self._common_apply(grads, trainable_variables)
def _stateful_handle_finite_grads(self, grads, trainable_variables):
scale = self.dynamic_scale
# Unscale gradients.
unscaled_grads = [
@ -174,13 +190,62 @@ class LossScaleOptimizer(optimizer.Optimizer):
increment,
)
def handle_non_finite_grads():
def _stateful_handle_non_finite_grads(self):
# If any inf or nan in grads, downscale loss and reset counter.
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
def _common_apply(self, grads, trainable_variables=None):
finite = self.check_finite(grads)
ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
ops.cond(
finite,
lambda: self._stateful_handle_finite_grads(
grads, trainable_variables
),
self._stateful_handle_non_finite_grads,
)
def _tf_apply(self, grads, trainable_variables=None):
"""Tensorflow specific logic for apply, which handles distribution."""
from keras.utils.module_utils import tensorflow as tf
if tf.distribute.in_cross_replica_context():
raise ValueError("apply() must be called in a replica context.")
if tf.__internal__.distribute.strategy_supports_no_merge_call():
self._common_apply(grads, trainable_variables=trainable_variables)
else:
def _handle_cross_replica(distribution, grads, trainable_variables):
finite_per_replica = (
distribution.extended.call_for_each_replica(
self.check_finite, args=(grads,)
)
)
# Each replica computed the same `finite` value, since
# `grads` is all-reduced across replicas. Arbitrarily take
# `finite` from the first replica.
finite = distribution.experimental_local_results(
finite_per_replica
)[0]
def apply_fn():
distribution.extended.call_for_each_replica(
self._stateful_handle_finite_grads,
args=(grads, trainable_variables),
)
# Note: We must call this cond() in a cross-replica context.
# DistributionStrategy does not support having a cond in a
# replica context with a branch that calls `merge_call`, and
# self._optimizer.apply_gradients calls `merge_call`.
ops.cond(
finite, apply_fn, self._stateful_handle_non_finite_grads
)
tf.distribute.get_replica_context().merge_call(
_handle_cross_replica, args=(grads, trainable_variables)
)
def check_finite(self, grads):
tensor_grads = [g for g in grads if g is not None]