Add tracking opt out.

This commit is contained in:
Francois Chollet 2023-04-12 15:20:56 -07:00
parent b49553435e
commit d663d5bcc1
7 changed files with 91 additions and 17 deletions

@ -1,3 +1,4 @@
from keras_core.layers.core.dense import Dense
from keras_core.layers.layer import Layer
# from keras_core.layers.regularization.dropout import Dropout

@ -273,8 +273,9 @@ class Layer(Operation):
f"should be passed as a keyword argument: {arg}"
)
# 4. Check input spec.
self._assert_input_compatibility(*args, **kwargs)
# 4. Check input spec for 1st positional arg.
# TODO: consider extending this to all args and kwargs.
self._assert_input_compatibility(*args)
######################################
###############
@ -471,7 +472,7 @@ class Layer(Operation):
# Check input spec again (after build, since self.input_spec
# may have been updated
self._assert_input_compatibility(*args, **kwargs)
self._assert_input_compatibility(*args)
def __repr__(self):
# TODO: improve
@ -495,7 +496,7 @@ class Layer(Operation):
"Go add it!"
)
def _assert_input_compatibility(self, *args, **kwargs):
def _assert_input_compatibility(self, *args):
if args and self.input_spec:
input_spec.assert_input_compatibility(
self.input_spec, args[0], layer_name=self.name

@ -6,6 +6,7 @@ from keras_core import operations as ops
from keras_core.layers.layer import Layer
from keras_core.models.model import Model
from keras_core.operations.function import Function
from keras_core.utils import tracking
class Functional(Function, Model):
@ -22,6 +23,7 @@ class Functional(Function, Model):
Symbolic add_loss
"""
@tracking.no_automatic_dependency_tracking
def __init__(self, inputs, outputs, name=None, **kwargs):
# This is used by the Model class, since we have some logic to swap the
# class in the __new__ method, which will lead to __init__ get invoked
@ -50,7 +52,7 @@ class Functional(Function, Model):
else:
masks = self._flatten_to_reference_inputs(mask)
for x, mask in zip(inputs, masks):
x._keras_mask = mask
x._keras_mask = mask
return self._run_through_graph(
inputs, operation_fn=lambda op: operation_fn(op, training=training)
)
@ -130,13 +132,12 @@ class Functional(Function, Model):
def operation_fn(operation, training):
def call(*arg, **kwargs):
def call(*args, **kwargs):
if (
hasattr(operation, "_call_has_training_arg")
and operation._call_has_training_arg()
and "training" not in kwargs
):
kwargs["training"] = training
return operation(*arg, **kwargs)
return operation(*args, **kwargs)
return call

@ -9,7 +9,7 @@ from keras_core.models.functional import Functional
class FunctionalTest(testing.TestCase):
def test_basic_flow(self):
def test_basic_flow_multi_input(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
@ -29,17 +29,61 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
def test_basic_flow_multi_output(self):
inputs = Input(shape=(3,), batch_size=2, name="input")
x = layers.Dense(5)(inputs)
output_a = layers.Dense(4)(x)
output_b = layers.Dense(5)(x)
model = Functional(inputs, [output_a, output_b])
# Eager call
in_val = np.random.random((2, 3))
out_val = model(in_val)
self.assertTrue(isinstance(out_val, list))
self.assertEqual(len(out_val), 2)
self.assertEqual(out_val[0].shape, (2, 4))
self.assertEqual(out_val[1].shape, (2, 5))
# Symbolic call
out_val = model(Input(shape=(3,), batch_size=2))
self.assertTrue(isinstance(out_val, list))
self.assertEqual(len(out_val), 2)
self.assertEqual(out_val[0].shape, (2, 4))
self.assertEqual(out_val[1].shape, (2, 5))
def test_layer_getters(self):
# Test mixing ops and layers
pass
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5, name="dense_1")(x)
outputs = layers.Dense(4, name="dense_2")(x)
model = Functional([input_a, input_b], outputs)
self.assertEqual(len(model.layers), 4)
self.assertEqual(len(model._operations), 5)
self.assertEqual(model.get_layer(index=0).name, "input_a")
self.assertEqual(model.get_layer(index=1).name, "input_b")
self.assertEqual(model.get_layer(index=2).name, "dense_1")
self.assertEqual(model.get_layer(index=3).name, "dense_2")
self.assertEqual(model.get_layer(name="dense_1").name, "dense_1")
def test_training_arg(self):
pass
class Canary(layers.Layer):
def call(self, x, training=False):
assert training
return x
def compute_output_spec(self, x, training=False):
return backend.KerasTensor(x.shape, dtype=x.dtype)
inputs = Input(shape=(3,), batch_size=2)
outputs = Canary()(inputs)
model = Functional(inputs, outputs)
model(np.random.random((2, 3)), training=True)
def test_mask_arg(self):
pass
def test_shape_inference(self):
# TODO
pass
def test_passing_inputs_by_name(self):

@ -79,7 +79,7 @@ class Function(Operation):
for depth in depth_keys:
nodes = nodes_by_depth[depth]
for node in nodes:
if not node.operation:
if not node.operation or node.is_input:
continue # Input tensors already exist.
if any(x not in tensor_dict for x in node.input_tensors):

@ -37,7 +37,6 @@ class Node:
call_kwargs: The keyword arguments the operation was called with.
outputs: The output tensors of the `op.__call__()` call.
"""
def __init__(
self, operation, call_args=None, call_kwargs=None, outputs=None
):
@ -77,7 +76,7 @@ class Node:
)
# Whether this is a root node.
self.is_input = not any_input_has_history
self.is_input = not self.arguments.keras_tensors
def __repr__(self):
return f"<Node operation={self.operation}, id={id(self)}>"

@ -1,3 +1,28 @@
import threading
GLOBAL_SCOPE_TRACKER = threading.local()
class DotNotTrackScope:
def __enter__(self):
self.original_value = is_tracking_enabled()
GLOBAL_SCOPE_TRACKER.tracking_on = False
def __exit__(self, *args, **kwargs):
GLOBAL_SCOPE_TRACKER.tracking_on = self.original_value
def is_tracking_enabled():
return getattr(GLOBAL_SCOPE_TRACKER, "tracking_on", True)
def no_automatic_dependency_tracking(fn):
def wrapper(*args, **kwargs):
with DotNotTrackScope():
return fn(*args, **kwargs)
return wrapper
class Tracker:
"""Attribute tracker, used for e.g. Variable tracking.
@ -34,6 +59,9 @@ class Tracker:
self.stored_ids = {name: set() for name in self.config.keys()}
def track(self, attr):
if not is_tracking_enabled():
return attr
for name, (is_attr_type, store) in self.config.items():
if is_attr_type(attr):
if id(attr) not in self.stored_ids[name]: