Fix SavedModel integration and add associated tests (#522)
* Add saved model test * Add TF tracking attribute * Add tests for functional and subclassed * Fix saving trackables * Fix test assertions * Fix formatting * Add comments for attribute tracking * Change saved model test description * Add backend conditional for attribute * Change package name * Change epoch nums * Revert epochs * Add set verbose logging utility and debug callback tests * Fix formatting
This commit is contained in:
parent
a0d7776585
commit
6094363015
@ -6,6 +6,25 @@ class TFLayer(tf.__internal__.tracking.AutoTrackable):
|
||||
"""Can be overriden to perform post-build actions."""
|
||||
pass
|
||||
|
||||
def _trackable_children(self, save_type="checkpoint", **kwargs):
|
||||
if save_type == "savedmodel":
|
||||
# SavedModel needs to ignore the execution functions.
|
||||
train_function = getattr(self, "train_function", None)
|
||||
test_function = getattr(self, "test_function", None)
|
||||
predict_function = getattr(self, "predict_function", None)
|
||||
self.train_function = None
|
||||
self.test_function = None
|
||||
self.predict_function = None
|
||||
|
||||
children = super()._trackable_children(save_type, **kwargs)
|
||||
|
||||
if save_type == "savedmodel":
|
||||
self.train_function = train_function
|
||||
self.test_function = test_function
|
||||
self.predict_function = predict_function
|
||||
|
||||
return children
|
||||
|
||||
@property
|
||||
def _default_save_signature(self):
|
||||
"""For SavedModel support: returns the default serving signature."""
|
||||
|
99
keras_core/backend/tensorflow/saved_model_test.py
Normal file
99
keras_core/backend/tensorflow/saved_model_test.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Tests for SavedModel functionality under tf implementation."""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core import metrics
|
||||
from keras_core import models
|
||||
from keras_core import testing
|
||||
from keras_core.saving import object_registration
|
||||
|
||||
|
||||
@object_registration.register_keras_serializable(package="my_package")
|
||||
class CustomModelX(models.Model):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dense1 = layers.Dense(1)
|
||||
self.dense2 = layers.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
out = self.dense1(inputs)
|
||||
return self.dense2(out)
|
||||
|
||||
def one(self):
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "tensorflow",
|
||||
reason="The SavedModel test can only run with TF backend.",
|
||||
)
|
||||
class SavedModelTest(testing.TestCase):
|
||||
def test_sequential(self):
|
||||
model = models.Sequential([layers.Dense(1)])
|
||||
model.compile(loss="mse", optimizer="adam")
|
||||
X_train = np.random.rand(100, 3)
|
||||
y_train = np.random.rand(100, 1)
|
||||
model.fit(X_train, y_train)
|
||||
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
|
||||
tf.saved_model.save(model, path)
|
||||
restored_model = tf.saved_model.load(path)
|
||||
self.assertAllClose(
|
||||
model(X_train),
|
||||
restored_model.signatures["serving_default"](
|
||||
tf.convert_to_tensor(X_train, dtype=tf.float32)
|
||||
)["output_0"],
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
def test_functional(self):
|
||||
inputs = layers.Input(shape=(3,))
|
||||
x = layers.Dense(1, name="first_dense")(inputs)
|
||||
outputs = layers.Dense(1, name="second_dense")(x)
|
||||
model = models.Model(inputs, outputs)
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="mse",
|
||||
)
|
||||
X_train = np.random.rand(100, 3)
|
||||
y_train = np.random.rand(100, 1)
|
||||
model.fit(X_train, y_train)
|
||||
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
|
||||
tf.saved_model.save(model, path)
|
||||
restored_model = tf.saved_model.load(path)
|
||||
self.assertAllClose(
|
||||
model(X_train),
|
||||
restored_model.signatures["serving_default"](
|
||||
tf.convert_to_tensor(X_train, dtype=tf.float32)
|
||||
)["output_0"],
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
def test_subclassed(self):
|
||||
model = CustomModelX()
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="mse",
|
||||
metrics=[metrics.Hinge(), "mse"],
|
||||
)
|
||||
X_train = np.random.rand(100, 3)
|
||||
y_train = np.random.rand(100, 1)
|
||||
model.fit(X_train, y_train)
|
||||
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
|
||||
tf.saved_model.save(model, path)
|
||||
restored_model = tf.saved_model.load(path)
|
||||
self.assertAllClose(
|
||||
model(X_train),
|
||||
restored_model.signatures["serving_default"](
|
||||
tf.convert_to_tensor(X_train, dtype=tf.float32)
|
||||
)["output_0"],
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
@ -52,8 +52,9 @@ class LearningRateSchedulerTest(testing.TestCase):
|
||||
lambda step: 1.0 / (1.0 + step), verbose=1
|
||||
)
|
||||
io_utils.disable_interactive_logging()
|
||||
io_utils.set_logging_verbosity("INFO")
|
||||
|
||||
with self.assertLogs(level="INFO") as logs:
|
||||
with self.assertLogs() as logs:
|
||||
self.model.fit(
|
||||
self.x_train,
|
||||
self.y_train,
|
||||
|
@ -83,8 +83,9 @@ class ReduceLROnPlateauTest(testing.TestCase):
|
||||
patience=1, factor=0.1, monitor="val_loss", min_delta=100, verbose=1
|
||||
)
|
||||
io_utils.disable_interactive_logging()
|
||||
io_utils.set_logging_verbosity("INFO")
|
||||
|
||||
with self.assertLogs(level="INFO") as logs:
|
||||
with self.assertLogs() as logs:
|
||||
self.model.fit(
|
||||
self.x_train,
|
||||
self.y_train,
|
||||
|
@ -295,12 +295,23 @@ class Layer(BackendLayer, Operation):
|
||||
),
|
||||
}
|
||||
)
|
||||
if backend.backend() == "tensorflow":
|
||||
# Remove attribute tracking for lists (TF-specific attribute)
|
||||
_self_setattr_tracking = getattr(
|
||||
self, "_self_setattr_tracking", True
|
||||
)
|
||||
self._self_setattr_tracking = False
|
||||
|
||||
self._trainable_variables = trainable_variables
|
||||
self._non_trainable_variables = non_trainable_variables
|
||||
self._layers = layers
|
||||
self._metrics = metrics
|
||||
self._seed_generators = seed_generators
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
# Reset attribute tracking (TF-specific)
|
||||
self._self_setattr_tracking = _self_setattr_tracking
|
||||
|
||||
@property
|
||||
def input_spec(self):
|
||||
return self._input_spec
|
||||
|
@ -58,6 +58,42 @@ def is_interactive_logging_enabled():
|
||||
return global_state.get_global_setting("interactive_logging", True)
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
[
|
||||
"keras_core.config.set_logging_verbosity",
|
||||
"keras_core.utils.set_logging_verbosity",
|
||||
]
|
||||
)
|
||||
def set_logging_verbosity(level):
|
||||
"""Sets the verbosity level for logging.
|
||||
The log levels are as follows:
|
||||
|
||||
- "FATAL" (least verbose)
|
||||
- "ERROR"
|
||||
- "WARNING"
|
||||
- "INFO"
|
||||
- "DEBUG" (most verbose)
|
||||
|
||||
Args:
|
||||
level: A string corresponding to the level of verbosity for logging.
|
||||
"""
|
||||
valid_levels = {
|
||||
"FATAL": logging.FATAL,
|
||||
"ERROR": logging.ERROR,
|
||||
"WARNING": logging.WARNING,
|
||||
"INFO": logging.INFO,
|
||||
"DEBUG": logging.DEBUG,
|
||||
}
|
||||
verbosity = valid_levels.get(level)
|
||||
if verbosity is None:
|
||||
raise ValueError(
|
||||
"Please pass a valid level for logging verbosity. "
|
||||
f"The valid levels are {valid_levels.keys()}. "
|
||||
f"Received: {level}"
|
||||
)
|
||||
logging.set_verbosity(verbosity)
|
||||
|
||||
|
||||
def print_msg(message, line_break=True):
|
||||
"""Print the message to absl logging or stdout."""
|
||||
if is_interactive_logging_enabled():
|
||||
|
Loading…
Reference in New Issue
Block a user