From 59321fc09a8b32d931b5cb111c4644bbd515c75b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 18 Aug 2023 21:18:35 -0700 Subject: [PATCH] Add name_scope feature --- keras_core/backend/__init__.py | 1 + keras_core/backend/common/name_scope.py | 57 +++++++++++ keras_core/backend/common/name_scope_test.py | 36 +++++++ keras_core/backend/common/variables.py | 10 +- keras_core/backend/exports.py | 11 ++- keras_core/backend/jax/__init__.py | 1 - keras_core/backend/jax/core.py | 4 - keras_core/backend/numpy/__init__.py | 1 - keras_core/backend/numpy/core.py | 7 -- keras_core/backend/tensorflow/core.py | 35 ++++++- .../backend/tensorflow/name_scope_test.py | 37 ++++++++ keras_core/backend/torch/__init__.py | 1 - keras_core/backend/torch/core.py | 4 - keras_core/layers/layer.py | 95 ++++++++++--------- keras_core/layers/rnn/rnn.py | 16 ++-- keras_core/metrics/metric.py | 17 ++-- keras_core/models/model.py | 4 +- keras_core/models/variable_mapping.py | 19 ++-- keras_core/optimizers/adafactor.py | 15 +-- keras_core/optimizers/adam.py | 6 +- keras_core/optimizers/adamax.py | 4 +- keras_core/optimizers/base_optimizer.py | 45 +++++---- keras_core/optimizers/lion.py | 2 +- keras_core/optimizers/nadam.py | 4 +- keras_core/optimizers/sgd.py | 2 +- 25 files changed, 301 insertions(+), 133 deletions(-) create mode 100644 keras_core/backend/common/name_scope.py create mode 100644 keras_core/backend/common/name_scope_test.py create mode 100644 keras_core/backend/tensorflow/name_scope_test.py diff --git a/keras_core/backend/__init__.py b/keras_core/backend/__init__.py index 3f6ab97df..4c66b2663 100644 --- a/keras_core/backend/__init__.py +++ b/keras_core/backend/__init__.py @@ -9,6 +9,7 @@ if backend() == "torch": from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.backend.common.keras_tensor import any_symbolic_tensors from keras_core.backend.common.keras_tensor import is_keras_tensor +from keras_core.backend.common.name_scope import name_scope from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.backend.common.stateless_scope import get_stateless_scope from keras_core.backend.common.stateless_scope import in_stateless_scope diff --git a/keras_core/backend/common/name_scope.py b/keras_core/backend/common/name_scope.py new file mode 100644 index 000000000..72302ed6b --- /dev/null +++ b/keras_core/backend/common/name_scope.py @@ -0,0 +1,57 @@ +from keras_core.backend.common import global_state + + +class name_scope: + """Creates a sub-namespace for variable paths. + + Args: + name: Name of the current scope (string). + caller: Optional ID of a caller object (e.g. class instance). + deduplicate: If `True`, if `caller` was passed, + and the previous caller matches the current caller, + and the previous name matches the current name, + do not reenter a new namespace. + """ + + def __init__(self, name, caller=None, deduplicate=True): + if not isinstance(name, str) or "/" in name: + raise ValueError( + "Argument `name` must be a string and " + "cannot contain character `/`. " + f"Received: name={name}" + ) + self.name = name + self.caller = caller + self.deduplicate = deduplicate + self._pop_on_exit = False + + def __enter__(self): + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack", default=[], set_to_default=True + ) + if self.deduplicate and name_scope_stack: + parent_caller = name_scope_stack[-1].caller + parent_name = name_scope_stack[-1].name + if ( + self.caller is not None + and self.caller is parent_caller + and self.name == parent_name + ): + return self + name_scope_stack.append(self) + self._pop_on_exit = True + return self + + def __exit__(self, *args, **kwargs): + if self._pop_on_exit: + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack" + ) + name_scope_stack.pop() + + +def current_path(): + name_scope_stack = global_state.get_global_attribute("name_scope_stack") + if name_scope_stack is None: + return "" + return "/".join(x.name for x in name_scope_stack) diff --git a/keras_core/backend/common/name_scope_test.py b/keras_core/backend/common/name_scope_test.py new file mode 100644 index 000000000..b13f8d73a --- /dev/null +++ b/keras_core/backend/common/name_scope_test.py @@ -0,0 +1,36 @@ +from keras_core import testing +from keras_core.backend.common.name_scope import current_path +from keras_core.backend.common.name_scope import name_scope + + +class NameScopeTest(testing.TestCase): + def test_stacking(self): + self.assertEqual(current_path(), "") + with name_scope("outer") as outer: + self.assertEqual(outer.name, "outer") + self.assertEqual(current_path(), "outer") + with name_scope("middle") as middle: + self.assertEqual(middle.name, "middle") + self.assertEqual(current_path(), "outer/middle") + with name_scope("inner") as inner: + self.assertEqual(inner.name, "inner") + self.assertEqual(current_path(), "outer/middle/inner") + self.assertEqual(current_path(), "outer/middle") + self.assertEqual(current_path(), "outer") + self.assertEqual(current_path(), "") + + def test_deduplication(self): + self.assertEqual(current_path(), "") + with name_scope("name", caller=1): + with name_scope("name", caller=1): + self.assertEqual(current_path(), "name") + self.assertEqual(current_path(), "") + with name_scope("name"): + with name_scope("name"): + self.assertEqual(current_path(), "name/name") + + def test_errors(self): + with self.assertRaisesRegex(ValueError, "must be a string"): + name_scope("foo/bar") + with self.assertRaisesRegex(ValueError, "must be a string"): + name_scope(4) diff --git a/keras_core/backend/common/variables.py b/keras_core/backend/common/variables.py index 53bb5edf5..f9299664d 100644 --- a/keras_core/backend/common/variables.py +++ b/keras_core/backend/common/variables.py @@ -2,6 +2,7 @@ import numpy as np from keras_core.backend import config from keras_core.backend.common import global_state +from keras_core.backend.common.name_scope import current_path from keras_core.backend.common.stateless_scope import get_stateless_scope from keras_core.backend.common.stateless_scope import in_stateless_scope from keras_core.utils.naming import auto_name @@ -19,6 +20,11 @@ class KerasVariable: f"Received: name={name}" ) self.name = name + parent_path = current_path() + if parent_path: + self.path = current_path() + "/" + self.name + else: + self.path = self.name dtype = standardize_dtype(dtype) self._dtype = dtype self._shape = None @@ -68,7 +74,7 @@ class KerasVariable: def _deferred_initialize(self): if self._value is not None: - raise ValueError(f"Variable {self.name} is already initialized.") + raise ValueError(f"Variable {self.path} is already initialized.") if in_stateless_scope(): raise ValueError( @@ -147,7 +153,7 @@ class KerasVariable: def __repr__(self): return ( f"" + f"path={self.path}>" ) def _initialize(self, value): diff --git a/keras_core/backend/exports.py b/keras_core/backend/exports.py index 3fef17bbd..8aa97ecca 100644 --- a/keras_core/backend/exports.py +++ b/keras_core/backend/exports.py @@ -3,18 +3,27 @@ from keras_core.api_export import keras_core_export if backend.backend() == "tensorflow": BackendVariable = backend.tensorflow.core.Variable + backend_name_scope = backend.tensorflow.core.name_scope elif backend.backend() == "jax": BackendVariable = backend.jax.core.Variable + backend_name_scope = backend.common.name_scope.name_scope elif backend.backend() == "torch": BackendVariable = backend.torch.core.Variable + backend_name_scope = backend.common.name_scope.name_scope elif backend.backend() == "numpy": from keras_core.backend.numpy.core import Variable as NumpyVariable BackendVariable = NumpyVariable + backend_name_scope = backend.common.name_scope.name_scope else: raise RuntimeError(f"Invalid backend: {backend.backend()}") -@keras_core_export("keras_core.backend.Variable") +@keras_core_export("keras_core.Variable") class Variable(BackendVariable): pass + + +@keras_core_export("keras_core.name_scope") +class name_scope(backend_name_scope): + pass diff --git a/keras_core/backend/jax/__init__.py b/keras_core/backend/jax/__init__.py index 47975d733..0943a87d3 100644 --- a/keras_core/backend/jax/__init__.py +++ b/keras_core/backend/jax/__init__.py @@ -12,7 +12,6 @@ from keras_core.backend.jax.core import cond from keras_core.backend.jax.core import convert_to_numpy from keras_core.backend.jax.core import convert_to_tensor from keras_core.backend.jax.core import is_tensor -from keras_core.backend.jax.core import name_scope from keras_core.backend.jax.core import scatter from keras_core.backend.jax.core import shape from keras_core.backend.jax.core import stop_gradient diff --git a/keras_core/backend/jax/core.py b/keras_core/backend/jax/core.py index 019aff77a..2c50e48e1 100644 --- a/keras_core/backend/jax/core.py +++ b/keras_core/backend/jax/core.py @@ -60,10 +60,6 @@ def cast(x, dtype): return convert_to_tensor(x, dtype=dtype) -def name_scope(name): - return jax.named_scope(name) - - # Shape / dtype inference util def compute_output_spec(fn, *args, **kwargs): with StatelessScope(): diff --git a/keras_core/backend/numpy/__init__.py b/keras_core/backend/numpy/__init__.py index 6470587f2..e92d10e04 100644 --- a/keras_core/backend/numpy/__init__.py +++ b/keras_core/backend/numpy/__init__.py @@ -12,7 +12,6 @@ from keras_core.backend.numpy.core import cond from keras_core.backend.numpy.core import convert_to_numpy from keras_core.backend.numpy.core import convert_to_tensor from keras_core.backend.numpy.core import is_tensor -from keras_core.backend.numpy.core import name_scope from keras_core.backend.numpy.core import shape from keras_core.backend.numpy.core import vectorized_map from keras_core.backend.numpy.rnn import cudnn_ok diff --git a/keras_core/backend/numpy/core.py b/keras_core/backend/numpy/core.py index 880e00548..792cf8069 100644 --- a/keras_core/backend/numpy/core.py +++ b/keras_core/backend/numpy/core.py @@ -1,5 +1,3 @@ -from contextlib import nullcontext - import numpy as np from tensorflow import nest @@ -60,11 +58,6 @@ def cond(pred, true_fn, false_fn): return false_fn() -def name_scope(name): - # There is no need for a named context for NumPy. - return nullcontext() - - def vectorized_map(function, elements): if len(elements) == 1: return function(elements) diff --git a/keras_core/backend/tensorflow/core.py b/keras_core/backend/tensorflow/core.py index 284ef3a82..5399166be 100644 --- a/keras_core/backend/tensorflow/core.py +++ b/keras_core/backend/tensorflow/core.py @@ -5,8 +5,10 @@ import tensorflow as tf from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from keras_core.backend.common import KerasVariable +from keras_core.backend.common import global_state from keras_core.backend.common import standardize_dtype from keras_core.backend.common.keras_tensor import KerasTensor +from keras_core.backend.common.name_scope import name_scope as base_name_scope from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.utils.naming import auto_name @@ -111,10 +113,6 @@ def cast(x, dtype): return tf.cast(x, dtype=dtype) -def name_scope(name): - return tf.name_scope(name) - - def compute_output_spec(fn, *args, **kwargs): with StatelessScope(): graph_name = auto_name("scratch_graph") @@ -203,3 +201,32 @@ def stop_gradient(variable): def unstack(x, num=None, axis=0): return tf.unstack(x, num=num, axis=axis) + + +class name_scope(base_name_scope): + def __init__(self, name, **kwargs): + super().__init__(name, **kwargs) + self._tf_name_scope = tf.name_scope(name) + + def __enter__(self): + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack", default=[], set_to_default=True + ) + if self.deduplicate and name_scope_stack: + parent_caller = name_scope_stack[-1].caller + parent_name = name_scope_stack[-1].name + if ( + self.caller is not None + and self.caller is parent_caller + and self.name == parent_name + ): + return self + name_scope_stack.append(self) + self._pop_on_exit = True + self._tf_name_scope.__enter__() + return self + + def __exit__(self, *args, **kwargs): + super().__exit__(*args, **kwargs) + if self._pop_on_exit: + self._tf_name_scope.__exit__(*args, **kwargs) diff --git a/keras_core/backend/tensorflow/name_scope_test.py b/keras_core/backend/tensorflow/name_scope_test.py new file mode 100644 index 000000000..1bceff4d0 --- /dev/null +++ b/keras_core/backend/tensorflow/name_scope_test.py @@ -0,0 +1,37 @@ +import tensorflow as tf + +from keras_core.backend.tensorflow.core import name_scope +from keras_core.testing import TestCase + + +class TFNameScopeTest(TestCase): + def test_stacking(self): + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + with name_scope("outer") as outer: + self.assertEqual(outer.name, "outer") + self.assertEqual(tf.Variable(0, name="x").name, "outer/x:0") + with name_scope("middle") as middle: + self.assertEqual(middle.name, "middle") + self.assertEqual( + tf.Variable(0, name="x").name, "outer/middle/x:0" + ) + with name_scope("inner") as inner: + self.assertEqual(inner.name, "inner") + self.assertEqual( + tf.Variable(0, name="x").name, "outer/middle/inner/x:0" + ) + self.assertEqual( + tf.Variable(0, name="x").name, "outer/middle/x:0" + ) + self.assertEqual(tf.Variable(0, name="x").name, "outer/x:0") + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + + def test_deduplicate(self): + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + with name_scope("name", caller=1): + with name_scope("name", caller=1): + self.assertEqual(tf.Variable(0, name="x").name, "name/x:0") + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + with name_scope("name"): + with name_scope("name"): + self.assertEqual(tf.Variable(0, name="x").name, "name/name/x:0") diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index 2c5051273..bd5eb5fad 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -28,7 +28,6 @@ from keras_core.backend.torch.core import cond from keras_core.backend.torch.core import convert_to_numpy from keras_core.backend.torch.core import convert_to_tensor from keras_core.backend.torch.core import is_tensor -from keras_core.backend.torch.core import name_scope from keras_core.backend.torch.core import scatter from keras_core.backend.torch.core import shape from keras_core.backend.torch.core import stop_gradient diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index d2ced0f51..fa9209990 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -174,10 +174,6 @@ def cast(x, dtype): return convert_to_tensor(x, dtype) -def name_scope(name): - return contextlib.nullcontext() - - # Shape / dtype inference util def compute_output_spec(fn, *args, **kwargs): def has_none_shape(x): diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index cd6e37012..bc015a8b6 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -439,13 +439,14 @@ class Layer(BackendLayer, Operation): # TODO: handle layout self._check_super_called() initializer = initializers.get(initializer) - variable = backend.Variable( - initializer=initializer, - shape=shape, - dtype=dtype or self.variable_dtype, - trainable=trainable, - name=name, - ) + with backend.name_scope(self.name, caller=self): + variable = backend.Variable( + initializer=initializer, + shape=shape, + dtype=dtype or self.variable_dtype, + trainable=trainable, + name=name, + ) # Will be added to layer.losses variable.regularizer = regularizer variable.constraint = constraint @@ -683,7 +684,8 @@ class Layer(BackendLayer, Operation): ################ # 4. Call build. - self._maybe_build(call_spec) + with backend.name_scope(self.name, caller=self): + self._maybe_build(call_spec) ########################## # 5. Infer training value @@ -737,7 +739,7 @@ class Layer(BackendLayer, Operation): #################### # 7. Call the layer. try: - with backend.name_scope(self.name): + with backend.name_scope(self.name, caller=self): if self.autocast and self.compute_dtype != self.variable_dtype: # For mixed precision, we automatically cast layer variables # (float ones only) to the compute dtype upon access. @@ -1083,45 +1085,44 @@ class Layer(BackendLayer, Operation): shapes_dict = get_shapes_dict(call_spec) self._build_shapes_dict = shapes_dict - with backend.name_scope(self.name): - if not utils.is_default(self.build): - shapes_dict = update_shapes_dict_for_target_fn( - self.build, - shapes_dict=shapes_dict, - call_spec=call_spec, - class_name=self.__class__.__name__, + if not utils.is_default(self.build): + shapes_dict = update_shapes_dict_for_target_fn( + self.build, + shapes_dict=shapes_dict, + call_spec=call_spec, + class_name=self.__class__.__name__, + ) + self.build(**shapes_dict) + elif might_have_unbuilt_state(self): + if len(shapes_dict) == 1: + # Single arg: pass it positionally + success = self._build_by_run_for_single_pos_arg( + tuple(shapes_dict.values())[0] + ) + else: + success = self._build_by_run_for_kwargs(shapes_dict) + if not success: + if call_spec.eager: + # Will let the actual eager call do state-building + return + raise ValueError( + f"Layer '{self.name}' looks like it has " + "unbuilt state, but Keras is not able to " + "trace the layer `call()` in order to " + "build it automatically. Possible causes:\n" + "1. The `call()` method of your layer may be " + "crashing. Try to `__call__()` the layer " + "eagerly on some test input " + "first to see if it works. " + "E.g. `x = np.random.random((3, 4)); " + "y = layer(x)`\n" + "2. If the `call()` method is correct, " + "then you may need to implement " + "the `def build(self, input_shape)` method on your " + "layer. It should create all variables used by the " + "layer (e.g. by calling `layer.build()` on all its " + "children layers)." ) - self.build(**shapes_dict) - elif might_have_unbuilt_state(self): - if len(shapes_dict) == 1: - # Single arg: pass it positionally - success = self._build_by_run_for_single_pos_arg( - tuple(shapes_dict.values())[0] - ) - else: - success = self._build_by_run_for_kwargs(shapes_dict) - if not success: - if call_spec.eager: - # Will let the actual eager call do state-building - return - raise ValueError( - f"Layer '{self.name}' looks like it has " - "unbuilt state, but Keras is not able to " - "trace the layer `call()` in order to " - "build it automatically. Possible causes:\n" - "1. The `call()` method of your layer may be " - "crashing. Try to `__call__()` the layer " - "eagerly on some test input " - "first to see if it works. " - "E.g. `x = np.random.random((3, 4)); " - "y = layer(x)`\n" - "2. If the `call()` method is correct, " - "then you may need to implement " - "the `def build(self, input_shape)` method on your " - "layer. It should create all variables used by the " - "layer (e.g. by calling `layer.build()` on all its " - "children layers)." - ) self.built = True # Check input spec again (after build, since self.input_spec diff --git a/keras_core/layers/rnn/rnn.py b/keras_core/layers/rnn/rnn.py index b0460bf47..b5c3baf8d 100644 --- a/keras_core/layers/rnn/rnn.py +++ b/keras_core/layers/rnn/rnn.py @@ -288,12 +288,16 @@ class RNN(Layer): @tracking.no_automatic_dependency_tracking def _create_state_variables(self, batch_size): - self.states = tree.map_structure( - lambda value: backend.Variable( - value, trainable=False, dtype=self.variable_dtype - ), - self.get_initial_state(batch_size), - ) + with backend.name_scope(self.name, caller=self): + self.states = tree.map_structure( + lambda value: backend.Variable( + value, + trainable=False, + dtype=self.variable_dtype, + name="rnn_state", + ), + self.get_initial_state(batch_size), + ) def get_initial_state(self, batch_size): get_initial_state_fn = getattr(self.cell, "get_initial_state", None) diff --git a/keras_core/metrics/metric.py b/keras_core/metrics/metric.py index a7d324bea..fa98585a3 100644 --- a/keras_core/metrics/metric.py +++ b/keras_core/metrics/metric.py @@ -166,14 +166,15 @@ class Metric: def add_variable(self, shape, initializer, dtype=None, name=None): self._check_super_called() - initializer = initializers.get(initializer) - variable = backend.Variable( - initializer=initializer, - shape=shape, - dtype=dtype, - trainable=False, - name=name, - ) + with backend.name_scope(self.name, caller=self): + initializer = initializers.get(initializer) + variable = backend.Variable( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=False, + name=name, + ) # Prevent double-tracking self._tracker.add_to_store("variables", variable) return variable diff --git a/keras_core/models/model.py b/keras_core/models/model.py index f27557add..c467bbc0a 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -542,9 +542,7 @@ class Model(Trainer, Layer): def _get_variable_map(self): store = {} - map_trackable_variables( - self, store=store, inner_path="", visited_trackables=set() - ) + map_trackable_variables(self, store=store, visited_trackables=set()) return store diff --git a/keras_core/models/variable_mapping.py b/keras_core/models/variable_mapping.py index 90d121375..e443c9536 100644 --- a/keras_core/models/variable_mapping.py +++ b/keras_core/models/variable_mapping.py @@ -2,10 +2,9 @@ from keras_core.layers.layer import Layer from keras_core.metrics.metric import Metric from keras_core.optimizers.optimizer import Optimizer from keras_core.saving import saving_lib -from keras_core.utils import file_utils -def map_trackable_variables(trackable, store, inner_path, visited_trackables): +def map_trackable_variables(trackable, store, visited_trackables): # If the trackable has already been saved, skip it. if id(trackable) in visited_trackables: return @@ -22,7 +21,15 @@ def map_trackable_variables(trackable, store, inner_path, visited_trackables): elif isinstance(trackable, Metric): variables = trackable._variables for v in variables: - store[inner_path + "/" + v.name] = v + if v.path in store: + raise ValueError( + "The model contains two variables with a duplicate path: " + f"path='{v.path}' appears at least twice. " + f"This path is used for {v} and for {store[v.path]}. " + "In order to get a variable map, make sure to use " + "unique paths/names for each variable." + ) + store[v.path] = v # Recursively save state of children trackables (layers, optimizers, etc.) for child_attr, child_obj in saving_lib._walk_trackable(trackable): @@ -30,28 +37,24 @@ def map_trackable_variables(trackable, store, inner_path, visited_trackables): map_trackable_variables( child_obj, store, - inner_path=file_utils.join(inner_path, child_obj.name), visited_trackables=visited_trackables, ) elif isinstance(child_obj, (list, dict, tuple, set)): map_container_variables( child_obj, store, - inner_path=file_utils.join(inner_path, child_attr), visited_trackables=visited_trackables, ) -def map_container_variables(container, store, inner_path, visited_trackables): +def map_container_variables(container, store, visited_trackables): if isinstance(container, dict): container = list(container.values()) for trackable in container: if saving_lib._is_keras_trackable(trackable): - name = trackable.name map_trackable_variables( trackable, store, - inner_path=file_utils.join(inner_path, name), visited_trackables=visited_trackables, ) diff --git a/keras_core/optimizers/adafactor.py b/keras_core/optimizers/adafactor.py index b14fc423e..bc328480f 100644 --- a/keras_core/optimizers/adafactor.py +++ b/keras_core/optimizers/adafactor.py @@ -96,12 +96,13 @@ class Adafactor(optimizer.Optimizer): if len(var.shape) < 2: # Don't factor if variable is of dimension < 2, but we still # need to create dummy variables as placeholder. - self._r.append( - backend.Variable(0, name=var.name, trainable=False) - ) - self._c.append( - backend.Variable(0, name=var.name, trainable=False) - ) + with backend.name_scope(self.name, caller=self): + self._r.append( + backend.Variable(0, name=var.name, trainable=False) + ) + self._c.append( + backend.Variable(0, name=var.name, trainable=False) + ) else: # Always factor the last 2 dimenstions. r_shape = var.shape[:-1] @@ -122,7 +123,7 @@ class Adafactor(optimizer.Optimizer): ) self._v.append( self.add_variable_from_reference( - reference_variable=var, name="v" + reference_variable=var, name="velocity" ) ) diff --git a/keras_core/optimizers/adam.py b/keras_core/optimizers/adam.py index 6107cad63..a6aed7b82 100644 --- a/keras_core/optimizers/adam.py +++ b/keras_core/optimizers/adam.py @@ -91,12 +91,12 @@ class Adam(optimizer.Optimizer): for var in var_list: self._momentums.append( self.add_variable_from_reference( - reference_variable=var, name="m" + reference_variable=var, name="momentum" ) ) self._velocities.append( self.add_variable_from_reference( - reference_variable=var, name="v" + reference_variable=var, name="velocity" ) ) if self.amsgrad: @@ -104,7 +104,7 @@ class Adam(optimizer.Optimizer): for var in var_list: self._velocity_hats.append( self.add_variable_from_reference( - reference_variable=var, name="vhat" + reference_variable=var, name="velocity_hat" ) ) diff --git a/keras_core/optimizers/adamax.py b/keras_core/optimizers/adamax.py index a305436a3..ff904b9a7 100644 --- a/keras_core/optimizers/adamax.py +++ b/keras_core/optimizers/adamax.py @@ -99,12 +99,12 @@ class Adamax(optimizer.Optimizer): for var in var_list: self._m.append( self.add_variable_from_reference( - reference_variable=var, name="m" + reference_variable=var, name="momentum" ) ) self._u.append( self.add_variable_from_reference( - reference_variable=var, name="u" + reference_variable=var, name="norm" ) ) diff --git a/keras_core/optimizers/base_optimizer.py b/keras_core/optimizers/base_optimizer.py index a408e9247..12aef014b 100644 --- a/keras_core/optimizers/base_optimizer.py +++ b/keras_core/optimizers/base_optimizer.py @@ -33,6 +33,8 @@ class BaseOptimizer: if kwargs: raise ValueError(f"Argument(s) not recognized: {kwargs}") + if name is None: + name = auto_name(self.__class__.__name__) self.name = name self.weight_decay = weight_decay self.clipnorm = clipnorm @@ -83,9 +85,10 @@ class BaseOptimizer: # Create iteration variable # Note: dtype="int" will resolve to int32 in JAX # (since int64 is disallowed in JAX) and to int64 in TF. - iterations = backend.Variable( - 0, name="iteration", dtype="int", trainable=False - ) + with backend.name_scope(self.name, caller=self): + iterations = backend.Variable( + 0, name="iteration", dtype="int", trainable=False + ) self._track_variable(iterations) self.iterations = iterations @@ -105,12 +108,13 @@ class BaseOptimizer: "and returns the corresponding learning rate value). " f"Received instead: learning_rate={learning_rate}" ) - learning_rate = backend.Variable( - learning_rate, - name="learning_rate", - dtype=backend.floatx(), - trainable=False, - ) + with backend.name_scope(self.name, caller=self): + learning_rate = backend.Variable( + learning_rate, + name="learning_rate", + dtype=backend.floatx(), + trainable=False, + ) self._track_variable(learning_rate) self._learning_rate = learning_rate @@ -154,13 +158,14 @@ class BaseOptimizer: ): self._check_super_called() initializer = initializers.get(initializer) - variable = backend.Variable( - initializer=initializer, - shape=shape, - dtype=dtype, - trainable=False, - name=name, - ) + with backend.name_scope(self.name, caller=self): + variable = backend.Variable( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=False, + name=name, + ) self._track_variable(variable) return variable @@ -169,12 +174,12 @@ class BaseOptimizer: variable. """ initializer = initializers.Zeros() - name = name or auto_name(self.__class__.__name__) + name = name or "var" return self.add_variable( shape=reference_variable.shape, initializer=initializer, dtype=reference_variable.dtype, - name=name, + name=reference_variable.path.replace("/", "_") + "_" + name, ) def _check_variables_are_known(self, variables): @@ -226,12 +231,12 @@ class BaseOptimizer: trainable_variables = list(trainable_variables) # Optionally build optimizer. if not self.built: - with ops.name_scope(self.name): + with backend.name_scope(self.name, caller=self): self.build(trainable_variables) self.built = True self._check_variables_are_known(trainable_variables) - with ops.name_scope(self.name): + with backend.name_scope(self.name, caller=self): # Filter empty gradients. grads, trainable_variables = self._filter_empty_gradients( grads, trainable_variables diff --git a/keras_core/optimizers/lion.py b/keras_core/optimizers/lion.py index 425d0c812..22c6f1b4d 100644 --- a/keras_core/optimizers/lion.py +++ b/keras_core/optimizers/lion.py @@ -91,7 +91,7 @@ class Lion(optimizer.Optimizer): for var in var_list: self._momentums.append( self.add_variable_from_reference( - reference_variable=var, name="m" + reference_variable=var, name="momentum" ) ) diff --git a/keras_core/optimizers/nadam.py b/keras_core/optimizers/nadam.py index efcbb6387..98b9df08d 100644 --- a/keras_core/optimizers/nadam.py +++ b/keras_core/optimizers/nadam.py @@ -90,12 +90,12 @@ class Nadam(optimizer.Optimizer): for var in var_list: self._momentums.append( self.add_variable_from_reference( - reference_variable=var, name="m" + reference_variable=var, name="momentum" ) ) self._velocities.append( self.add_variable_from_reference( - reference_variable=var, name="v" + reference_variable=var, name="velocity" ) ) diff --git a/keras_core/optimizers/sgd.py b/keras_core/optimizers/sgd.py index b22d04529..9aa29b11a 100644 --- a/keras_core/optimizers/sgd.py +++ b/keras_core/optimizers/sgd.py @@ -89,7 +89,7 @@ class SGD(optimizer.Optimizer): for variable in variables: self.momentums.append( self.add_variable_from_reference( - reference_variable=variable, name="m" + reference_variable=variable, name="momentum" ) )