Fix SavedModel integration and add associated tests (#522)

* Add saved model test

* Add TF tracking attribute

* Add tests for functional and subclassed

* Fix saving trackables

* Fix test assertions

* Fix formatting

* Add comments for attribute tracking

* Change saved model test description

* Add backend conditional for attribute

* Change package name

* Change epoch nums

* Revert epochs

* Add set verbose logging utility and debug callback tests

* Fix formatting
This commit is contained in:
Neel Kovelamudi 2023-07-18 23:01:13 +00:00 committed by Francois Chollet
parent a0d7776585
commit 6094363015
6 changed files with 169 additions and 2 deletions

@ -6,6 +6,25 @@ class TFLayer(tf.__internal__.tracking.AutoTrackable):
"""Can be overriden to perform post-build actions."""
pass
def _trackable_children(self, save_type="checkpoint", **kwargs):
if save_type == "savedmodel":
# SavedModel needs to ignore the execution functions.
train_function = getattr(self, "train_function", None)
test_function = getattr(self, "test_function", None)
predict_function = getattr(self, "predict_function", None)
self.train_function = None
self.test_function = None
self.predict_function = None
children = super()._trackable_children(save_type, **kwargs)
if save_type == "savedmodel":
self.train_function = train_function
self.test_function = test_function
self.predict_function = predict_function
return children
@property
def _default_save_signature(self):
"""For SavedModel support: returns the default serving signature."""

@ -0,0 +1,99 @@
"""Tests for SavedModel functionality under tf implementation."""
import os
import numpy as np
import pytest
import tensorflow as tf
from keras_core import backend
from keras_core import layers
from keras_core import metrics
from keras_core import models
from keras_core import testing
from keras_core.saving import object_registration
@object_registration.register_keras_serializable(package="my_package")
class CustomModelX(models.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dense1 = layers.Dense(1)
self.dense2 = layers.Dense(1)
def call(self, inputs):
out = self.dense1(inputs)
return self.dense2(out)
def one(self):
return 1
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="The SavedModel test can only run with TF backend.",
)
class SavedModelTest(testing.TestCase):
def test_sequential(self):
model = models.Sequential([layers.Dense(1)])
model.compile(loss="mse", optimizer="adam")
X_train = np.random.rand(100, 3)
y_train = np.random.rand(100, 1)
model.fit(X_train, y_train)
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
self.assertAllClose(
model(X_train),
restored_model.signatures["serving_default"](
tf.convert_to_tensor(X_train, dtype=tf.float32)
)["output_0"],
rtol=1e-4,
atol=1e-4,
)
def test_functional(self):
inputs = layers.Input(shape=(3,))
x = layers.Dense(1, name="first_dense")(inputs)
outputs = layers.Dense(1, name="second_dense")(x)
model = models.Model(inputs, outputs)
model.compile(
optimizer="adam",
loss="mse",
)
X_train = np.random.rand(100, 3)
y_train = np.random.rand(100, 1)
model.fit(X_train, y_train)
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
self.assertAllClose(
model(X_train),
restored_model.signatures["serving_default"](
tf.convert_to_tensor(X_train, dtype=tf.float32)
)["output_0"],
rtol=1e-4,
atol=1e-4,
)
def test_subclassed(self):
model = CustomModelX()
model.compile(
optimizer="adam",
loss="mse",
metrics=[metrics.Hinge(), "mse"],
)
X_train = np.random.rand(100, 3)
y_train = np.random.rand(100, 1)
model.fit(X_train, y_train)
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
self.assertAllClose(
model(X_train),
restored_model.signatures["serving_default"](
tf.convert_to_tensor(X_train, dtype=tf.float32)
)["output_0"],
rtol=1e-4,
atol=1e-4,
)

@ -52,8 +52,9 @@ class LearningRateSchedulerTest(testing.TestCase):
lambda step: 1.0 / (1.0 + step), verbose=1
)
io_utils.disable_interactive_logging()
io_utils.set_logging_verbosity("INFO")
with self.assertLogs(level="INFO") as logs:
with self.assertLogs() as logs:
self.model.fit(
self.x_train,
self.y_train,

@ -83,8 +83,9 @@ class ReduceLROnPlateauTest(testing.TestCase):
patience=1, factor=0.1, monitor="val_loss", min_delta=100, verbose=1
)
io_utils.disable_interactive_logging()
io_utils.set_logging_verbosity("INFO")
with self.assertLogs(level="INFO") as logs:
with self.assertLogs() as logs:
self.model.fit(
self.x_train,
self.y_train,

@ -295,12 +295,23 @@ class Layer(BackendLayer, Operation):
),
}
)
if backend.backend() == "tensorflow":
# Remove attribute tracking for lists (TF-specific attribute)
_self_setattr_tracking = getattr(
self, "_self_setattr_tracking", True
)
self._self_setattr_tracking = False
self._trainable_variables = trainable_variables
self._non_trainable_variables = non_trainable_variables
self._layers = layers
self._metrics = metrics
self._seed_generators = seed_generators
if backend.backend() == "tensorflow":
# Reset attribute tracking (TF-specific)
self._self_setattr_tracking = _self_setattr_tracking
@property
def input_spec(self):
return self._input_spec

@ -58,6 +58,42 @@ def is_interactive_logging_enabled():
return global_state.get_global_setting("interactive_logging", True)
@keras_core_export(
[
"keras_core.config.set_logging_verbosity",
"keras_core.utils.set_logging_verbosity",
]
)
def set_logging_verbosity(level):
"""Sets the verbosity level for logging.
The log levels are as follows:
- "FATAL" (least verbose)
- "ERROR"
- "WARNING"
- "INFO"
- "DEBUG" (most verbose)
Args:
level: A string corresponding to the level of verbosity for logging.
"""
valid_levels = {
"FATAL": logging.FATAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
}
verbosity = valid_levels.get(level)
if verbosity is None:
raise ValueError(
"Please pass a valid level for logging verbosity. "
f"The valid levels are {valid_levels.keys()}. "
f"Received: {level}"
)
logging.set_verbosity(verbosity)
def print_msg(message, line_break=True):
"""Print the message to absl logging or stdout."""
if is_interactive_logging_enabled():