diff --git a/keras_core/backend/tensorflow/layer.py b/keras_core/backend/tensorflow/layer.py index 51f980f54..4304aee96 100644 --- a/keras_core/backend/tensorflow/layer.py +++ b/keras_core/backend/tensorflow/layer.py @@ -1,11 +1,61 @@ import tensorflow as tf +from keras_core.backend.tensorflow import tf_utils + class TFLayer(tf.__internal__.tracking.AutoTrackable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Export-related attributes + self._saved_model_inputs_spec = None + self._saved_model_arg_spec = None + def _post_build(self): """Can be overriden to perform post-build actions.""" pass + @tf.__internal__.tracking.no_automatic_dependency_tracking + def _set_save_spec(self, inputs, args=None, kwargs=None): + """Defines the save spec so that serialization can trace layer calls. + + The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are + saved into a tuple of `([inputs] + args, kwargs)`. + + Args: + inputs: possibly nested inputs passed into the call function. + args: a list of positional arguments passed into call. + kwargs: a dictionary of keyword arguments passed into call. + """ + if self._saved_model_inputs_spec is not None: + return # Already set. + + inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs) + args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or []) + kwargs_spec = {} + # Filter out non-tensor arguments from kwargs. + for key, kwarg in kwargs.items(): + flat_kwarg = tf.nest.flatten(kwarg) + flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg] + if any(s is None for s in flat_specs): + continue + kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs) + + self._saved_model_inputs_spec = inputs_spec + self._saved_model_arg_spec = ( + [inputs_spec] + list(args_spec), + kwargs_spec, + ) + + def _get_save_spec(self, dynamic_batch=True, inputs_only=True): + if self._saved_model_inputs_spec is None: + return None + + spec = tf.nest.map_structure( + lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), + self._saved_model_arg_spec, + ) + return spec[0][0] if inputs_only else spec + def _trackable_children(self, save_type="checkpoint", **kwargs): if save_type == "savedmodel": # SavedModel needs to ignore the execution functions. diff --git a/keras_core/utils/tf_utils.py b/keras_core/backend/tensorflow/tf_utils.py similarity index 81% rename from keras_core/utils/tf_utils.py rename to keras_core/backend/tensorflow/tf_utils.py index 12a29cc87..bf1f4e484 100644 --- a/keras_core/utils/tf_utils.py +++ b/keras_core/backend/tensorflow/tf_utils.py @@ -109,3 +109,29 @@ def encode_categorical_inputs( ) else: return tf.multiply(bincounts, idf_weights) + + +def get_tensor_spec(t, dynamic_batch=False, name=None): + """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" + if isinstance(t, tf.TypeSpec): + spec = t + elif isinstance(t, tf.__internal__.CompositeTensor): + # Check for ExtensionTypes + spec = t._type_spec + elif hasattr(t, "shape") and hasattr(t, "dtype"): + spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) + else: + return None # Allow non-Tensors to pass through. + + if not dynamic_batch: + return spec + + shape = spec.shape + if shape.rank is None or shape.rank == 0: + return spec + + shape_list = shape.as_list() + shape_list[0] = None + shape = tf.TensorShape(shape_list) + spec._shape = shape + return spec diff --git a/keras_core/export/__init__.py b/keras_core/export/__init__.py new file mode 100644 index 000000000..2d34bf5fc --- /dev/null +++ b/keras_core/export/__init__.py @@ -0,0 +1 @@ +from keras_core.export.export_lib import ExportArchive diff --git a/keras_core/export/export_lib.py b/keras_core/export/export_lib.py new file mode 100644 index 000000000..144c0e2b0 --- /dev/null +++ b/keras_core/export/export_lib.py @@ -0,0 +1,574 @@ +"""Library for exporting inference-only Keras models/layers.""" + +from keras_core import backend +from keras_core.api_export import keras_core_export +from keras_core.layers import Layer +from keras_core.models import Functional +from keras_core.models import Sequential +from keras_core.utils import io_utils +from keras_core.utils.module_utils import tensorflow as tf + + +@keras_core_export("keras_core.export.ExportArchive") +class ExportArchive(tf.__internal__.tracking.AutoTrackable): + """ExportArchive is used to write SavedModel artifacts (e.g. for inference). + + If you have a Keras model or layer that you want to export as SavedModel for + serving (e.g. via TensorFlow-Serving), you can use `ExportArchive` + to configure the different serving endpoints you need to make available, + as well as their signatures. Simply instantiate an `ExportArchive`, + use `track()` to register the layer(s) or model(s) to be used, + then use the `add_endpoint()` method to register a new serving endpoint. + When done, use the `write_out()` method to save the artifact. + + The resulting artifact is a SavedModel and can be reloaded via + `tf.saved_model.load`. + + Examples: + + Here's how to export a model for inference. + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + export_archive.write_out("path/to/location") + + # Elsewhere, we can reload the artifact and serve it. + # The endpoint we added is available as a method: + serving_model = tf.saved_model.load("path/to/location") + outputs = serving_model.serve(inputs) + ``` + + Here's how to export a model with one endpoint for inference and one + endpoint for a training-mode forward pass (e.g. with dropout on). + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="call_inference", + fn=lambda x: model.call(x, training=False), + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + export_archive.add_endpoint( + name="call_training", + fn=lambda x: model.call(x, training=True), + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + export_archive.write_out("path/to/location") + ``` + + **Note on resource tracking:** + + `ExportArchive` is able to automatically track all `tf.Variables` used + by its endpoints, so most of the time calling `.track(model)` + is not strictly required. However, if your model uses lookup layers such + as `IntegerLookup`, `StringLookup`, or `TextVectorization`, + it will need to be tracked explicitly via `.track(model)`. + + Explicit tracking is also required if you need to be able to access + the properties `variables`, `trainable_variables`, or + `non_trainable_variables` on the revived archive. + """ + + def __init__(self): + self._endpoint_names = [] + self._endpoint_signatures = {} + self.tensorflow_version = tf.__version__ + self.variables = [] + self.trainable_variables = [] + self.non_trainable_variables = [] + + @tf.__internal__.tracking.no_automatic_dependency_tracking + def track(self, resource): + """Track the variables (and other assets) of a layer or model.""" + if not isinstance(resource, tf.__internal__.tracking.Trackable): + raise ValueError( + "Invalid resource type. Expected an instance of a " + "TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). " + f"Received instead an object of type '{type(resource)}'. " + f"Object received: {resource}" + ) + if isinstance(resource, Layer): + if not resource.built: + raise ValueError( + "The layer provided has not yet been built. " + "It must be built before export." + ) + + # Layers in `_tracked` are not part of the trackables that get saved, + # because we're creating the attribute in a + # no_automatic_dependency_tracking scope. + if not hasattr(self, "_tracked"): + self._tracked = [] + self._tracked.append(resource) + + if isinstance(resource, Layer): + # Variables in the lists below are actually part of the trackables + # that get saved, because the lists are created in __init__. + self.variables += resource.variables + self.trainable_variables += resource.trainable_variables + self.non_trainable_variables += resource.non_trainable_variables + + def add_endpoint(self, name, fn, input_signature=None): + """Register a new serving endpoint. + + Arguments: + name: Str, name of the endpoint. + fn: A function. It should only leverage resources + (e.g. `tf.Variable` objects or `tf.lookup.StaticHashTable` + objects) that are available on the models/layers + tracked by the `ExportArchive` (you can call `.track(model)` + to track a new model). + The shape and dtype of the inputs to the function must be + known. For that purpose, you can either 1) make sure that + `fn` is a `tf.function` that has been called at least once, or + 2) provide an `input_signature` argument that specifies the + shape and dtype of the inputs (see below). + input_signature: Used to specify the shape and dtype of the + inputs to `fn`. List of `tf.TensorSpec` objects (one + per positional input argument of `fn`). Nested arguments are + allowed (see below for an example showing a Functional model + with 2 input arguments). + + Example: + + Adding an endpoint using the `input_signature` argument when the + model has a single input argument: + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + ``` + + Adding an endpoint using the `input_signature` argument when the + model has two positional input arguments: + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[ + tf.TensorSpec(shape=(None, 3), dtype=tf.float32), + tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + ], + ) + ``` + + Adding an endpoint using the `input_signature` argument when the + model has one input argument that is a list of 2 tensors (e.g. + a Functional model with 2 inputs): + + ```python + model = keras_core.Model(inputs=[x1, x2], outputs=outputs) + + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[ + [ + tf.TensorSpec(shape=(None, 3), dtype=tf.float32), + tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + ], + ], + ) + ``` + + This also works with dictionary inputs: + + ```python + model = keras_core.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs) + + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[ + { + "x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32), + "x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + }, + ], + ) + ``` + + Adding an endpoint that is a `tf.function`: + + ```python + @tf.function() + def serving_fn(x): + return model(x) + + # The function must be traced, i.e. it must be called at least once. + serving_fn(tf.random.normal(shape=(2, 3))) + + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint(name="serve", fn=serving_fn) + ``` + """ + if name in self._endpoint_names: + raise ValueError(f"Endpoint name '{name}' is already taken.") + + if input_signature: + decorated_fn = tf.function(fn, input_signature=input_signature) + self._endpoint_signatures[name] = input_signature + else: + if isinstance(fn, tf.types.experimental.GenericFunction): + if not fn._list_all_concrete_functions(): + raise ValueError( + f"The provided tf.function '{fn}' " + "has never been called. " + "To specify the expected shape and dtype " + "of the function's arguments, " + "you must either provide a function that " + "has been called at least once, or alternatively pass " + "an `input_signature` argument in `add_endpoint()`." + ) + decorated_fn = fn + else: + raise ValueError( + "If the `fn` argument provided is not a `tf.function`, " + "you must provide an `input_signature` argument to " + "specify the shape and dtype of the function arguments. " + "Example:\n\n" + "export_archive.add_endpoint(\n" + " name='call',\n" + " fn=model.call,\n" + " input_signature=[\n" + " tf.TensorSpec(\n" + " shape=(None, 224, 224, 3),\n" + " dtype=tf.float32,\n" + " )\n" + " ],\n" + ")" + ) + setattr(self, name, decorated_fn) + self._endpoint_names.append(name) + + def add_variable_collection(self, name, variables): + """Register a set of variables to be retrieved after reloading. + + Arguments: + name: The string name for the collection. + variables: A tuple/list/set of `tf.Variable` instances. + + Example: + + ```python + export_archive = ExportArchive() + export_archive.track(model) + # Register an endpoint + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + # Save a variable collection + export_archive.add_variable_collection( + name="optimizer_variables", variables=model.optimizer.variables) + export_archive.write_out("path/to/location") + + # Reload the object + revived_object = tf.saved_model.load("path/to/location") + # Retrieve the variables + optimizer_variables = revived_object.optimizer_variables + ``` + """ + if not isinstance(variables, (list, tuple, set)): + raise ValueError( + "Expected `variables` to be a list/tuple/set. " + f"Received instead object of type '{type(variables)}'." + ) + # Ensure that all variables added are either tf.Variables + # or Variables created by Keras Core with the TF backend. + if not all(isinstance(v, tf.Variable) for v in variables) and not ( + all( + isinstance(v, (tf.Variable, backend.Variable)) + for v in variables + ) + and (backend.backend() == "tensorflow") + ): + raise ValueError( + "Expected all elements in `variables` to be " + "`tf.Variable` instances. Found instead the following types: " + f"{list(set(type(v) for v in variables))}" + ) + setattr(self, name, list(variables)) + + def write_out(self, filepath, options=None): + """Write the corresponding SavedModel to disk. + + Arguments: + filepath: `str` or `pathlib.Path` object. + Path where to save the artifact. + options: `tf.saved_model.SaveOptions` object that specifies + SavedModel saving options. + + **Note on TF-Serving**: all endpoints registered via `add_endpoint()` + are made visible for TF-Serving in the SavedModel artifact. In addition, + the first endpoint registered is made visible under the alias + `"serving_default"` (unless an endpoint with the name + `"serving_default"` was already registered manually), + since TF-Serving requires this endpoint to be set. + """ + if not self._endpoint_names: + raise ValueError( + "No endpoints have been set yet. Call add_endpoint()." + ) + self._filter_and_track_resources() + + signatures = {} + for name in self._endpoint_names: + signatures[name] = self._get_concrete_fn(name) + # Add "serving_default" signature key for TFServing + if "serving_default" not in self._endpoint_names: + signatures["serving_default"] = self._get_concrete_fn( + self._endpoint_names[0] + ) + tf.saved_model.save( + self, filepath, options=options, signatures=signatures + ) + # Print out available endpoints + endpoints = "\n\n".join( + _print_signature(getattr(self, name), name) + for name in self._endpoint_names + ) + io_utils.print_msg( + f"Saved artifact at '{filepath}'. " + "The following endpoints are available:\n\n" + f"{endpoints}" + ) + + def _get_concrete_fn(self, endpoint): + """Workaround for some SavedModel quirks.""" + if endpoint in self._endpoint_signatures: + return getattr(self, endpoint) + else: + traces = getattr(self, endpoint)._trackable_children("saved_model") + return list(traces.values())[0] + + def _get_variables_used_by_endpoints(self): + fns = [self._get_concrete_fn(name) for name in self._endpoint_names] + return _list_variables_used_by_fns(fns) + + def _filter_and_track_resources(self): + """Track resources used by endpoints / referenced in `track()` calls.""" + # Start by extracting variables from endpoints. + fns = [self._get_concrete_fn(name) for name in self._endpoint_names] + tvs, ntvs = _list_variables_used_by_fns(fns) + self._all_variables = list(tvs + ntvs) + + # Next, track lookup tables. + # Hopefully, one day this will be automated at the tf.function level. + self._misc_assets = [] + from keras_core.layers import IntegerLookup + from keras_core.layers import StringLookup + from keras_core.layers import TextVectorization + + if hasattr(self, "_tracked"): + for root in self._tracked: + descendants = tf.train.TrackableView(root).descendants() + for trackable in descendants: + if isinstance( + trackable, + (IntegerLookup, StringLookup, TextVectorization), + ): + self._misc_assets.append(trackable) + + +def export_model(model, filepath): + export_archive = ExportArchive() + export_archive.track(model) + if isinstance(model, (Functional, Sequential)): + input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs) + if isinstance(input_signature, list) and len(input_signature) > 1: + input_signature = [input_signature] + export_archive.add_endpoint("serve", model.__call__, input_signature) + else: + save_spec = model._get_save_spec() + if not save_spec: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + input_signature = [save_spec] + export_archive.add_endpoint("serve", model.__call__, input_signature) + export_archive.write_out(filepath) + + +class TFSMLayer(Layer): + """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. + + Arguments: + filepath: `str` or `pathlib.Path` object. The path to the SavedModel. + call_endpoint: Name of the endpoint to use as the `call()` method + of the reloaded layer. If the SavedModel was created + via `model.export()`, + then the default endpoint name is `'serve'`. In other cases + it may be named `'serving_default'`. + + Example: + + ```python + model.export("path/to/artifact") + reloaded_layer = TFSMLayer("path/to/artifact") + outputs = reloaded_layer(inputs) + ``` + + The reloaded object can be used like a regular Keras layer, and supports + training/fine-tuning of its trainable weights. Note that the reloaded + object retains none of the internal structure or custom methods of the + original object -- it's a brand new layer created around the saved + function. + + **Limitations:** + + * Only call endpoints with a single `inputs` tensor argument + (which may optionally be a dict/tuple/list of tensors) are supported. + For endpoints with multiple separate input tensor arguments, consider + subclassing `TFSMLayer` and implementing a `call()` method with a + custom signature. + * If you need training-time behavior to differ from inference-time behavior + (i.e. if you need the reloaded object to support a `training=True` argument + in `__call__()`), make sure that the training-time call function is + saved as a standalone endpoint in the artifact, and provide its name + to the `TFSMLayer` via the `call_training_endpoint` argument. + """ + + def __init__( + self, + filepath, + call_endpoint="serve", + call_training_endpoint=None, + trainable=True, + name=None, + dtype=None, + ): + # Initialize an empty layer, then add_weight() etc. as needed. + super().__init__(trainable=trainable, name=name, dtype=dtype) + + self._reloaded_obj = tf.saved_model.load(filepath) + + self.filepath = filepath + self.call_endpoint = call_endpoint + self.call_training_endpoint = call_training_endpoint + + # Resolve the call function. + if hasattr(self._reloaded_obj, call_endpoint): + # Case 1: it's set as an attribute. + self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint) + elif call_endpoint in self._reloaded_obj.signatures: + # Case 2: it's listed in the `signatures` field. + self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] + else: + raise ValueError( + f"The endpoint '{call_endpoint}' is neither an " + "attribute of the reloaded SavedModel, nor an entry " + "in the `signatures` field of the reloaded SavedModel. " + ) + + # Resolving the training function. + if call_training_endpoint: + if hasattr(self._reloaded_obj, call_training_endpoint): + self.call_training_endpoint_fn = getattr( + self._reloaded_obj, call_training_endpoint + ) + elif call_training_endpoint in self._reloaded_obj.signatures: + self.call_training_endpoint_fn = self._reloaded_obj.signatures[ + call_training_endpoint + ] + else: + raise ValueError( + f"The endpoint '{call_training_endpoint}' is " + "neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. " + ) + + # Add trainable and non-trainable weights from the call_endpoint_fn. + all_fns = [self.call_endpoint_fn] + if call_training_endpoint: + all_fns.append(self.call_training_endpoint_fn) + tvs, ntvs = _list_variables_used_by_fns(all_fns) + for v in tvs: + self._add_existing_weight(v) + for v in ntvs: + self._add_existing_weight(v) + self.built = True + + def _add_existing_weight(self, weight): + """Tracks an existing weight.""" + self._track_variable(weight) + + def call(self, inputs, training=False, **kwargs): + if training: + if self.call_training_endpoint: + return self.call_training_endpoint_fn(inputs, **kwargs) + return self.call_endpoint_fn(inputs, **kwargs) + + def get_config(self): + base_config = super().get_config() + config = { + # Note: this is not intended to be portable. + "filepath": self.filepath, + "call_endpoint": self.call_endpoint, + "call_training_endpoint": self.call_training_endpoint, + } + return {**base_config, **config} + + +def _make_tensor_spec(x): + return tf.TensorSpec(x.shape, dtype=x.dtype, name=x.name) + + +def _print_signature(fn, name): + concrete_fn = fn._list_all_concrete_functions()[0] + pprinted_signature = concrete_fn.pretty_printed_signature(verbose=True) + lines = pprinted_signature.split("\n") + lines = [f"* Endpoint '{name}'"] + lines[1:] + endpoint = "\n".join(lines) + return endpoint + + +def _list_variables_used_by_fns(fns): + trainable_variables = [] + non_trainable_variables = [] + trainable_variables_ids = set() + non_trainable_variables_ids = set() + for fn in fns: + if hasattr(fn, "concrete_functions"): + concrete_functions = fn.concrete_functions + elif hasattr(fn, "get_concrete_function"): + concrete_functions = [fn.get_concrete_function()] + else: + concrete_functions = [fn] + for concrete_fn in concrete_functions: + for v in concrete_fn.trainable_variables: + if id(v) not in trainable_variables_ids: + trainable_variables.append(v) + trainable_variables_ids.add(id(v)) + + for v in concrete_fn.variables: + if ( + id(v) not in trainable_variables_ids + and id(v) not in non_trainable_variables_ids + ): + non_trainable_variables.append(v) + non_trainable_variables_ids.add(id(v)) + return trainable_variables, non_trainable_variables diff --git a/keras_core/export/export_lib_test.py b/keras_core/export/export_lib_test.py new file mode 100644 index 000000000..397287f0c --- /dev/null +++ b/keras_core/export/export_lib_test.py @@ -0,0 +1,595 @@ +"""Tests for inference-only model/layer exporting utilities.""" +import os + +import numpy as np +import pytest +import tensorflow as tf + +from keras_core import backend +from keras_core import layers +from keras_core import models +from keras_core import testing +from keras_core import utils +from keras_core.export import export_lib +from keras_core.saving import saving_lib + + +def get_model(): + layer_list = [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + model = models.Sequential(layer_list) + return model + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Export only currently supports the TF backend.", +) +class ExportArchiveTest(testing.TestCase): + def test_standard_model_export(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input).numpy() + + export_lib.export_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6 + ) + + def test_low_level_model_export(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input).numpy() + + # Test variable tracking + export_archive = export_lib.ExportArchive() + export_archive.track(model) + self.assertLen(export_archive.variables, 8) + self.assertLen(export_archive.trainable_variables, 6) + self.assertLen(export_archive.non_trainable_variables, 2) + + @tf.function() + def my_endpoint(x): + return model(x) + + # Test registering an endpoint that is a tf.function (called) + my_endpoint(ref_input) # Trace fn + + export_archive.add_endpoint( + "call", + my_endpoint, + ) + export_archive.write_out(temp_filepath) + + revived_model = tf.saved_model.load(temp_filepath) + self.assertFalse(hasattr(revived_model, "_tracked")) + self.assertAllClose( + ref_output, revived_model.call(ref_input).numpy(), atol=1e-6 + ) + self.assertLen(revived_model.variables, 8) + self.assertLen(revived_model.trainable_variables, 6) + self.assertLen(revived_model.non_trainable_variables, 2) + + # Test registering an endpoint that is NOT a tf.function + export_archive = export_lib.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 10), + dtype=tf.float32, + ) + ], + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, revived_model.call(ref_input).numpy(), atol=1e-6 + ) + + def test_layer_export(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") + + layer = layers.BatchNormalization() + ref_input = tf.random.normal((3, 10)) + ref_output = layer(ref_input).numpy() # Build layer (important) + + export_archive = export_lib.ExportArchive() + export_archive.track(layer) + export_archive.add_endpoint( + "call", + layer.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 10), + dtype=tf.float32, + ) + ], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, revived_layer.call(ref_input).numpy(), atol=1e-6 + ) + + def test_multi_input_output_functional_model(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + x1 = layers.Input((2,)) + x2 = layers.Input((2,)) + y1 = layers.Dense(3)(x1) + y2 = layers.Dense(3)(x2) + model = models.Model([x1, x2], [y1, y2]) + + ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))] + ref_outputs = model(ref_inputs) + + export_archive = export_lib.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "serve", + model.call, + input_signature=[ + [ + tf.TensorSpec( + shape=(None, 2), + dtype=tf.float32, + ), + tf.TensorSpec( + shape=(None, 2), + dtype=tf.float32, + ), + ] + ], + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_outputs[0].numpy(), + revived_model.serve(ref_inputs)[0].numpy(), + atol=1e-6, + ) + self.assertAllClose( + ref_outputs[1].numpy(), + revived_model.serve(ref_inputs)[1].numpy(), + atol=1e-6, + ) + + # Now test dict inputs + model = models.Model({"x1": x1, "x2": x2}, [y1, y2]) + + ref_inputs = { + "x1": tf.random.normal((3, 2)), + "x2": tf.random.normal((3, 2)), + } + ref_outputs = model(ref_inputs) + + export_archive = export_lib.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "serve", + model.call, + input_signature=[ + { + "x1": tf.TensorSpec( + shape=(None, 2), + dtype=tf.float32, + ), + "x2": tf.TensorSpec( + shape=(None, 2), + dtype=tf.float32, + ), + } + ], + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_outputs[0].numpy(), + revived_model.serve(ref_inputs)[0].numpy(), + atol=1e-6, + ) + self.assertAllClose( + ref_outputs[1].numpy(), + revived_model.serve(ref_inputs)[1].numpy(), + atol=1e-6, + ) + + # def test_model_with_lookup_table(self): + # tf.debugging.disable_traceback_filtering() + # temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + # text_vectorization = layers.TextVectorization() + # text_vectorization.adapt(["one two", "three four", "five six"]) + # model = models.Sequential( + # [ + # layers.Input(shape=(), dtype="string"), + # text_vectorization, + # layers.Embedding(10, 32), + # layers.Dense(1), + # ] + # ) + # ref_input = tf.convert_to_tensor(["one two three four"]) + # ref_output = model(ref_input).numpy() + + # export_lib.export_model(model, temp_filepath) + # revived_model = tf.saved_model.load(temp_filepath) + # self.assertAllClose( + # ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6 + # ) + + def test_track_multiple_layers(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + layer_1 = layers.Dense(2) + ref_input_1 = tf.random.normal((3, 4)) + ref_output_1 = layer_1(ref_input_1).numpy() + layer_2 = layers.Dense(3) + ref_input_2 = tf.random.normal((3, 5)) + ref_output_2 = layer_2(ref_input_2).numpy() + + export_archive = export_lib.ExportArchive() + export_archive.add_endpoint( + "call_1", + layer_1.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 4), + dtype=tf.float32, + ), + ], + ) + export_archive.add_endpoint( + "call_2", + layer_2.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 5), + dtype=tf.float32, + ), + ], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output_1, + revived_layer.call_1(ref_input_1).numpy(), + atol=1e-6, + ) + self.assertAllClose( + ref_output_2, + revived_layer.call_2(ref_input_2).numpy(), + atol=1e-6, + ) + + def test_non_standard_layer_signature(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") + + layer = layers.MultiHeadAttention(2, 2) + x1 = tf.random.normal((3, 2, 2)) + x2 = tf.random.normal((3, 2, 2)) + ref_output = layer(x1, x2).numpy() # Build layer (important) + export_archive = export_lib.ExportArchive() + export_archive.track(layer) + export_archive.add_endpoint( + "call", + layer.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 2, 2), + dtype=tf.float32, + ), + tf.TensorSpec( + shape=(None, 2, 2), + dtype=tf.float32, + ), + ], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, + revived_layer.call(query=x1, value=x2).numpy(), + atol=1e-6, + ) + + def test_variable_collection(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = models.Sequential( + [ + layers.Input((10,)), + layers.Dense(2), + layers.Dense(2), + ] + ) + + # Test variable tracking + export_archive = export_lib.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 10), + dtype=tf.float32, + ) + ], + ) + export_archive.add_variable_collection( + "my_vars", model.layers[1].weights + ) + self.assertLen(export_archive.my_vars, 2) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertLen(revived_model.my_vars, 2) + + def test_export_model_errors(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + # Model has not been built + model = models.Sequential([layers.Dense(2)]) + with self.assertRaisesRegex(ValueError, "It must be built"): + export_lib.export_model(model, temp_filepath) + + # Subclassed model has not been called + class MyModel(models.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense = layers.Dense(2) + + def build(self, input_shape): + self.dense.build(input_shape) + self.built = True + + def call(self, x): + return self.dense(x) + + model = MyModel() + model.build((2, 3)) + with self.assertRaisesRegex(ValueError, "It must be called"): + export_lib.export_model(model, temp_filepath) + + def test_export_archive_errors(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([layers.Dense(2)]) + model(tf.random.normal((2, 3))) + + # Endpoint name reuse + export_archive = export_lib.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 3), + dtype=tf.float32, + ) + ], + ) + with self.assertRaisesRegex(ValueError, "already taken"): + export_archive.add_endpoint( + "call", + model.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 3), + dtype=tf.float32, + ) + ], + ) + + # Write out with no endpoints + export_archive = export_lib.ExportArchive() + export_archive.track(model) + with self.assertRaisesRegex(ValueError, "No endpoints have been set"): + export_archive.write_out(temp_filepath) + + # Invalid object type + with self.assertRaisesRegex(ValueError, "Invalid resource type"): + export_archive = export_lib.ExportArchive() + export_archive.track("model") + + # Set endpoint with no input signature + export_archive = export_lib.ExportArchive() + export_archive.track(model) + with self.assertRaisesRegex( + ValueError, "you must provide an `input_signature`" + ): + export_archive.add_endpoint( + "call", + model.call, + ) + + # Set endpoint that has never been called + export_archive = export_lib.ExportArchive() + export_archive.track(model) + + @tf.function() + def my_endpoint(x): + return model(x) + + export_archive = export_lib.ExportArchive() + export_archive.track(model) + with self.assertRaisesRegex( + ValueError, "you must either provide a function" + ): + export_archive.add_endpoint( + "call", + my_endpoint, + ) + + def test_export_no_assets(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + # Case where there are legitimately no assets. + model = models.Sequential([layers.Flatten()]) + model(tf.random.normal((2, 3))) + export_archive = export_lib.ExportArchive() + export_archive.add_endpoint( + "call", + model.call, + input_signature=[ + tf.TensorSpec( + shape=(None, 3), + dtype=tf.float32, + ) + ], + ) + export_archive.write_out(temp_filepath) + + def test_model_export_method(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input).numpy() + + model.export(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6 + ) + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Export only currently supports the TF backend.", +) +class TestTFSMLayer(testing.TestCase): + def test_reloading_export_archive(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input).numpy() + + export_lib.export_model(model, temp_filepath) + reloaded_layer = export_lib.TFSMLayer(temp_filepath) + self.assertAllClose( + reloaded_layer(ref_input).numpy(), ref_output, atol=1e-7 + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + # TODO(nkovela): Expand test coverage/debug fine-tuning and + # non-trainable use cases here. + + def test_reloading_default_saved_model(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input).numpy() + + tf.saved_model.save(model, temp_filepath) + reloaded_layer = export_lib.TFSMLayer( + temp_filepath, call_endpoint="serving_default" + ) + # The output is a dict, due to the nature of SavedModel saving. + new_output = reloaded_layer(ref_input) + self.assertAllClose( + new_output[list(new_output.keys())[0]].numpy(), + ref_output, + atol=1e-7, + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_call_training(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + utils.set_random_seed(1337) + model = models.Sequential( + [ + layers.Input((10,)), + layers.Dense(10), + layers.Dropout(0.99999), + ] + ) + export_archive = export_lib.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="call_inference", + fn=lambda x: model(x, training=False), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.add_endpoint( + name="call_training", + fn=lambda x: model(x, training=True), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + reloaded_layer = export_lib.TFSMLayer( + temp_filepath, + call_endpoint="call_inference", + call_training_endpoint="call_training", + ) + inference_output = reloaded_layer( + tf.random.normal((1, 10)), training=False + ) + training_output = reloaded_layer( + tf.random.normal((1, 10)), training=True + ) + self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7) + self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7) + + def test_serialization(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input).numpy() + + export_lib.export_model(model, temp_filepath) + reloaded_layer = export_lib.TFSMLayer(temp_filepath) + + # Test reinstantiation from config + config = reloaded_layer.get_config() + rereloaded_layer = export_lib.TFSMLayer.from_config(config) + self.assertAllClose( + rereloaded_layer(ref_input).numpy(), ref_output, atol=1e-7 + ) + + # Test whole model saving with reloaded layer inside + model = models.Sequential([reloaded_layer]) + temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") + model.save(temp_model_filepath, save_format="keras_v3") + reloaded_model = saving_lib.load_model( + temp_model_filepath, + custom_objects={"TFSMLayer": export_lib.TFSMLayer}, + ) + self.assertAllClose( + reloaded_model(ref_input).numpy(), ref_output, atol=1e-7 + ) + + def test_errors(self): + # Test missing call endpoint + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) + export_lib.export_model(model, temp_filepath) + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + export_lib.TFSMLayer(temp_filepath, call_endpoint="wrong") + + # Test missing call training endpoint + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + export_lib.TFSMLayer( + temp_filepath, + call_endpoint="serve", + call_training_endpoint="wrong", + ) diff --git a/keras_core/layers/preprocessing/hashing.py b/keras_core/layers/preprocessing/hashing.py index 90bdf924d..14f1424dc 100644 --- a/keras_core/layers/preprocessing/hashing.py +++ b/keras_core/layers/preprocessing/hashing.py @@ -2,9 +2,9 @@ import numpy as np from keras_core import backend from keras_core.api_export import keras_core_export +from keras_core.backend.tensorflow import tf_utils from keras_core.layers.layer import Layer from keras_core.utils import backend_utils -from keras_core.utils import tf_utils from keras_core.utils.module_utils import tensorflow as tf diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 91ac1f487..50efae4ac 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -101,8 +101,7 @@ class Functional(Function, Model): f"including invalid value {v} of type {type(v)}" ) if k != v.name: - # TODO: maybe make this a warning - raise ValueError( + warnings.warn( "When providing `inputs` as a dict, all keys in the " "dict must match the names of the corresponding " f"tensors. Received key '{k}' mapping to value {v} " diff --git a/keras_core/models/functional_test.py b/keras_core/models/functional_test.py index d667b921c..3eb6b1ed2 100644 --- a/keras_core/models/functional_test.py +++ b/keras_core/models/functional_test.py @@ -87,11 +87,6 @@ class FunctionalTest(testing.TestCase): ): model = Functional({"aa": [input_a], "bb": input_b}, outputs) - with self.assertRaisesRegex( - ValueError, "all keys in the dict must match the names" - ): - model = Functional({"aa": input_a, "bb": input_b}, outputs) - model = Functional({"a": input_a, "b": input_b}, outputs) # Eager call diff --git a/keras_core/models/model.py b/keras_core/models/model.py index af75bf165..f27557add 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -448,12 +448,42 @@ class Model(Trainer, Layer): return json.dumps(model_config, **kwargs) def export(self, filepath, format="tf_saved_model"): - raise NotImplementedError( - "The export() method is not yet supported. It will " - "be added in the next version. For the time being, you " - "can use `tf.saved_model.save(model)` to save a " - "TensorFlow SavedModel for your Keras Core model." - ) + """[TF backend only]* Create a TF SavedModel artifact for inference + (e.g. via TF-Serving). + + **Note:** This can currently only be used with the TF backend. + + This method lets you export a model to a lightweight SavedModel artifact + that contains the model's forward pass only (its `call()` method) + and can be served via e.g. TF-Serving. The forward pass is registered + under the name `serve()` (see example below). + + The original code of the model (including any custom layers you may + have used) is *no longer* necessary to reload the artifact -- it is + entirely standalone. + + Args: + filepath: `str` or `pathlib.Path` object. Path where to save + the artifact. + + Example: + + ```python + # Create the artifact + model.export("path/to/location") + + # Later, in a different process / environment... + reloaded_artifact = tf.saved_model.load("path/to/location") + predictions = reloaded_artifact.serve(input_data) + ``` + + If you would like to customize your serving endpoints, you can + use the lower-level `keras_core.export.ExportArchive` class. The + `export()` method relies on `ExportArchive` internally. + """ + from keras_core.export import export_lib + + export_lib.export_model(self, filepath) @classmethod def from_config(cls, config, custom_objects=None):