Minor performance optimizations for eager.

This commit is contained in:
Francois Chollet 2023-07-12 13:43:52 -07:00
parent 13caae7284
commit f41817b345
3 changed files with 35 additions and 27 deletions

@ -157,7 +157,13 @@ devices = mesh_utils.create_device_mesh((8,))
# data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
# naming axes of the sharded partition
data_sharding = NamedSharding(data_mesh,P("batch",),)
data_sharding = NamedSharding(
data_mesh,
P(
"batch",
),
)
# all variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=("_"))
# in NamedSharding, axes that are not mentioned are replicated (all axes here)
@ -269,7 +275,7 @@ def train_step(train_state, x, y):
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
train_state.optimizer_variables, grads, train_state.trainable_variables
grads, train_state.trainable_variables, train_state.optimizer_variables
)
return loss_value, TrainingState(

@ -43,20 +43,16 @@ class TorchTrainer(base_trainer.Trainer):
# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
trainable_weights = self.trainable_weights[:]
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
)
self.optimizer.apply(gradients, trainable_weights)
else:
warnings.warn("The model does not have any trainable weights.")

@ -193,7 +193,6 @@ class BaseOptimizer:
`variables` can be provided on the first call to build the optimizer.
"""
grads = list(grads)
if len(grads) == 0:
# It is possible that the grad is empty. In this case,
# `apply_gradients` is a no-op.
@ -224,16 +223,15 @@ class BaseOptimizer:
self.built = True
self._check_variables_are_known(trainable_variables)
grads_and_vars = list(zip(grads, self._trainable_variables))
with ops.name_scope(self.name):
# Filter empty gradients.
grads_and_vars = self._filter_empty_gradients(grads_and_vars)
if len(list(grads_and_vars)) == 0:
grads, trainable_variables = self._filter_empty_gradients(
grads, trainable_variables
)
if len(list(grads)) == 0:
return
# Apply clipping and weight decay.
grads, trainable_variables = zip(*grads_and_vars)
grads = self._clip_gradients(grads)
self._apply_weight_decay(trainable_variables)
@ -363,19 +361,27 @@ class BaseOptimizer:
return self._learning_rate(self.iterations)
return self._learning_rate
def _filter_empty_gradients(self, grads_and_vars):
filtered = [(g, v) for g, v in grads_and_vars if g is not None]
if not filtered:
raise ValueError("No gradients provided for any variable.")
if len(filtered) < len(grads_and_vars):
missing_grad_vars = [v for g, v in grads_and_vars if g is None]
warnings.warn(
"Gradients do not exist for variables "
f"{[v.name for v in missing_grad_vars]} when minimizing the "
"loss. If you're using `model.compile()`, did you forget to "
"provide a `loss` argument?"
)
return filtered
def _filter_empty_gradients(self, grads, vars):
for grad in grads:
if grad is None:
# Filtering is required.
filtered = [
(g, v) for g, v in zip(grads, vars) if g is not None
]
if not filtered:
raise ValueError("No gradients provided for any variable.")
if len(filtered) < len(grads):
missing_grad_vars = [
v for g, v in zip(grads, vars) if g is None
]
warnings.warn(
"Gradients do not exist for variables "
f"{[v.name for v in missing_grad_vars]} when "
"minimizing the loss. If using `model.compile()`, "
"did you forget to provide a `loss` argument?"
)
return zip(*filtered)
return grads, vars
def _clip_gradients(self, grads):
if self.clipnorm and self.clipnorm > 0: