Fix the loss scale optimizer for tf.distribute. (#18691)
* Fix the loss scale optimizer for tf.distribute. * Address review comments.
This commit is contained in:
parent
123b61fee4
commit
85dcc8b9f5
@ -44,7 +44,9 @@ def test_model_fit():
|
|||||||
# Fit from numpy arrays:
|
# Fit from numpy arrays:
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),
|
optimizer=optimizers.LossScaleOptimizer(
|
||||||
|
optimizers.SGD(learning_rate=0.001, momentum=0.01)
|
||||||
|
),
|
||||||
loss=losses.MeanSquaredError(),
|
loss=losses.MeanSquaredError(),
|
||||||
metrics=[metrics.MeanSquaredError()],
|
metrics=[metrics.MeanSquaredError()],
|
||||||
# TODO(scottzhu): Find out where is the variable
|
# TODO(scottzhu): Find out where is the variable
|
||||||
|
@ -89,8 +89,20 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|||||||
"must be built (i.e. its variables must have been created). "
|
"must be built (i.e. its variables must have been created). "
|
||||||
"You can build it via `optimizer.build(trainable_variables)`."
|
"You can build it via `optimizer.build(trainable_variables)`."
|
||||||
)
|
)
|
||||||
|
finite = self.check_finite(grads)
|
||||||
|
return ops.cond(
|
||||||
|
finite,
|
||||||
|
lambda: self._stateless_handle_finite_grads(
|
||||||
|
optimizer_variables, grads, trainable_variables
|
||||||
|
),
|
||||||
|
lambda: self._stateless_handle_non_finite_grads(
|
||||||
|
optimizer_variables, trainable_variables
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def handle_finite_grads():
|
def _stateless_handle_finite_grads(
|
||||||
|
self, optimizer_variables, grads, trainable_variables
|
||||||
|
):
|
||||||
def upscale():
|
def upscale():
|
||||||
mapping = list(zip(self.variables, optimizer_variables))
|
mapping = list(zip(self.variables, optimizer_variables))
|
||||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||||
@ -130,7 +142,9 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|||||||
new_optimizer_variables = own_variables + new_inner_variables
|
new_optimizer_variables = own_variables + new_inner_variables
|
||||||
return new_trainable_variables, new_optimizer_variables
|
return new_trainable_variables, new_optimizer_variables
|
||||||
|
|
||||||
def handle_non_finite_grads():
|
def _stateless_handle_non_finite_grads(
|
||||||
|
self, optimizer_variables, trainable_variables
|
||||||
|
):
|
||||||
mapping = list(zip(self.variables, optimizer_variables))
|
mapping = list(zip(self.variables, optimizer_variables))
|
||||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||||
self.step_counter.assign(0)
|
self.step_counter.assign(0)
|
||||||
@ -140,9 +154,6 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|||||||
new_optimizer_variables.append(scope.get_current_value(v))
|
new_optimizer_variables.append(scope.get_current_value(v))
|
||||||
return trainable_variables, new_optimizer_variables
|
return trainable_variables, new_optimizer_variables
|
||||||
|
|
||||||
finite = self.check_finite(grads)
|
|
||||||
return ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
|
|
||||||
|
|
||||||
def apply(self, grads, trainable_variables=None):
|
def apply(self, grads, trainable_variables=None):
|
||||||
# Optionally build optimizer.
|
# Optionally build optimizer.
|
||||||
if not self.built:
|
if not self.built:
|
||||||
@ -150,7 +161,12 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|||||||
self.build(trainable_variables)
|
self.build(trainable_variables)
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
def handle_finite_grads():
|
if backend.backend() == "tensorflow":
|
||||||
|
self._tf_apply(grads, trainable_variables)
|
||||||
|
else:
|
||||||
|
self._common_apply(grads, trainable_variables)
|
||||||
|
|
||||||
|
def _stateful_handle_finite_grads(self, grads, trainable_variables):
|
||||||
scale = self.dynamic_scale
|
scale = self.dynamic_scale
|
||||||
# Unscale gradients.
|
# Unscale gradients.
|
||||||
unscaled_grads = [
|
unscaled_grads = [
|
||||||
@ -174,13 +190,62 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|||||||
increment,
|
increment,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_non_finite_grads():
|
def _stateful_handle_non_finite_grads(self):
|
||||||
# If any inf or nan in grads, downscale loss and reset counter.
|
# If any inf or nan in grads, downscale loss and reset counter.
|
||||||
self.step_counter.assign(0)
|
self.step_counter.assign(0)
|
||||||
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
||||||
|
|
||||||
|
def _common_apply(self, grads, trainable_variables=None):
|
||||||
finite = self.check_finite(grads)
|
finite = self.check_finite(grads)
|
||||||
ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
|
ops.cond(
|
||||||
|
finite,
|
||||||
|
lambda: self._stateful_handle_finite_grads(
|
||||||
|
grads, trainable_variables
|
||||||
|
),
|
||||||
|
self._stateful_handle_non_finite_grads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _tf_apply(self, grads, trainable_variables=None):
|
||||||
|
"""Tensorflow specific logic for apply, which handles distribution."""
|
||||||
|
from keras.utils.module_utils import tensorflow as tf
|
||||||
|
|
||||||
|
if tf.distribute.in_cross_replica_context():
|
||||||
|
raise ValueError("apply() must be called in a replica context.")
|
||||||
|
|
||||||
|
if tf.__internal__.distribute.strategy_supports_no_merge_call():
|
||||||
|
self._common_apply(grads, trainable_variables=trainable_variables)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _handle_cross_replica(distribution, grads, trainable_variables):
|
||||||
|
finite_per_replica = (
|
||||||
|
distribution.extended.call_for_each_replica(
|
||||||
|
self.check_finite, args=(grads,)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Each replica computed the same `finite` value, since
|
||||||
|
# `grads` is all-reduced across replicas. Arbitrarily take
|
||||||
|
# `finite` from the first replica.
|
||||||
|
finite = distribution.experimental_local_results(
|
||||||
|
finite_per_replica
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def apply_fn():
|
||||||
|
distribution.extended.call_for_each_replica(
|
||||||
|
self._stateful_handle_finite_grads,
|
||||||
|
args=(grads, trainable_variables),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: We must call this cond() in a cross-replica context.
|
||||||
|
# DistributionStrategy does not support having a cond in a
|
||||||
|
# replica context with a branch that calls `merge_call`, and
|
||||||
|
# self._optimizer.apply_gradients calls `merge_call`.
|
||||||
|
ops.cond(
|
||||||
|
finite, apply_fn, self._stateful_handle_non_finite_grads
|
||||||
|
)
|
||||||
|
|
||||||
|
tf.distribute.get_replica_context().merge_call(
|
||||||
|
_handle_cross_replica, args=(grads, trainable_variables)
|
||||||
|
)
|
||||||
|
|
||||||
def check_finite(self, grads):
|
def check_finite(self, grads):
|
||||||
tensor_grads = [g for g in grads if g is not None]
|
tensor_grads = [g for g in grads if g is not None]
|
||||||
|
Loading…
Reference in New Issue
Block a user