Fix the loss scale optimizer for tf.distribute. (#18691)
* Fix the loss scale optimizer for tf.distribute. * Address review comments.
This commit is contained in:
parent
123b61fee4
commit
85dcc8b9f5
@ -44,7 +44,9 @@ def test_model_fit():
|
||||
# Fit from numpy arrays:
|
||||
with strategy.scope():
|
||||
model.compile(
|
||||
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),
|
||||
optimizer=optimizers.LossScaleOptimizer(
|
||||
optimizers.SGD(learning_rate=0.001, momentum=0.01)
|
||||
),
|
||||
loss=losses.MeanSquaredError(),
|
||||
metrics=[metrics.MeanSquaredError()],
|
||||
# TODO(scottzhu): Find out where is the variable
|
||||
|
@ -89,59 +89,70 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
||||
"must be built (i.e. its variables must have been created). "
|
||||
"You can build it via `optimizer.build(trainable_variables)`."
|
||||
)
|
||||
finite = self.check_finite(grads)
|
||||
return ops.cond(
|
||||
finite,
|
||||
lambda: self._stateless_handle_finite_grads(
|
||||
optimizer_variables, grads, trainable_variables
|
||||
),
|
||||
lambda: self._stateless_handle_non_finite_grads(
|
||||
optimizer_variables, trainable_variables
|
||||
),
|
||||
)
|
||||
|
||||
def handle_finite_grads():
|
||||
def upscale():
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||
self.step_counter.assign(0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
|
||||
return [scope.get_current_value(v) for v in self._variables]
|
||||
|
||||
def increment():
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||
self.step_counter.assign_add(1)
|
||||
return [scope.get_current_value(v) for v in self._variables]
|
||||
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping):
|
||||
# Potentially upscale loss and reset counter.
|
||||
own_variables = ops.cond(
|
||||
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
|
||||
upscale,
|
||||
increment,
|
||||
)
|
||||
|
||||
# Unscale gradients.
|
||||
scale = self.dynamic_scale
|
||||
unscaled_grads = [
|
||||
g if g is None else ops.divide(g, scale) for g in grads
|
||||
]
|
||||
(
|
||||
new_trainable_variables,
|
||||
new_inner_variables,
|
||||
) = self.inner_optimizer.stateless_apply(
|
||||
self.inner_optimizer.variables,
|
||||
unscaled_grads,
|
||||
trainable_variables,
|
||||
)
|
||||
|
||||
new_optimizer_variables = own_variables + new_inner_variables
|
||||
return new_trainable_variables, new_optimizer_variables
|
||||
|
||||
def handle_non_finite_grads():
|
||||
def _stateless_handle_finite_grads(
|
||||
self, optimizer_variables, grads, trainable_variables
|
||||
):
|
||||
def upscale():
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||
self.step_counter.assign(0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
||||
new_optimizer_variables = []
|
||||
for v in self.variables:
|
||||
new_optimizer_variables.append(scope.get_current_value(v))
|
||||
return trainable_variables, new_optimizer_variables
|
||||
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
|
||||
return [scope.get_current_value(v) for v in self._variables]
|
||||
|
||||
finite = self.check_finite(grads)
|
||||
return ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
|
||||
def increment():
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||
self.step_counter.assign_add(1)
|
||||
return [scope.get_current_value(v) for v in self._variables]
|
||||
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping):
|
||||
# Potentially upscale loss and reset counter.
|
||||
own_variables = ops.cond(
|
||||
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
|
||||
upscale,
|
||||
increment,
|
||||
)
|
||||
|
||||
# Unscale gradients.
|
||||
scale = self.dynamic_scale
|
||||
unscaled_grads = [
|
||||
g if g is None else ops.divide(g, scale) for g in grads
|
||||
]
|
||||
(
|
||||
new_trainable_variables,
|
||||
new_inner_variables,
|
||||
) = self.inner_optimizer.stateless_apply(
|
||||
self.inner_optimizer.variables,
|
||||
unscaled_grads,
|
||||
trainable_variables,
|
||||
)
|
||||
|
||||
new_optimizer_variables = own_variables + new_inner_variables
|
||||
return new_trainable_variables, new_optimizer_variables
|
||||
|
||||
def _stateless_handle_non_finite_grads(
|
||||
self, optimizer_variables, trainable_variables
|
||||
):
|
||||
mapping = list(zip(self.variables, optimizer_variables))
|
||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||
self.step_counter.assign(0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
||||
new_optimizer_variables = []
|
||||
for v in self.variables:
|
||||
new_optimizer_variables.append(scope.get_current_value(v))
|
||||
return trainable_variables, new_optimizer_variables
|
||||
|
||||
def apply(self, grads, trainable_variables=None):
|
||||
# Optionally build optimizer.
|
||||
@ -150,37 +161,91 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
||||
self.build(trainable_variables)
|
||||
self.built = True
|
||||
|
||||
def handle_finite_grads():
|
||||
scale = self.dynamic_scale
|
||||
# Unscale gradients.
|
||||
unscaled_grads = [
|
||||
g if g is None else ops.divide(g, scale) for g in grads
|
||||
]
|
||||
self.inner_optimizer.apply(
|
||||
unscaled_grads, trainable_variables=trainable_variables
|
||||
)
|
||||
if backend.backend() == "tensorflow":
|
||||
self._tf_apply(grads, trainable_variables)
|
||||
else:
|
||||
self._common_apply(grads, trainable_variables)
|
||||
|
||||
def upscale():
|
||||
self.step_counter.assign(0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
|
||||
def _stateful_handle_finite_grads(self, grads, trainable_variables):
|
||||
scale = self.dynamic_scale
|
||||
# Unscale gradients.
|
||||
unscaled_grads = [
|
||||
g if g is None else ops.divide(g, scale) for g in grads
|
||||
]
|
||||
self.inner_optimizer.apply(
|
||||
unscaled_grads, trainable_variables=trainable_variables
|
||||
)
|
||||
|
||||
def increment():
|
||||
self.step_counter.assign_add(1)
|
||||
|
||||
# Potentially upscale loss and reset counter.
|
||||
ops.cond(
|
||||
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
|
||||
upscale,
|
||||
increment,
|
||||
)
|
||||
|
||||
def handle_non_finite_grads():
|
||||
# If any inf or nan in grads, downscale loss and reset counter.
|
||||
def upscale():
|
||||
self.step_counter.assign(0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
|
||||
|
||||
def increment():
|
||||
self.step_counter.assign_add(1)
|
||||
|
||||
# Potentially upscale loss and reset counter.
|
||||
ops.cond(
|
||||
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
|
||||
upscale,
|
||||
increment,
|
||||
)
|
||||
|
||||
def _stateful_handle_non_finite_grads(self):
|
||||
# If any inf or nan in grads, downscale loss and reset counter.
|
||||
self.step_counter.assign(0)
|
||||
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
||||
|
||||
def _common_apply(self, grads, trainable_variables=None):
|
||||
finite = self.check_finite(grads)
|
||||
ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
|
||||
ops.cond(
|
||||
finite,
|
||||
lambda: self._stateful_handle_finite_grads(
|
||||
grads, trainable_variables
|
||||
),
|
||||
self._stateful_handle_non_finite_grads,
|
||||
)
|
||||
|
||||
def _tf_apply(self, grads, trainable_variables=None):
|
||||
"""Tensorflow specific logic for apply, which handles distribution."""
|
||||
from keras.utils.module_utils import tensorflow as tf
|
||||
|
||||
if tf.distribute.in_cross_replica_context():
|
||||
raise ValueError("apply() must be called in a replica context.")
|
||||
|
||||
if tf.__internal__.distribute.strategy_supports_no_merge_call():
|
||||
self._common_apply(grads, trainable_variables=trainable_variables)
|
||||
else:
|
||||
|
||||
def _handle_cross_replica(distribution, grads, trainable_variables):
|
||||
finite_per_replica = (
|
||||
distribution.extended.call_for_each_replica(
|
||||
self.check_finite, args=(grads,)
|
||||
)
|
||||
)
|
||||
# Each replica computed the same `finite` value, since
|
||||
# `grads` is all-reduced across replicas. Arbitrarily take
|
||||
# `finite` from the first replica.
|
||||
finite = distribution.experimental_local_results(
|
||||
finite_per_replica
|
||||
)[0]
|
||||
|
||||
def apply_fn():
|
||||
distribution.extended.call_for_each_replica(
|
||||
self._stateful_handle_finite_grads,
|
||||
args=(grads, trainable_variables),
|
||||
)
|
||||
|
||||
# Note: We must call this cond() in a cross-replica context.
|
||||
# DistributionStrategy does not support having a cond in a
|
||||
# replica context with a branch that calls `merge_call`, and
|
||||
# self._optimizer.apply_gradients calls `merge_call`.
|
||||
ops.cond(
|
||||
finite, apply_fn, self._stateful_handle_non_finite_grads
|
||||
)
|
||||
|
||||
tf.distribute.get_replica_context().merge_call(
|
||||
_handle_cross_replica, args=(grads, trainable_variables)
|
||||
)
|
||||
|
||||
def check_finite(self, grads):
|
||||
tensor_grads = [g for g in grads if g is not None]
|
||||
|
Loading…
Reference in New Issue
Block a user