Minor performance optimizations for eager.
This commit is contained in:
parent
13caae7284
commit
f41817b345
@ -157,7 +157,13 @@ devices = mesh_utils.create_device_mesh((8,))
|
||||
# data will be split along the batch axis
|
||||
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
|
||||
# naming axes of the sharded partition
|
||||
data_sharding = NamedSharding(data_mesh,P("batch",),)
|
||||
data_sharding = NamedSharding(
|
||||
data_mesh,
|
||||
P(
|
||||
"batch",
|
||||
),
|
||||
)
|
||||
|
||||
# all variables will be replicated on all devices
|
||||
var_mesh = Mesh(devices, axis_names=("_"))
|
||||
# in NamedSharding, axes that are not mentioned are replicated (all axes here)
|
||||
@ -269,7 +275,7 @@ def train_step(train_state, x, y):
|
||||
)
|
||||
|
||||
trainable_variables, optimizer_variables = optimizer.stateless_apply(
|
||||
train_state.optimizer_variables, grads, train_state.trainable_variables
|
||||
grads, train_state.trainable_variables, train_state.optimizer_variables
|
||||
)
|
||||
|
||||
return loss_value, TrainingState(
|
||||
|
@ -43,20 +43,16 @@ class TorchTrainer(base_trainer.Trainer):
|
||||
|
||||
# Compute gradients
|
||||
if self.trainable_weights:
|
||||
# Backpropagation
|
||||
trainable_weights = [v for v in self.trainable_weights]
|
||||
|
||||
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||
# for the weights.
|
||||
loss.backward()
|
||||
|
||||
trainable_weights = self.trainable_weights[:]
|
||||
gradients = [v.value.grad for v in trainable_weights]
|
||||
|
||||
# Update weights
|
||||
with torch.no_grad():
|
||||
self.optimizer.apply_gradients(
|
||||
zip(gradients, trainable_weights)
|
||||
)
|
||||
self.optimizer.apply(gradients, trainable_weights)
|
||||
else:
|
||||
warnings.warn("The model does not have any trainable weights.")
|
||||
|
||||
|
@ -193,7 +193,6 @@ class BaseOptimizer:
|
||||
|
||||
`variables` can be provided on the first call to build the optimizer.
|
||||
"""
|
||||
grads = list(grads)
|
||||
if len(grads) == 0:
|
||||
# It is possible that the grad is empty. In this case,
|
||||
# `apply_gradients` is a no-op.
|
||||
@ -224,16 +223,15 @@ class BaseOptimizer:
|
||||
self.built = True
|
||||
self._check_variables_are_known(trainable_variables)
|
||||
|
||||
grads_and_vars = list(zip(grads, self._trainable_variables))
|
||||
|
||||
with ops.name_scope(self.name):
|
||||
# Filter empty gradients.
|
||||
grads_and_vars = self._filter_empty_gradients(grads_and_vars)
|
||||
if len(list(grads_and_vars)) == 0:
|
||||
grads, trainable_variables = self._filter_empty_gradients(
|
||||
grads, trainable_variables
|
||||
)
|
||||
if len(list(grads)) == 0:
|
||||
return
|
||||
|
||||
# Apply clipping and weight decay.
|
||||
grads, trainable_variables = zip(*grads_and_vars)
|
||||
grads = self._clip_gradients(grads)
|
||||
self._apply_weight_decay(trainable_variables)
|
||||
|
||||
@ -363,19 +361,27 @@ class BaseOptimizer:
|
||||
return self._learning_rate(self.iterations)
|
||||
return self._learning_rate
|
||||
|
||||
def _filter_empty_gradients(self, grads_and_vars):
|
||||
filtered = [(g, v) for g, v in grads_and_vars if g is not None]
|
||||
if not filtered:
|
||||
raise ValueError("No gradients provided for any variable.")
|
||||
if len(filtered) < len(grads_and_vars):
|
||||
missing_grad_vars = [v for g, v in grads_and_vars if g is None]
|
||||
warnings.warn(
|
||||
"Gradients do not exist for variables "
|
||||
f"{[v.name for v in missing_grad_vars]} when minimizing the "
|
||||
"loss. If you're using `model.compile()`, did you forget to "
|
||||
"provide a `loss` argument?"
|
||||
)
|
||||
return filtered
|
||||
def _filter_empty_gradients(self, grads, vars):
|
||||
for grad in grads:
|
||||
if grad is None:
|
||||
# Filtering is required.
|
||||
filtered = [
|
||||
(g, v) for g, v in zip(grads, vars) if g is not None
|
||||
]
|
||||
if not filtered:
|
||||
raise ValueError("No gradients provided for any variable.")
|
||||
if len(filtered) < len(grads):
|
||||
missing_grad_vars = [
|
||||
v for g, v in zip(grads, vars) if g is None
|
||||
]
|
||||
warnings.warn(
|
||||
"Gradients do not exist for variables "
|
||||
f"{[v.name for v in missing_grad_vars]} when "
|
||||
"minimizing the loss. If using `model.compile()`, "
|
||||
"did you forget to provide a `loss` argument?"
|
||||
)
|
||||
return zip(*filtered)
|
||||
return grads, vars
|
||||
|
||||
def _clip_gradients(self, grads):
|
||||
if self.clipnorm and self.clipnorm > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user