Fix the loss scale optimizer for tf.distribute. (#18691)

* Fix the loss scale optimizer for tf.distribute.

* Address review comments.
This commit is contained in:
Qianli Scott Zhu 2023-10-26 12:59:32 -07:00 committed by GitHub
parent 123b61fee4
commit 85dcc8b9f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 142 additions and 75 deletions

@ -44,7 +44,9 @@ def test_model_fit():
# Fit from numpy arrays: # Fit from numpy arrays:
with strategy.scope(): with strategy.scope():
model.compile( model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01), optimizer=optimizers.LossScaleOptimizer(
optimizers.SGD(learning_rate=0.001, momentum=0.01)
),
loss=losses.MeanSquaredError(), loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()], metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable # TODO(scottzhu): Find out where is the variable

@ -89,59 +89,70 @@ class LossScaleOptimizer(optimizer.Optimizer):
"must be built (i.e. its variables must have been created). " "must be built (i.e. its variables must have been created). "
"You can build it via `optimizer.build(trainable_variables)`." "You can build it via `optimizer.build(trainable_variables)`."
) )
finite = self.check_finite(grads)
return ops.cond(
finite,
lambda: self._stateless_handle_finite_grads(
optimizer_variables, grads, trainable_variables
),
lambda: self._stateless_handle_non_finite_grads(
optimizer_variables, trainable_variables
),
)
def handle_finite_grads(): def _stateless_handle_finite_grads(
def upscale(): self, optimizer_variables, grads, trainable_variables
mapping = list(zip(self.variables, optimizer_variables)) ):
with backend.StatelessScope(state_mapping=mapping) as scope: def upscale():
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
return [scope.get_current_value(v) for v in self._variables]
def increment():
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign_add(1)
return [scope.get_current_value(v) for v in self._variables]
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping):
# Potentially upscale loss and reset counter.
own_variables = ops.cond(
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
upscale,
increment,
)
# Unscale gradients.
scale = self.dynamic_scale
unscaled_grads = [
g if g is None else ops.divide(g, scale) for g in grads
]
(
new_trainable_variables,
new_inner_variables,
) = self.inner_optimizer.stateless_apply(
self.inner_optimizer.variables,
unscaled_grads,
trainable_variables,
)
new_optimizer_variables = own_variables + new_inner_variables
return new_trainable_variables, new_optimizer_variables
def handle_non_finite_grads():
mapping = list(zip(self.variables, optimizer_variables)) mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope: with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign(0) self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0) self.dynamic_scale.assign(self.dynamic_scale * 2.0)
new_optimizer_variables = [] return [scope.get_current_value(v) for v in self._variables]
for v in self.variables:
new_optimizer_variables.append(scope.get_current_value(v))
return trainable_variables, new_optimizer_variables
finite = self.check_finite(grads) def increment():
return ops.cond(finite, handle_finite_grads, handle_non_finite_grads) mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign_add(1)
return [scope.get_current_value(v) for v in self._variables]
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping):
# Potentially upscale loss and reset counter.
own_variables = ops.cond(
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
upscale,
increment,
)
# Unscale gradients.
scale = self.dynamic_scale
unscaled_grads = [
g if g is None else ops.divide(g, scale) for g in grads
]
(
new_trainable_variables,
new_inner_variables,
) = self.inner_optimizer.stateless_apply(
self.inner_optimizer.variables,
unscaled_grads,
trainable_variables,
)
new_optimizer_variables = own_variables + new_inner_variables
return new_trainable_variables, new_optimizer_variables
def _stateless_handle_non_finite_grads(
self, optimizer_variables, trainable_variables
):
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
new_optimizer_variables = []
for v in self.variables:
new_optimizer_variables.append(scope.get_current_value(v))
return trainable_variables, new_optimizer_variables
def apply(self, grads, trainable_variables=None): def apply(self, grads, trainable_variables=None):
# Optionally build optimizer. # Optionally build optimizer.
@ -150,37 +161,91 @@ class LossScaleOptimizer(optimizer.Optimizer):
self.build(trainable_variables) self.build(trainable_variables)
self.built = True self.built = True
def handle_finite_grads(): if backend.backend() == "tensorflow":
scale = self.dynamic_scale self._tf_apply(grads, trainable_variables)
# Unscale gradients. else:
unscaled_grads = [ self._common_apply(grads, trainable_variables)
g if g is None else ops.divide(g, scale) for g in grads
]
self.inner_optimizer.apply(
unscaled_grads, trainable_variables=trainable_variables
)
def upscale(): def _stateful_handle_finite_grads(self, grads, trainable_variables):
self.step_counter.assign(0) scale = self.dynamic_scale
self.dynamic_scale.assign(self.dynamic_scale * 2.0) # Unscale gradients.
unscaled_grads = [
g if g is None else ops.divide(g, scale) for g in grads
]
self.inner_optimizer.apply(
unscaled_grads, trainable_variables=trainable_variables
)
def increment(): def upscale():
self.step_counter.assign_add(1)
# Potentially upscale loss and reset counter.
ops.cond(
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
upscale,
increment,
)
def handle_non_finite_grads():
# If any inf or nan in grads, downscale loss and reset counter.
self.step_counter.assign(0) self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0) self.dynamic_scale.assign(self.dynamic_scale * 2.0)
def increment():
self.step_counter.assign_add(1)
# Potentially upscale loss and reset counter.
ops.cond(
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
upscale,
increment,
)
def _stateful_handle_non_finite_grads(self):
# If any inf or nan in grads, downscale loss and reset counter.
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
def _common_apply(self, grads, trainable_variables=None):
finite = self.check_finite(grads) finite = self.check_finite(grads)
ops.cond(finite, handle_finite_grads, handle_non_finite_grads) ops.cond(
finite,
lambda: self._stateful_handle_finite_grads(
grads, trainable_variables
),
self._stateful_handle_non_finite_grads,
)
def _tf_apply(self, grads, trainable_variables=None):
"""Tensorflow specific logic for apply, which handles distribution."""
from keras.utils.module_utils import tensorflow as tf
if tf.distribute.in_cross_replica_context():
raise ValueError("apply() must be called in a replica context.")
if tf.__internal__.distribute.strategy_supports_no_merge_call():
self._common_apply(grads, trainable_variables=trainable_variables)
else:
def _handle_cross_replica(distribution, grads, trainable_variables):
finite_per_replica = (
distribution.extended.call_for_each_replica(
self.check_finite, args=(grads,)
)
)
# Each replica computed the same `finite` value, since
# `grads` is all-reduced across replicas. Arbitrarily take
# `finite` from the first replica.
finite = distribution.experimental_local_results(
finite_per_replica
)[0]
def apply_fn():
distribution.extended.call_for_each_replica(
self._stateful_handle_finite_grads,
args=(grads, trainable_variables),
)
# Note: We must call this cond() in a cross-replica context.
# DistributionStrategy does not support having a cond in a
# replica context with a branch that calls `merge_call`, and
# self._optimizer.apply_gradients calls `merge_call`.
ops.cond(
finite, apply_fn, self._stateful_handle_non_finite_grads
)
tf.distribute.get_replica_context().merge_call(
_handle_cross_replica, args=(grads, trainable_variables)
)
def check_finite(self, grads): def check_finite(self, grads):
tensor_grads = [g for g in grads if g is not None] tensor_grads = [g for g in grads if g is not None]