Add tracking opt out.
This commit is contained in:
parent
b49553435e
commit
d663d5bcc1
@ -1,3 +1,4 @@
|
|||||||
from keras_core.layers.core.dense import Dense
|
from keras_core.layers.core.dense import Dense
|
||||||
|
from keras_core.layers.layer import Layer
|
||||||
|
|
||||||
# from keras_core.layers.regularization.dropout import Dropout
|
# from keras_core.layers.regularization.dropout import Dropout
|
||||||
|
@ -273,8 +273,9 @@ class Layer(Operation):
|
|||||||
f"should be passed as a keyword argument: {arg}"
|
f"should be passed as a keyword argument: {arg}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Check input spec.
|
# 4. Check input spec for 1st positional arg.
|
||||||
self._assert_input_compatibility(*args, **kwargs)
|
# TODO: consider extending this to all args and kwargs.
|
||||||
|
self._assert_input_compatibility(*args)
|
||||||
######################################
|
######################################
|
||||||
|
|
||||||
###############
|
###############
|
||||||
@ -471,7 +472,7 @@ class Layer(Operation):
|
|||||||
|
|
||||||
# Check input spec again (after build, since self.input_spec
|
# Check input spec again (after build, since self.input_spec
|
||||||
# may have been updated
|
# may have been updated
|
||||||
self._assert_input_compatibility(*args, **kwargs)
|
self._assert_input_compatibility(*args)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# TODO: improve
|
# TODO: improve
|
||||||
@ -495,7 +496,7 @@ class Layer(Operation):
|
|||||||
"Go add it!"
|
"Go add it!"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _assert_input_compatibility(self, *args, **kwargs):
|
def _assert_input_compatibility(self, *args):
|
||||||
if args and self.input_spec:
|
if args and self.input_spec:
|
||||||
input_spec.assert_input_compatibility(
|
input_spec.assert_input_compatibility(
|
||||||
self.input_spec, args[0], layer_name=self.name
|
self.input_spec, args[0], layer_name=self.name
|
||||||
|
@ -6,6 +6,7 @@ from keras_core import operations as ops
|
|||||||
from keras_core.layers.layer import Layer
|
from keras_core.layers.layer import Layer
|
||||||
from keras_core.models.model import Model
|
from keras_core.models.model import Model
|
||||||
from keras_core.operations.function import Function
|
from keras_core.operations.function import Function
|
||||||
|
from keras_core.utils import tracking
|
||||||
|
|
||||||
|
|
||||||
class Functional(Function, Model):
|
class Functional(Function, Model):
|
||||||
@ -22,6 +23,7 @@ class Functional(Function, Model):
|
|||||||
Symbolic add_loss
|
Symbolic add_loss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@tracking.no_automatic_dependency_tracking
|
||||||
def __init__(self, inputs, outputs, name=None, **kwargs):
|
def __init__(self, inputs, outputs, name=None, **kwargs):
|
||||||
# This is used by the Model class, since we have some logic to swap the
|
# This is used by the Model class, since we have some logic to swap the
|
||||||
# class in the __new__ method, which will lead to __init__ get invoked
|
# class in the __new__ method, which will lead to __init__ get invoked
|
||||||
@ -130,13 +132,12 @@ class Functional(Function, Model):
|
|||||||
|
|
||||||
|
|
||||||
def operation_fn(operation, training):
|
def operation_fn(operation, training):
|
||||||
def call(*arg, **kwargs):
|
def call(*args, **kwargs):
|
||||||
if (
|
if (
|
||||||
hasattr(operation, "_call_has_training_arg")
|
hasattr(operation, "_call_has_training_arg")
|
||||||
and operation._call_has_training_arg()
|
and operation._call_has_training_arg()
|
||||||
and "training" not in kwargs
|
|
||||||
):
|
):
|
||||||
kwargs["training"] = training
|
kwargs["training"] = training
|
||||||
return operation(*arg, **kwargs)
|
return operation(*args, **kwargs)
|
||||||
|
|
||||||
return call
|
return call
|
||||||
|
@ -9,7 +9,7 @@ from keras_core.models.functional import Functional
|
|||||||
|
|
||||||
|
|
||||||
class FunctionalTest(testing.TestCase):
|
class FunctionalTest(testing.TestCase):
|
||||||
def test_basic_flow(self):
|
def test_basic_flow_multi_input(self):
|
||||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||||
x = input_a + input_b
|
x = input_a + input_b
|
||||||
@ -29,17 +29,61 @@ class FunctionalTest(testing.TestCase):
|
|||||||
out_val = model(in_val)
|
out_val = model(in_val)
|
||||||
self.assertEqual(out_val.shape, (2, 4))
|
self.assertEqual(out_val.shape, (2, 4))
|
||||||
|
|
||||||
|
def test_basic_flow_multi_output(self):
|
||||||
|
inputs = Input(shape=(3,), batch_size=2, name="input")
|
||||||
|
x = layers.Dense(5)(inputs)
|
||||||
|
output_a = layers.Dense(4)(x)
|
||||||
|
output_b = layers.Dense(5)(x)
|
||||||
|
model = Functional(inputs, [output_a, output_b])
|
||||||
|
|
||||||
|
# Eager call
|
||||||
|
in_val = np.random.random((2, 3))
|
||||||
|
out_val = model(in_val)
|
||||||
|
self.assertTrue(isinstance(out_val, list))
|
||||||
|
self.assertEqual(len(out_val), 2)
|
||||||
|
self.assertEqual(out_val[0].shape, (2, 4))
|
||||||
|
self.assertEqual(out_val[1].shape, (2, 5))
|
||||||
|
|
||||||
|
# Symbolic call
|
||||||
|
out_val = model(Input(shape=(3,), batch_size=2))
|
||||||
|
self.assertTrue(isinstance(out_val, list))
|
||||||
|
self.assertEqual(len(out_val), 2)
|
||||||
|
self.assertEqual(out_val[0].shape, (2, 4))
|
||||||
|
self.assertEqual(out_val[1].shape, (2, 5))
|
||||||
|
|
||||||
def test_layer_getters(self):
|
def test_layer_getters(self):
|
||||||
# Test mixing ops and layers
|
# Test mixing ops and layers
|
||||||
pass
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||||
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||||
|
x = input_a + input_b
|
||||||
|
x = layers.Dense(5, name="dense_1")(x)
|
||||||
|
outputs = layers.Dense(4, name="dense_2")(x)
|
||||||
|
model = Functional([input_a, input_b], outputs)
|
||||||
|
|
||||||
|
self.assertEqual(len(model.layers), 4)
|
||||||
|
self.assertEqual(len(model._operations), 5)
|
||||||
|
self.assertEqual(model.get_layer(index=0).name, "input_a")
|
||||||
|
self.assertEqual(model.get_layer(index=1).name, "input_b")
|
||||||
|
self.assertEqual(model.get_layer(index=2).name, "dense_1")
|
||||||
|
self.assertEqual(model.get_layer(index=3).name, "dense_2")
|
||||||
|
self.assertEqual(model.get_layer(name="dense_1").name, "dense_1")
|
||||||
|
|
||||||
def test_training_arg(self):
|
def test_training_arg(self):
|
||||||
pass
|
class Canary(layers.Layer):
|
||||||
|
def call(self, x, training=False):
|
||||||
|
assert training
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_output_spec(self, x, training=False):
|
||||||
|
return backend.KerasTensor(x.shape, dtype=x.dtype)
|
||||||
|
|
||||||
|
inputs = Input(shape=(3,), batch_size=2)
|
||||||
|
outputs = Canary()(inputs)
|
||||||
|
model = Functional(inputs, outputs)
|
||||||
|
model(np.random.random((2, 3)), training=True)
|
||||||
|
|
||||||
def test_mask_arg(self):
|
def test_mask_arg(self):
|
||||||
pass
|
# TODO
|
||||||
|
|
||||||
def test_shape_inference(self):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_passing_inputs_by_name(self):
|
def test_passing_inputs_by_name(self):
|
||||||
|
@ -79,7 +79,7 @@ class Function(Operation):
|
|||||||
for depth in depth_keys:
|
for depth in depth_keys:
|
||||||
nodes = nodes_by_depth[depth]
|
nodes = nodes_by_depth[depth]
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if not node.operation:
|
if not node.operation or node.is_input:
|
||||||
continue # Input tensors already exist.
|
continue # Input tensors already exist.
|
||||||
|
|
||||||
if any(x not in tensor_dict for x in node.input_tensors):
|
if any(x not in tensor_dict for x in node.input_tensors):
|
||||||
|
@ -37,7 +37,6 @@ class Node:
|
|||||||
call_kwargs: The keyword arguments the operation was called with.
|
call_kwargs: The keyword arguments the operation was called with.
|
||||||
outputs: The output tensors of the `op.__call__()` call.
|
outputs: The output tensors of the `op.__call__()` call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, operation, call_args=None, call_kwargs=None, outputs=None
|
self, operation, call_args=None, call_kwargs=None, outputs=None
|
||||||
):
|
):
|
||||||
@ -77,7 +76,7 @@ class Node:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Whether this is a root node.
|
# Whether this is a root node.
|
||||||
self.is_input = not any_input_has_history
|
self.is_input = not self.arguments.keras_tensors
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Node operation={self.operation}, id={id(self)}>"
|
return f"<Node operation={self.operation}, id={id(self)}>"
|
||||||
|
@ -1,3 +1,28 @@
|
|||||||
|
import threading
|
||||||
|
|
||||||
|
GLOBAL_SCOPE_TRACKER = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
class DotNotTrackScope:
|
||||||
|
def __enter__(self):
|
||||||
|
self.original_value = is_tracking_enabled()
|
||||||
|
GLOBAL_SCOPE_TRACKER.tracking_on = False
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
GLOBAL_SCOPE_TRACKER.tracking_on = self.original_value
|
||||||
|
|
||||||
|
|
||||||
|
def is_tracking_enabled():
|
||||||
|
return getattr(GLOBAL_SCOPE_TRACKER, "tracking_on", True)
|
||||||
|
|
||||||
|
|
||||||
|
def no_automatic_dependency_tracking(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with DotNotTrackScope():
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
"""Attribute tracker, used for e.g. Variable tracking.
|
"""Attribute tracker, used for e.g. Variable tracking.
|
||||||
|
|
||||||
@ -34,6 +59,9 @@ class Tracker:
|
|||||||
self.stored_ids = {name: set() for name in self.config.keys()}
|
self.stored_ids = {name: set() for name in self.config.keys()}
|
||||||
|
|
||||||
def track(self, attr):
|
def track(self, attr):
|
||||||
|
if not is_tracking_enabled():
|
||||||
|
return attr
|
||||||
|
|
||||||
for name, (is_attr_type, store) in self.config.items():
|
for name, (is_attr_type, store) in self.config.items():
|
||||||
if is_attr_type(attr):
|
if is_attr_type(attr):
|
||||||
if id(attr) not in self.stored_ids[name]:
|
if id(attr) not in self.stored_ids[name]:
|
||||||
|
Loading…
Reference in New Issue
Block a user