diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index 68cf068d6..9aef7060a 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -157,7 +157,13 @@ devices = mesh_utils.create_device_mesh((8,)) # data will be split along the batch axis data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh # naming axes of the sharded partition -data_sharding = NamedSharding(data_mesh,P("batch",),) +data_sharding = NamedSharding( + data_mesh, + P( + "batch", + ), +) + # all variables will be replicated on all devices var_mesh = Mesh(devices, axis_names=("_")) # in NamedSharding, axes that are not mentioned are replicated (all axes here) @@ -269,7 +275,7 @@ def train_step(train_state, x, y): ) trainable_variables, optimizer_variables = optimizer.stateless_apply( - train_state.optimizer_variables, grads, train_state.trainable_variables + grads, train_state.trainable_variables, train_state.optimizer_variables ) return loss_value, TrainingState( diff --git a/keras_core/backend/torch/trainer.py b/keras_core/backend/torch/trainer.py index 4a214c97e..e46d41378 100644 --- a/keras_core/backend/torch/trainer.py +++ b/keras_core/backend/torch/trainer.py @@ -43,20 +43,16 @@ class TorchTrainer(base_trainer.Trainer): # Compute gradients if self.trainable_weights: - # Backpropagation - trainable_weights = [v for v in self.trainable_weights] - # Call torch.Tensor.backward() on the loss to compute gradients # for the weights. loss.backward() + trainable_weights = self.trainable_weights[:] gradients = [v.value.grad for v in trainable_weights] # Update weights with torch.no_grad(): - self.optimizer.apply_gradients( - zip(gradients, trainable_weights) - ) + self.optimizer.apply(gradients, trainable_weights) else: warnings.warn("The model does not have any trainable weights.") diff --git a/keras_core/optimizers/base_optimizer.py b/keras_core/optimizers/base_optimizer.py index d8f107ec1..192356eb8 100644 --- a/keras_core/optimizers/base_optimizer.py +++ b/keras_core/optimizers/base_optimizer.py @@ -193,7 +193,6 @@ class BaseOptimizer: `variables` can be provided on the first call to build the optimizer. """ - grads = list(grads) if len(grads) == 0: # It is possible that the grad is empty. In this case, # `apply_gradients` is a no-op. @@ -224,16 +223,15 @@ class BaseOptimizer: self.built = True self._check_variables_are_known(trainable_variables) - grads_and_vars = list(zip(grads, self._trainable_variables)) - with ops.name_scope(self.name): # Filter empty gradients. - grads_and_vars = self._filter_empty_gradients(grads_and_vars) - if len(list(grads_and_vars)) == 0: + grads, trainable_variables = self._filter_empty_gradients( + grads, trainable_variables + ) + if len(list(grads)) == 0: return # Apply clipping and weight decay. - grads, trainable_variables = zip(*grads_and_vars) grads = self._clip_gradients(grads) self._apply_weight_decay(trainable_variables) @@ -363,19 +361,27 @@ class BaseOptimizer: return self._learning_rate(self.iterations) return self._learning_rate - def _filter_empty_gradients(self, grads_and_vars): - filtered = [(g, v) for g, v in grads_and_vars if g is not None] - if not filtered: - raise ValueError("No gradients provided for any variable.") - if len(filtered) < len(grads_and_vars): - missing_grad_vars = [v for g, v in grads_and_vars if g is None] - warnings.warn( - "Gradients do not exist for variables " - f"{[v.name for v in missing_grad_vars]} when minimizing the " - "loss. If you're using `model.compile()`, did you forget to " - "provide a `loss` argument?" - ) - return filtered + def _filter_empty_gradients(self, grads, vars): + for grad in grads: + if grad is None: + # Filtering is required. + filtered = [ + (g, v) for g, v in zip(grads, vars) if g is not None + ] + if not filtered: + raise ValueError("No gradients provided for any variable.") + if len(filtered) < len(grads): + missing_grad_vars = [ + v for g, v in zip(grads, vars) if g is None + ] + warnings.warn( + "Gradients do not exist for variables " + f"{[v.name for v in missing_grad_vars]} when " + "minimizing the loss. If using `model.compile()`, " + "did you forget to provide a `loss` argument?" + ) + return zip(*filtered) + return grads, vars def _clip_gradients(self, grads): if self.clipnorm and self.clipnorm > 0: