Add name_scope feature
This commit is contained in:
parent
6b53ab21cd
commit
59321fc09a
@ -9,6 +9,7 @@ if backend() == "torch":
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.backend.common.keras_tensor import any_symbolic_tensors
|
||||
from keras_core.backend.common.keras_tensor import is_keras_tensor
|
||||
from keras_core.backend.common.name_scope import name_scope
|
||||
from keras_core.backend.common.stateless_scope import StatelessScope
|
||||
from keras_core.backend.common.stateless_scope import get_stateless_scope
|
||||
from keras_core.backend.common.stateless_scope import in_stateless_scope
|
||||
|
57
keras_core/backend/common/name_scope.py
Normal file
57
keras_core/backend/common/name_scope.py
Normal file
@ -0,0 +1,57 @@
|
||||
from keras_core.backend.common import global_state
|
||||
|
||||
|
||||
class name_scope:
|
||||
"""Creates a sub-namespace for variable paths.
|
||||
|
||||
Args:
|
||||
name: Name of the current scope (string).
|
||||
caller: Optional ID of a caller object (e.g. class instance).
|
||||
deduplicate: If `True`, if `caller` was passed,
|
||||
and the previous caller matches the current caller,
|
||||
and the previous name matches the current name,
|
||||
do not reenter a new namespace.
|
||||
"""
|
||||
|
||||
def __init__(self, name, caller=None, deduplicate=True):
|
||||
if not isinstance(name, str) or "/" in name:
|
||||
raise ValueError(
|
||||
"Argument `name` must be a string and "
|
||||
"cannot contain character `/`. "
|
||||
f"Received: name={name}"
|
||||
)
|
||||
self.name = name
|
||||
self.caller = caller
|
||||
self.deduplicate = deduplicate
|
||||
self._pop_on_exit = False
|
||||
|
||||
def __enter__(self):
|
||||
name_scope_stack = global_state.get_global_attribute(
|
||||
"name_scope_stack", default=[], set_to_default=True
|
||||
)
|
||||
if self.deduplicate and name_scope_stack:
|
||||
parent_caller = name_scope_stack[-1].caller
|
||||
parent_name = name_scope_stack[-1].name
|
||||
if (
|
||||
self.caller is not None
|
||||
and self.caller is parent_caller
|
||||
and self.name == parent_name
|
||||
):
|
||||
return self
|
||||
name_scope_stack.append(self)
|
||||
self._pop_on_exit = True
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
if self._pop_on_exit:
|
||||
name_scope_stack = global_state.get_global_attribute(
|
||||
"name_scope_stack"
|
||||
)
|
||||
name_scope_stack.pop()
|
||||
|
||||
|
||||
def current_path():
|
||||
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
|
||||
if name_scope_stack is None:
|
||||
return ""
|
||||
return "/".join(x.name for x in name_scope_stack)
|
36
keras_core/backend/common/name_scope_test.py
Normal file
36
keras_core/backend/common/name_scope_test.py
Normal file
@ -0,0 +1,36 @@
|
||||
from keras_core import testing
|
||||
from keras_core.backend.common.name_scope import current_path
|
||||
from keras_core.backend.common.name_scope import name_scope
|
||||
|
||||
|
||||
class NameScopeTest(testing.TestCase):
|
||||
def test_stacking(self):
|
||||
self.assertEqual(current_path(), "")
|
||||
with name_scope("outer") as outer:
|
||||
self.assertEqual(outer.name, "outer")
|
||||
self.assertEqual(current_path(), "outer")
|
||||
with name_scope("middle") as middle:
|
||||
self.assertEqual(middle.name, "middle")
|
||||
self.assertEqual(current_path(), "outer/middle")
|
||||
with name_scope("inner") as inner:
|
||||
self.assertEqual(inner.name, "inner")
|
||||
self.assertEqual(current_path(), "outer/middle/inner")
|
||||
self.assertEqual(current_path(), "outer/middle")
|
||||
self.assertEqual(current_path(), "outer")
|
||||
self.assertEqual(current_path(), "")
|
||||
|
||||
def test_deduplication(self):
|
||||
self.assertEqual(current_path(), "")
|
||||
with name_scope("name", caller=1):
|
||||
with name_scope("name", caller=1):
|
||||
self.assertEqual(current_path(), "name")
|
||||
self.assertEqual(current_path(), "")
|
||||
with name_scope("name"):
|
||||
with name_scope("name"):
|
||||
self.assertEqual(current_path(), "name/name")
|
||||
|
||||
def test_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, "must be a string"):
|
||||
name_scope("foo/bar")
|
||||
with self.assertRaisesRegex(ValueError, "must be a string"):
|
||||
name_scope(4)
|
@ -2,6 +2,7 @@ import numpy as np
|
||||
|
||||
from keras_core.backend import config
|
||||
from keras_core.backend.common import global_state
|
||||
from keras_core.backend.common.name_scope import current_path
|
||||
from keras_core.backend.common.stateless_scope import get_stateless_scope
|
||||
from keras_core.backend.common.stateless_scope import in_stateless_scope
|
||||
from keras_core.utils.naming import auto_name
|
||||
@ -19,6 +20,11 @@ class KerasVariable:
|
||||
f"Received: name={name}"
|
||||
)
|
||||
self.name = name
|
||||
parent_path = current_path()
|
||||
if parent_path:
|
||||
self.path = current_path() + "/" + self.name
|
||||
else:
|
||||
self.path = self.name
|
||||
dtype = standardize_dtype(dtype)
|
||||
self._dtype = dtype
|
||||
self._shape = None
|
||||
@ -68,7 +74,7 @@ class KerasVariable:
|
||||
|
||||
def _deferred_initialize(self):
|
||||
if self._value is not None:
|
||||
raise ValueError(f"Variable {self.name} is already initialized.")
|
||||
raise ValueError(f"Variable {self.path} is already initialized.")
|
||||
|
||||
if in_stateless_scope():
|
||||
raise ValueError(
|
||||
@ -147,7 +153,7 @@ class KerasVariable:
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
|
||||
f"name={self.name}>"
|
||||
f"path={self.path}>"
|
||||
)
|
||||
|
||||
def _initialize(self, value):
|
||||
|
@ -3,18 +3,27 @@ from keras_core.api_export import keras_core_export
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
BackendVariable = backend.tensorflow.core.Variable
|
||||
backend_name_scope = backend.tensorflow.core.name_scope
|
||||
elif backend.backend() == "jax":
|
||||
BackendVariable = backend.jax.core.Variable
|
||||
backend_name_scope = backend.common.name_scope.name_scope
|
||||
elif backend.backend() == "torch":
|
||||
BackendVariable = backend.torch.core.Variable
|
||||
backend_name_scope = backend.common.name_scope.name_scope
|
||||
elif backend.backend() == "numpy":
|
||||
from keras_core.backend.numpy.core import Variable as NumpyVariable
|
||||
|
||||
BackendVariable = NumpyVariable
|
||||
backend_name_scope = backend.common.name_scope.name_scope
|
||||
else:
|
||||
raise RuntimeError(f"Invalid backend: {backend.backend()}")
|
||||
|
||||
|
||||
@keras_core_export("keras_core.backend.Variable")
|
||||
@keras_core_export("keras_core.Variable")
|
||||
class Variable(BackendVariable):
|
||||
pass
|
||||
|
||||
|
||||
@keras_core_export("keras_core.name_scope")
|
||||
class name_scope(backend_name_scope):
|
||||
pass
|
||||
|
@ -12,7 +12,6 @@ from keras_core.backend.jax.core import cond
|
||||
from keras_core.backend.jax.core import convert_to_numpy
|
||||
from keras_core.backend.jax.core import convert_to_tensor
|
||||
from keras_core.backend.jax.core import is_tensor
|
||||
from keras_core.backend.jax.core import name_scope
|
||||
from keras_core.backend.jax.core import scatter
|
||||
from keras_core.backend.jax.core import shape
|
||||
from keras_core.backend.jax.core import stop_gradient
|
||||
|
@ -60,10 +60,6 @@ def cast(x, dtype):
|
||||
return convert_to_tensor(x, dtype=dtype)
|
||||
|
||||
|
||||
def name_scope(name):
|
||||
return jax.named_scope(name)
|
||||
|
||||
|
||||
# Shape / dtype inference util
|
||||
def compute_output_spec(fn, *args, **kwargs):
|
||||
with StatelessScope():
|
||||
|
@ -12,7 +12,6 @@ from keras_core.backend.numpy.core import cond
|
||||
from keras_core.backend.numpy.core import convert_to_numpy
|
||||
from keras_core.backend.numpy.core import convert_to_tensor
|
||||
from keras_core.backend.numpy.core import is_tensor
|
||||
from keras_core.backend.numpy.core import name_scope
|
||||
from keras_core.backend.numpy.core import shape
|
||||
from keras_core.backend.numpy.core import vectorized_map
|
||||
from keras_core.backend.numpy.rnn import cudnn_ok
|
||||
|
@ -1,5 +1,3 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import numpy as np
|
||||
from tensorflow import nest
|
||||
|
||||
@ -60,11 +58,6 @@ def cond(pred, true_fn, false_fn):
|
||||
return false_fn()
|
||||
|
||||
|
||||
def name_scope(name):
|
||||
# There is no need for a named context for NumPy.
|
||||
return nullcontext()
|
||||
|
||||
|
||||
def vectorized_map(function, elements):
|
||||
if len(elements) == 1:
|
||||
return function(elements)
|
||||
|
@ -5,8 +5,10 @@ import tensorflow as tf
|
||||
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
|
||||
|
||||
from keras_core.backend.common import KerasVariable
|
||||
from keras_core.backend.common import global_state
|
||||
from keras_core.backend.common import standardize_dtype
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.backend.common.name_scope import name_scope as base_name_scope
|
||||
from keras_core.backend.common.stateless_scope import StatelessScope
|
||||
from keras_core.utils.naming import auto_name
|
||||
|
||||
@ -111,10 +113,6 @@ def cast(x, dtype):
|
||||
return tf.cast(x, dtype=dtype)
|
||||
|
||||
|
||||
def name_scope(name):
|
||||
return tf.name_scope(name)
|
||||
|
||||
|
||||
def compute_output_spec(fn, *args, **kwargs):
|
||||
with StatelessScope():
|
||||
graph_name = auto_name("scratch_graph")
|
||||
@ -203,3 +201,32 @@ def stop_gradient(variable):
|
||||
|
||||
def unstack(x, num=None, axis=0):
|
||||
return tf.unstack(x, num=num, axis=axis)
|
||||
|
||||
|
||||
class name_scope(base_name_scope):
|
||||
def __init__(self, name, **kwargs):
|
||||
super().__init__(name, **kwargs)
|
||||
self._tf_name_scope = tf.name_scope(name)
|
||||
|
||||
def __enter__(self):
|
||||
name_scope_stack = global_state.get_global_attribute(
|
||||
"name_scope_stack", default=[], set_to_default=True
|
||||
)
|
||||
if self.deduplicate and name_scope_stack:
|
||||
parent_caller = name_scope_stack[-1].caller
|
||||
parent_name = name_scope_stack[-1].name
|
||||
if (
|
||||
self.caller is not None
|
||||
and self.caller is parent_caller
|
||||
and self.name == parent_name
|
||||
):
|
||||
return self
|
||||
name_scope_stack.append(self)
|
||||
self._pop_on_exit = True
|
||||
self._tf_name_scope.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
super().__exit__(*args, **kwargs)
|
||||
if self._pop_on_exit:
|
||||
self._tf_name_scope.__exit__(*args, **kwargs)
|
||||
|
37
keras_core/backend/tensorflow/name_scope_test.py
Normal file
37
keras_core/backend/tensorflow/name_scope_test.py
Normal file
@ -0,0 +1,37 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from keras_core.backend.tensorflow.core import name_scope
|
||||
from keras_core.testing import TestCase
|
||||
|
||||
|
||||
class TFNameScopeTest(TestCase):
|
||||
def test_stacking(self):
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "x:0")
|
||||
with name_scope("outer") as outer:
|
||||
self.assertEqual(outer.name, "outer")
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "outer/x:0")
|
||||
with name_scope("middle") as middle:
|
||||
self.assertEqual(middle.name, "middle")
|
||||
self.assertEqual(
|
||||
tf.Variable(0, name="x").name, "outer/middle/x:0"
|
||||
)
|
||||
with name_scope("inner") as inner:
|
||||
self.assertEqual(inner.name, "inner")
|
||||
self.assertEqual(
|
||||
tf.Variable(0, name="x").name, "outer/middle/inner/x:0"
|
||||
)
|
||||
self.assertEqual(
|
||||
tf.Variable(0, name="x").name, "outer/middle/x:0"
|
||||
)
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "outer/x:0")
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "x:0")
|
||||
|
||||
def test_deduplicate(self):
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "x:0")
|
||||
with name_scope("name", caller=1):
|
||||
with name_scope("name", caller=1):
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "name/x:0")
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "x:0")
|
||||
with name_scope("name"):
|
||||
with name_scope("name"):
|
||||
self.assertEqual(tf.Variable(0, name="x").name, "name/name/x:0")
|
@ -28,7 +28,6 @@ from keras_core.backend.torch.core import cond
|
||||
from keras_core.backend.torch.core import convert_to_numpy
|
||||
from keras_core.backend.torch.core import convert_to_tensor
|
||||
from keras_core.backend.torch.core import is_tensor
|
||||
from keras_core.backend.torch.core import name_scope
|
||||
from keras_core.backend.torch.core import scatter
|
||||
from keras_core.backend.torch.core import shape
|
||||
from keras_core.backend.torch.core import stop_gradient
|
||||
|
@ -174,10 +174,6 @@ def cast(x, dtype):
|
||||
return convert_to_tensor(x, dtype)
|
||||
|
||||
|
||||
def name_scope(name):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
# Shape / dtype inference util
|
||||
def compute_output_spec(fn, *args, **kwargs):
|
||||
def has_none_shape(x):
|
||||
|
@ -439,13 +439,14 @@ class Layer(BackendLayer, Operation):
|
||||
# TODO: handle layout
|
||||
self._check_super_called()
|
||||
initializer = initializers.get(initializer)
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype or self.variable_dtype,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype or self.variable_dtype,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
)
|
||||
# Will be added to layer.losses
|
||||
variable.regularizer = regularizer
|
||||
variable.constraint = constraint
|
||||
@ -683,7 +684,8 @@ class Layer(BackendLayer, Operation):
|
||||
|
||||
################
|
||||
# 4. Call build.
|
||||
self._maybe_build(call_spec)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
self._maybe_build(call_spec)
|
||||
|
||||
##########################
|
||||
# 5. Infer training value
|
||||
@ -737,7 +739,7 @@ class Layer(BackendLayer, Operation):
|
||||
####################
|
||||
# 7. Call the layer.
|
||||
try:
|
||||
with backend.name_scope(self.name):
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
if self.autocast and self.compute_dtype != self.variable_dtype:
|
||||
# For mixed precision, we automatically cast layer variables
|
||||
# (float ones only) to the compute dtype upon access.
|
||||
@ -1083,45 +1085,44 @@ class Layer(BackendLayer, Operation):
|
||||
shapes_dict = get_shapes_dict(call_spec)
|
||||
self._build_shapes_dict = shapes_dict
|
||||
|
||||
with backend.name_scope(self.name):
|
||||
if not utils.is_default(self.build):
|
||||
shapes_dict = update_shapes_dict_for_target_fn(
|
||||
self.build,
|
||||
shapes_dict=shapes_dict,
|
||||
call_spec=call_spec,
|
||||
class_name=self.__class__.__name__,
|
||||
if not utils.is_default(self.build):
|
||||
shapes_dict = update_shapes_dict_for_target_fn(
|
||||
self.build,
|
||||
shapes_dict=shapes_dict,
|
||||
call_spec=call_spec,
|
||||
class_name=self.__class__.__name__,
|
||||
)
|
||||
self.build(**shapes_dict)
|
||||
elif might_have_unbuilt_state(self):
|
||||
if len(shapes_dict) == 1:
|
||||
# Single arg: pass it positionally
|
||||
success = self._build_by_run_for_single_pos_arg(
|
||||
tuple(shapes_dict.values())[0]
|
||||
)
|
||||
else:
|
||||
success = self._build_by_run_for_kwargs(shapes_dict)
|
||||
if not success:
|
||||
if call_spec.eager:
|
||||
# Will let the actual eager call do state-building
|
||||
return
|
||||
raise ValueError(
|
||||
f"Layer '{self.name}' looks like it has "
|
||||
"unbuilt state, but Keras is not able to "
|
||||
"trace the layer `call()` in order to "
|
||||
"build it automatically. Possible causes:\n"
|
||||
"1. The `call()` method of your layer may be "
|
||||
"crashing. Try to `__call__()` the layer "
|
||||
"eagerly on some test input "
|
||||
"first to see if it works. "
|
||||
"E.g. `x = np.random.random((3, 4)); "
|
||||
"y = layer(x)`\n"
|
||||
"2. If the `call()` method is correct, "
|
||||
"then you may need to implement "
|
||||
"the `def build(self, input_shape)` method on your "
|
||||
"layer. It should create all variables used by the "
|
||||
"layer (e.g. by calling `layer.build()` on all its "
|
||||
"children layers)."
|
||||
)
|
||||
self.build(**shapes_dict)
|
||||
elif might_have_unbuilt_state(self):
|
||||
if len(shapes_dict) == 1:
|
||||
# Single arg: pass it positionally
|
||||
success = self._build_by_run_for_single_pos_arg(
|
||||
tuple(shapes_dict.values())[0]
|
||||
)
|
||||
else:
|
||||
success = self._build_by_run_for_kwargs(shapes_dict)
|
||||
if not success:
|
||||
if call_spec.eager:
|
||||
# Will let the actual eager call do state-building
|
||||
return
|
||||
raise ValueError(
|
||||
f"Layer '{self.name}' looks like it has "
|
||||
"unbuilt state, but Keras is not able to "
|
||||
"trace the layer `call()` in order to "
|
||||
"build it automatically. Possible causes:\n"
|
||||
"1. The `call()` method of your layer may be "
|
||||
"crashing. Try to `__call__()` the layer "
|
||||
"eagerly on some test input "
|
||||
"first to see if it works. "
|
||||
"E.g. `x = np.random.random((3, 4)); "
|
||||
"y = layer(x)`\n"
|
||||
"2. If the `call()` method is correct, "
|
||||
"then you may need to implement "
|
||||
"the `def build(self, input_shape)` method on your "
|
||||
"layer. It should create all variables used by the "
|
||||
"layer (e.g. by calling `layer.build()` on all its "
|
||||
"children layers)."
|
||||
)
|
||||
self.built = True
|
||||
|
||||
# Check input spec again (after build, since self.input_spec
|
||||
|
@ -288,12 +288,16 @@ class RNN(Layer):
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def _create_state_variables(self, batch_size):
|
||||
self.states = tree.map_structure(
|
||||
lambda value: backend.Variable(
|
||||
value, trainable=False, dtype=self.variable_dtype
|
||||
),
|
||||
self.get_initial_state(batch_size),
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
self.states = tree.map_structure(
|
||||
lambda value: backend.Variable(
|
||||
value,
|
||||
trainable=False,
|
||||
dtype=self.variable_dtype,
|
||||
name="rnn_state",
|
||||
),
|
||||
self.get_initial_state(batch_size),
|
||||
)
|
||||
|
||||
def get_initial_state(self, batch_size):
|
||||
get_initial_state_fn = getattr(self.cell, "get_initial_state", None)
|
||||
|
@ -166,14 +166,15 @@ class Metric:
|
||||
|
||||
def add_variable(self, shape, initializer, dtype=None, name=None):
|
||||
self._check_super_called()
|
||||
initializer = initializers.get(initializer)
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
trainable=False,
|
||||
name=name,
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
initializer = initializers.get(initializer)
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
trainable=False,
|
||||
name=name,
|
||||
)
|
||||
# Prevent double-tracking
|
||||
self._tracker.add_to_store("variables", variable)
|
||||
return variable
|
||||
|
@ -542,9 +542,7 @@ class Model(Trainer, Layer):
|
||||
|
||||
def _get_variable_map(self):
|
||||
store = {}
|
||||
map_trackable_variables(
|
||||
self, store=store, inner_path="", visited_trackables=set()
|
||||
)
|
||||
map_trackable_variables(self, store=store, visited_trackables=set())
|
||||
return store
|
||||
|
||||
|
||||
|
@ -2,10 +2,9 @@ from keras_core.layers.layer import Layer
|
||||
from keras_core.metrics.metric import Metric
|
||||
from keras_core.optimizers.optimizer import Optimizer
|
||||
from keras_core.saving import saving_lib
|
||||
from keras_core.utils import file_utils
|
||||
|
||||
|
||||
def map_trackable_variables(trackable, store, inner_path, visited_trackables):
|
||||
def map_trackable_variables(trackable, store, visited_trackables):
|
||||
# If the trackable has already been saved, skip it.
|
||||
if id(trackable) in visited_trackables:
|
||||
return
|
||||
@ -22,7 +21,15 @@ def map_trackable_variables(trackable, store, inner_path, visited_trackables):
|
||||
elif isinstance(trackable, Metric):
|
||||
variables = trackable._variables
|
||||
for v in variables:
|
||||
store[inner_path + "/" + v.name] = v
|
||||
if v.path in store:
|
||||
raise ValueError(
|
||||
"The model contains two variables with a duplicate path: "
|
||||
f"path='{v.path}' appears at least twice. "
|
||||
f"This path is used for {v} and for {store[v.path]}. "
|
||||
"In order to get a variable map, make sure to use "
|
||||
"unique paths/names for each variable."
|
||||
)
|
||||
store[v.path] = v
|
||||
|
||||
# Recursively save state of children trackables (layers, optimizers, etc.)
|
||||
for child_attr, child_obj in saving_lib._walk_trackable(trackable):
|
||||
@ -30,28 +37,24 @@ def map_trackable_variables(trackable, store, inner_path, visited_trackables):
|
||||
map_trackable_variables(
|
||||
child_obj,
|
||||
store,
|
||||
inner_path=file_utils.join(inner_path, child_obj.name),
|
||||
visited_trackables=visited_trackables,
|
||||
)
|
||||
elif isinstance(child_obj, (list, dict, tuple, set)):
|
||||
map_container_variables(
|
||||
child_obj,
|
||||
store,
|
||||
inner_path=file_utils.join(inner_path, child_attr),
|
||||
visited_trackables=visited_trackables,
|
||||
)
|
||||
|
||||
|
||||
def map_container_variables(container, store, inner_path, visited_trackables):
|
||||
def map_container_variables(container, store, visited_trackables):
|
||||
if isinstance(container, dict):
|
||||
container = list(container.values())
|
||||
|
||||
for trackable in container:
|
||||
if saving_lib._is_keras_trackable(trackable):
|
||||
name = trackable.name
|
||||
map_trackable_variables(
|
||||
trackable,
|
||||
store,
|
||||
inner_path=file_utils.join(inner_path, name),
|
||||
visited_trackables=visited_trackables,
|
||||
)
|
||||
|
@ -96,12 +96,13 @@ class Adafactor(optimizer.Optimizer):
|
||||
if len(var.shape) < 2:
|
||||
# Don't factor if variable is of dimension < 2, but we still
|
||||
# need to create dummy variables as placeholder.
|
||||
self._r.append(
|
||||
backend.Variable(0, name=var.name, trainable=False)
|
||||
)
|
||||
self._c.append(
|
||||
backend.Variable(0, name=var.name, trainable=False)
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
self._r.append(
|
||||
backend.Variable(0, name=var.name, trainable=False)
|
||||
)
|
||||
self._c.append(
|
||||
backend.Variable(0, name=var.name, trainable=False)
|
||||
)
|
||||
else:
|
||||
# Always factor the last 2 dimenstions.
|
||||
r_shape = var.shape[:-1]
|
||||
@ -122,7 +123,7 @@ class Adafactor(optimizer.Optimizer):
|
||||
)
|
||||
self._v.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="v"
|
||||
reference_variable=var, name="velocity"
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -91,12 +91,12 @@ class Adam(optimizer.Optimizer):
|
||||
for var in var_list:
|
||||
self._momentums.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="m"
|
||||
reference_variable=var, name="momentum"
|
||||
)
|
||||
)
|
||||
self._velocities.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="v"
|
||||
reference_variable=var, name="velocity"
|
||||
)
|
||||
)
|
||||
if self.amsgrad:
|
||||
@ -104,7 +104,7 @@ class Adam(optimizer.Optimizer):
|
||||
for var in var_list:
|
||||
self._velocity_hats.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="vhat"
|
||||
reference_variable=var, name="velocity_hat"
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -99,12 +99,12 @@ class Adamax(optimizer.Optimizer):
|
||||
for var in var_list:
|
||||
self._m.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="m"
|
||||
reference_variable=var, name="momentum"
|
||||
)
|
||||
)
|
||||
self._u.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="u"
|
||||
reference_variable=var, name="norm"
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -33,6 +33,8 @@ class BaseOptimizer:
|
||||
if kwargs:
|
||||
raise ValueError(f"Argument(s) not recognized: {kwargs}")
|
||||
|
||||
if name is None:
|
||||
name = auto_name(self.__class__.__name__)
|
||||
self.name = name
|
||||
self.weight_decay = weight_decay
|
||||
self.clipnorm = clipnorm
|
||||
@ -83,9 +85,10 @@ class BaseOptimizer:
|
||||
# Create iteration variable
|
||||
# Note: dtype="int" will resolve to int32 in JAX
|
||||
# (since int64 is disallowed in JAX) and to int64 in TF.
|
||||
iterations = backend.Variable(
|
||||
0, name="iteration", dtype="int", trainable=False
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
iterations = backend.Variable(
|
||||
0, name="iteration", dtype="int", trainable=False
|
||||
)
|
||||
self._track_variable(iterations)
|
||||
self.iterations = iterations
|
||||
|
||||
@ -105,12 +108,13 @@ class BaseOptimizer:
|
||||
"and returns the corresponding learning rate value). "
|
||||
f"Received instead: learning_rate={learning_rate}"
|
||||
)
|
||||
learning_rate = backend.Variable(
|
||||
learning_rate,
|
||||
name="learning_rate",
|
||||
dtype=backend.floatx(),
|
||||
trainable=False,
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
learning_rate = backend.Variable(
|
||||
learning_rate,
|
||||
name="learning_rate",
|
||||
dtype=backend.floatx(),
|
||||
trainable=False,
|
||||
)
|
||||
self._track_variable(learning_rate)
|
||||
self._learning_rate = learning_rate
|
||||
|
||||
@ -154,13 +158,14 @@ class BaseOptimizer:
|
||||
):
|
||||
self._check_super_called()
|
||||
initializer = initializers.get(initializer)
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
trainable=False,
|
||||
name=name,
|
||||
)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
trainable=False,
|
||||
name=name,
|
||||
)
|
||||
self._track_variable(variable)
|
||||
return variable
|
||||
|
||||
@ -169,12 +174,12 @@ class BaseOptimizer:
|
||||
variable.
|
||||
"""
|
||||
initializer = initializers.Zeros()
|
||||
name = name or auto_name(self.__class__.__name__)
|
||||
name = name or "var"
|
||||
return self.add_variable(
|
||||
shape=reference_variable.shape,
|
||||
initializer=initializer,
|
||||
dtype=reference_variable.dtype,
|
||||
name=name,
|
||||
name=reference_variable.path.replace("/", "_") + "_" + name,
|
||||
)
|
||||
|
||||
def _check_variables_are_known(self, variables):
|
||||
@ -226,12 +231,12 @@ class BaseOptimizer:
|
||||
trainable_variables = list(trainable_variables)
|
||||
# Optionally build optimizer.
|
||||
if not self.built:
|
||||
with ops.name_scope(self.name):
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
self.build(trainable_variables)
|
||||
self.built = True
|
||||
self._check_variables_are_known(trainable_variables)
|
||||
|
||||
with ops.name_scope(self.name):
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
# Filter empty gradients.
|
||||
grads, trainable_variables = self._filter_empty_gradients(
|
||||
grads, trainable_variables
|
||||
|
@ -91,7 +91,7 @@ class Lion(optimizer.Optimizer):
|
||||
for var in var_list:
|
||||
self._momentums.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="m"
|
||||
reference_variable=var, name="momentum"
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -90,12 +90,12 @@ class Nadam(optimizer.Optimizer):
|
||||
for var in var_list:
|
||||
self._momentums.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="m"
|
||||
reference_variable=var, name="momentum"
|
||||
)
|
||||
)
|
||||
self._velocities.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="v"
|
||||
reference_variable=var, name="velocity"
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -89,7 +89,7 @@ class SGD(optimizer.Optimizer):
|
||||
for variable in variables:
|
||||
self.momentums.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=variable, name="m"
|
||||
reference_variable=variable, name="momentum"
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user