Add Export for TF backend (#692)

* Add saved model test

* Add TF tracking attribute

* Add tests for functional and subclassed

* Fix saving trackables

* Fix test assertions

* Fix formatting

* Add comments for attribute tracking

* Change saved model test description

* Add backend conditional for attribute

* Change package name

* Change epoch nums

* Revert epochs

* Add set verbose logging utility and debug callback tests

* Fix formatting

* Initial port of model export

* Fix imports

* Add save spec methods to TF layer

* Add export function to Keras Core base model

* Downgrade naming error to warning and debug TF variable collections check

* Simplify weight reloading

* Fix formatting, add TODOs

* Unify tf_utils under backend/tensorflow

* Fix docstring and import

* Fix module utils import

* Fix lookup layers export and add test

* Change naming to TFSMLayer

* Remove parameterized

* Comment out failing test
This commit is contained in:
Neel Kovelamudi 2023-08-11 21:38:42 +00:00 committed by Francois Chollet
parent ef72bfb728
commit c93f1be73e
9 changed files with 1284 additions and 14 deletions

@ -1,11 +1,61 @@
import tensorflow as tf
from keras_core.backend.tensorflow import tf_utils
class TFLayer(tf.__internal__.tracking.AutoTrackable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Export-related attributes
self._saved_model_inputs_spec = None
self._saved_model_arg_spec = None
def _post_build(self):
"""Can be overriden to perform post-build actions."""
pass
@tf.__internal__.tracking.no_automatic_dependency_tracking
def _set_save_spec(self, inputs, args=None, kwargs=None):
"""Defines the save spec so that serialization can trace layer calls.
The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are
saved into a tuple of `([inputs] + args, kwargs)`.
Args:
inputs: possibly nested inputs passed into the call function.
args: a list of positional arguments passed into call.
kwargs: a dictionary of keyword arguments passed into call.
"""
if self._saved_model_inputs_spec is not None:
return # Already set.
inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
kwargs_spec = {}
# Filter out non-tensor arguments from kwargs.
for key, kwarg in kwargs.items():
flat_kwarg = tf.nest.flatten(kwarg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
if any(s is None for s in flat_specs):
continue
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
self._saved_model_inputs_spec = inputs_spec
self._saved_model_arg_spec = (
[inputs_spec] + list(args_spec),
kwargs_spec,
)
def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
if self._saved_model_inputs_spec is None:
return None
spec = tf.nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_arg_spec,
)
return spec[0][0] if inputs_only else spec
def _trackable_children(self, save_type="checkpoint", **kwargs):
if save_type == "savedmodel":
# SavedModel needs to ignore the execution functions.

@ -109,3 +109,29 @@ def encode_categorical_inputs(
)
else:
return tf.multiply(bincounts, idf_weights)
def get_tensor_spec(t, dynamic_batch=False, name=None):
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
if isinstance(t, tf.TypeSpec):
spec = t
elif isinstance(t, tf.__internal__.CompositeTensor):
# Check for ExtensionTypes
spec = t._type_spec
elif hasattr(t, "shape") and hasattr(t, "dtype"):
spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
else:
return None # Allow non-Tensors to pass through.
if not dynamic_batch:
return spec
shape = spec.shape
if shape.rank is None or shape.rank == 0:
return spec
shape_list = shape.as_list()
shape_list[0] = None
shape = tf.TensorShape(shape_list)
spec._shape = shape
return spec

@ -0,0 +1 @@
from keras_core.export.export_lib import ExportArchive

@ -0,0 +1,574 @@
"""Library for exporting inference-only Keras models/layers."""
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers import Layer
from keras_core.models import Functional
from keras_core.models import Sequential
from keras_core.utils import io_utils
from keras_core.utils.module_utils import tensorflow as tf
@keras_core_export("keras_core.export.ExportArchive")
class ExportArchive(tf.__internal__.tracking.AutoTrackable):
"""ExportArchive is used to write SavedModel artifacts (e.g. for inference).
If you have a Keras model or layer that you want to export as SavedModel for
serving (e.g. via TensorFlow-Serving), you can use `ExportArchive`
to configure the different serving endpoints you need to make available,
as well as their signatures. Simply instantiate an `ExportArchive`,
use `track()` to register the layer(s) or model(s) to be used,
then use the `add_endpoint()` method to register a new serving endpoint.
When done, use the `write_out()` method to save the artifact.
The resulting artifact is a SavedModel and can be reloaded via
`tf.saved_model.load`.
Examples:
Here's how to export a model for inference.
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
```
Here's how to export a model with one endpoint for inference and one
endpoint for a training-mode forward pass (e.g. with dropout on).
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
```
**Note on resource tracking:**
`ExportArchive` is able to automatically track all `tf.Variables` used
by its endpoints, so most of the time calling `.track(model)`
is not strictly required. However, if your model uses lookup layers such
as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
it will need to be tracked explicitly via `.track(model)`.
Explicit tracking is also required if you need to be able to access
the properties `variables`, `trainable_variables`, or
`non_trainable_variables` on the revived archive.
"""
def __init__(self):
self._endpoint_names = []
self._endpoint_signatures = {}
self.tensorflow_version = tf.__version__
self.variables = []
self.trainable_variables = []
self.non_trainable_variables = []
@tf.__internal__.tracking.no_automatic_dependency_tracking
def track(self, resource):
"""Track the variables (and other assets) of a layer or model."""
if not isinstance(resource, tf.__internal__.tracking.Trackable):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)
if isinstance(resource, Layer):
if not resource.built:
raise ValueError(
"The layer provided has not yet been built. "
"It must be built before export."
)
# Layers in `_tracked` are not part of the trackables that get saved,
# because we're creating the attribute in a
# no_automatic_dependency_tracking scope.
if not hasattr(self, "_tracked"):
self._tracked = []
self._tracked.append(resource)
if isinstance(resource, Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
self.variables += resource.variables
self.trainable_variables += resource.trainable_variables
self.non_trainable_variables += resource.non_trainable_variables
def add_endpoint(self, name, fn, input_signature=None):
"""Register a new serving endpoint.
Arguments:
name: Str, name of the endpoint.
fn: A function. It should only leverage resources
(e.g. `tf.Variable` objects or `tf.lookup.StaticHashTable`
objects) that are available on the models/layers
tracked by the `ExportArchive` (you can call `.track(model)`
to track a new model).
The shape and dtype of the inputs to the function must be
known. For that purpose, you can either 1) make sure that
`fn` is a `tf.function` that has been called at least once, or
2) provide an `input_signature` argument that specifies the
shape and dtype of the inputs (see below).
input_signature: Used to specify the shape and dtype of the
inputs to `fn`. List of `tf.TensorSpec` objects (one
per positional input argument of `fn`). Nested arguments are
allowed (see below for an example showing a Functional model
with 2 input arguments).
Example:
Adding an endpoint using the `input_signature` argument when the
model has a single input argument:
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
```
Adding an endpoint using the `input_signature` argument when the
model has two positional input arguments:
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
)
```
Adding an endpoint using the `input_signature` argument when the
model has one input argument that is a list of 2 tensors (e.g.
a Functional model with 2 inputs):
```python
model = keras_core.Model(inputs=[x1, x2], outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
],
)
```
This also works with dictionary inputs:
```python
model = keras_core.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
{
"x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
"x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
},
],
)
```
Adding an endpoint that is a `tf.function`:
```python
@tf.function()
def serving_fn(x):
return model(x)
# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)
```
"""
if name in self._endpoint_names:
raise ValueError(f"Endpoint name '{name}' is already taken.")
if input_signature:
decorated_fn = tf.function(fn, input_signature=input_signature)
self._endpoint_signatures[name] = input_signature
else:
if isinstance(fn, tf.types.experimental.GenericFunction):
if not fn._list_all_concrete_functions():
raise ValueError(
f"The provided tf.function '{fn}' "
"has never been called. "
"To specify the expected shape and dtype "
"of the function's arguments, "
"you must either provide a function that "
"has been called at least once, or alternatively pass "
"an `input_signature` argument in `add_endpoint()`."
)
decorated_fn = fn
else:
raise ValueError(
"If the `fn` argument provided is not a `tf.function`, "
"you must provide an `input_signature` argument to "
"specify the shape and dtype of the function arguments. "
"Example:\n\n"
"export_archive.add_endpoint(\n"
" name='call',\n"
" fn=model.call,\n"
" input_signature=[\n"
" tf.TensorSpec(\n"
" shape=(None, 224, 224, 3),\n"
" dtype=tf.float32,\n"
" )\n"
" ],\n"
")"
)
setattr(self, name, decorated_fn)
self._endpoint_names.append(name)
def add_variable_collection(self, name, variables):
"""Register a set of variables to be retrieved after reloading.
Arguments:
name: The string name for the collection.
variables: A tuple/list/set of `tf.Variable` instances.
Example:
```python
export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
# Save a variable collection
export_archive.add_variable_collection(
name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")
# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables
```
"""
if not isinstance(variables, (list, tuple, set)):
raise ValueError(
"Expected `variables` to be a list/tuple/set. "
f"Received instead object of type '{type(variables)}'."
)
# Ensure that all variables added are either tf.Variables
# or Variables created by Keras Core with the TF backend.
if not all(isinstance(v, tf.Variable) for v in variables) and not (
all(
isinstance(v, (tf.Variable, backend.Variable))
for v in variables
)
and (backend.backend() == "tensorflow")
):
raise ValueError(
"Expected all elements in `variables` to be "
"`tf.Variable` instances. Found instead the following types: "
f"{list(set(type(v) for v in variables))}"
)
setattr(self, name, list(variables))
def write_out(self, filepath, options=None):
"""Write the corresponding SavedModel to disk.
Arguments:
filepath: `str` or `pathlib.Path` object.
Path where to save the artifact.
options: `tf.saved_model.SaveOptions` object that specifies
SavedModel saving options.
**Note on TF-Serving**: all endpoints registered via `add_endpoint()`
are made visible for TF-Serving in the SavedModel artifact. In addition,
the first endpoint registered is made visible under the alias
`"serving_default"` (unless an endpoint with the name
`"serving_default"` was already registered manually),
since TF-Serving requires this endpoint to be set.
"""
if not self._endpoint_names:
raise ValueError(
"No endpoints have been set yet. Call add_endpoint()."
)
self._filter_and_track_resources()
signatures = {}
for name in self._endpoint_names:
signatures[name] = self._get_concrete_fn(name)
# Add "serving_default" signature key for TFServing
if "serving_default" not in self._endpoint_names:
signatures["serving_default"] = self._get_concrete_fn(
self._endpoint_names[0]
)
tf.saved_model.save(
self, filepath, options=options, signatures=signatures
)
# Print out available endpoints
endpoints = "\n\n".join(
_print_signature(getattr(self, name), name)
for name in self._endpoint_names
)
io_utils.print_msg(
f"Saved artifact at '{filepath}'. "
"The following endpoints are available:\n\n"
f"{endpoints}"
)
def _get_concrete_fn(self, endpoint):
"""Workaround for some SavedModel quirks."""
if endpoint in self._endpoint_signatures:
return getattr(self, endpoint)
else:
traces = getattr(self, endpoint)._trackable_children("saved_model")
return list(traces.values())[0]
def _get_variables_used_by_endpoints(self):
fns = [self._get_concrete_fn(name) for name in self._endpoint_names]
return _list_variables_used_by_fns(fns)
def _filter_and_track_resources(self):
"""Track resources used by endpoints / referenced in `track()` calls."""
# Start by extracting variables from endpoints.
fns = [self._get_concrete_fn(name) for name in self._endpoint_names]
tvs, ntvs = _list_variables_used_by_fns(fns)
self._all_variables = list(tvs + ntvs)
# Next, track lookup tables.
# Hopefully, one day this will be automated at the tf.function level.
self._misc_assets = []
from keras_core.layers import IntegerLookup
from keras_core.layers import StringLookup
from keras_core.layers import TextVectorization
if hasattr(self, "_tracked"):
for root in self._tracked:
descendants = tf.train.TrackableView(root).descendants()
for trackable in descendants:
if isinstance(
trackable,
(IntegerLookup, StringLookup, TextVectorization),
):
self._misc_assets.append(trackable)
def export_model(model, filepath):
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (Functional, Sequential)):
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint("serve", model.__call__, input_signature)
else:
save_spec = model._get_save_spec()
if not save_spec:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)
input_signature = [save_spec]
export_archive.add_endpoint("serve", model.__call__, input_signature)
export_archive.write_out(filepath)
class TFSMLayer(Layer):
"""Reload a Keras model/layer that was saved via SavedModel / ExportArchive.
Arguments:
filepath: `str` or `pathlib.Path` object. The path to the SavedModel.
call_endpoint: Name of the endpoint to use as the `call()` method
of the reloaded layer. If the SavedModel was created
via `model.export()`,
then the default endpoint name is `'serve'`. In other cases
it may be named `'serving_default'`.
Example:
```python
model.export("path/to/artifact")
reloaded_layer = TFSMLayer("path/to/artifact")
outputs = reloaded_layer(inputs)
```
The reloaded object can be used like a regular Keras layer, and supports
training/fine-tuning of its trainable weights. Note that the reloaded
object retains none of the internal structure or custom methods of the
original object -- it's a brand new layer created around the saved
function.
**Limitations:**
* Only call endpoints with a single `inputs` tensor argument
(which may optionally be a dict/tuple/list of tensors) are supported.
For endpoints with multiple separate input tensor arguments, consider
subclassing `TFSMLayer` and implementing a `call()` method with a
custom signature.
* If you need training-time behavior to differ from inference-time behavior
(i.e. if you need the reloaded object to support a `training=True` argument
in `__call__()`), make sure that the training-time call function is
saved as a standalone endpoint in the artifact, and provide its name
to the `TFSMLayer` via the `call_training_endpoint` argument.
"""
def __init__(
self,
filepath,
call_endpoint="serve",
call_training_endpoint=None,
trainable=True,
name=None,
dtype=None,
):
# Initialize an empty layer, then add_weight() etc. as needed.
super().__init__(trainable=trainable, name=name, dtype=dtype)
self._reloaded_obj = tf.saved_model.load(filepath)
self.filepath = filepath
self.call_endpoint = call_endpoint
self.call_training_endpoint = call_training_endpoint
# Resolve the call function.
if hasattr(self._reloaded_obj, call_endpoint):
# Case 1: it's set as an attribute.
self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint)
elif call_endpoint in self._reloaded_obj.signatures:
# Case 2: it's listed in the `signatures` field.
self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint]
else:
raise ValueError(
f"The endpoint '{call_endpoint}' is neither an "
"attribute of the reloaded SavedModel, nor an entry "
"in the `signatures` field of the reloaded SavedModel. "
)
# Resolving the training function.
if call_training_endpoint:
if hasattr(self._reloaded_obj, call_training_endpoint):
self.call_training_endpoint_fn = getattr(
self._reloaded_obj, call_training_endpoint
)
elif call_training_endpoint in self._reloaded_obj.signatures:
self.call_training_endpoint_fn = self._reloaded_obj.signatures[
call_training_endpoint
]
else:
raise ValueError(
f"The endpoint '{call_training_endpoint}' is "
"neither an attribute of the reloaded SavedModel, "
"nor an entry in the `signatures` field of "
"the reloaded SavedModel. "
)
# Add trainable and non-trainable weights from the call_endpoint_fn.
all_fns = [self.call_endpoint_fn]
if call_training_endpoint:
all_fns.append(self.call_training_endpoint_fn)
tvs, ntvs = _list_variables_used_by_fns(all_fns)
for v in tvs:
self._add_existing_weight(v)
for v in ntvs:
self._add_existing_weight(v)
self.built = True
def _add_existing_weight(self, weight):
"""Tracks an existing weight."""
self._track_variable(weight)
def call(self, inputs, training=False, **kwargs):
if training:
if self.call_training_endpoint:
return self.call_training_endpoint_fn(inputs, **kwargs)
return self.call_endpoint_fn(inputs, **kwargs)
def get_config(self):
base_config = super().get_config()
config = {
# Note: this is not intended to be portable.
"filepath": self.filepath,
"call_endpoint": self.call_endpoint,
"call_training_endpoint": self.call_training_endpoint,
}
return {**base_config, **config}
def _make_tensor_spec(x):
return tf.TensorSpec(x.shape, dtype=x.dtype, name=x.name)
def _print_signature(fn, name):
concrete_fn = fn._list_all_concrete_functions()[0]
pprinted_signature = concrete_fn.pretty_printed_signature(verbose=True)
lines = pprinted_signature.split("\n")
lines = [f"* Endpoint '{name}'"] + lines[1:]
endpoint = "\n".join(lines)
return endpoint
def _list_variables_used_by_fns(fns):
trainable_variables = []
non_trainable_variables = []
trainable_variables_ids = set()
non_trainable_variables_ids = set()
for fn in fns:
if hasattr(fn, "concrete_functions"):
concrete_functions = fn.concrete_functions
elif hasattr(fn, "get_concrete_function"):
concrete_functions = [fn.get_concrete_function()]
else:
concrete_functions = [fn]
for concrete_fn in concrete_functions:
for v in concrete_fn.trainable_variables:
if id(v) not in trainable_variables_ids:
trainable_variables.append(v)
trainable_variables_ids.add(id(v))
for v in concrete_fn.variables:
if (
id(v) not in trainable_variables_ids
and id(v) not in non_trainable_variables_ids
):
non_trainable_variables.append(v)
non_trainable_variables_ids.add(id(v))
return trainable_variables, non_trainable_variables

@ -0,0 +1,595 @@
"""Tests for inference-only model/layer exporting utilities."""
import os
import numpy as np
import pytest
import tensorflow as tf
from keras_core import backend
from keras_core import layers
from keras_core import models
from keras_core import testing
from keras_core import utils
from keras_core.export import export_lib
from keras_core.saving import saving_lib
def get_model():
layer_list = [
layers.Dense(10, activation="relu"),
layers.BatchNormalization(),
layers.Dense(1, activation="sigmoid"),
]
model = models.Sequential(layer_list)
return model
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Export only currently supports the TF backend.",
)
class ExportArchiveTest(testing.TestCase):
def test_standard_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input).numpy()
export_lib.export_model(model, temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6
)
def test_low_level_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input).numpy()
# Test variable tracking
export_archive = export_lib.ExportArchive()
export_archive.track(model)
self.assertLen(export_archive.variables, 8)
self.assertLen(export_archive.trainable_variables, 6)
self.assertLen(export_archive.non_trainable_variables, 2)
@tf.function()
def my_endpoint(x):
return model(x)
# Test registering an endpoint that is a tf.function (called)
my_endpoint(ref_input) # Trace fn
export_archive.add_endpoint(
"call",
my_endpoint,
)
export_archive.write_out(temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertFalse(hasattr(revived_model, "_tracked"))
self.assertAllClose(
ref_output, revived_model.call(ref_input).numpy(), atol=1e-6
)
self.assertLen(revived_model.variables, 8)
self.assertLen(revived_model.trainable_variables, 6)
self.assertLen(revived_model.non_trainable_variables, 2)
# Test registering an endpoint that is NOT a tf.function
export_archive = export_lib.ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
"call",
model.call,
input_signature=[
tf.TensorSpec(
shape=(None, 10),
dtype=tf.float32,
)
],
)
export_archive.write_out(temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output, revived_model.call(ref_input).numpy(), atol=1e-6
)
def test_layer_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer")
layer = layers.BatchNormalization()
ref_input = tf.random.normal((3, 10))
ref_output = layer(ref_input).numpy() # Build layer (important)
export_archive = export_lib.ExportArchive()
export_archive.track(layer)
export_archive.add_endpoint(
"call",
layer.call,
input_signature=[
tf.TensorSpec(
shape=(None, 10),
dtype=tf.float32,
)
],
)
export_archive.write_out(temp_filepath)
revived_layer = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output, revived_layer.call(ref_input).numpy(), atol=1e-6
)
def test_multi_input_output_functional_model(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
x1 = layers.Input((2,))
x2 = layers.Input((2,))
y1 = layers.Dense(3)(x1)
y2 = layers.Dense(3)(x2)
model = models.Model([x1, x2], [y1, y2])
ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))]
ref_outputs = model(ref_inputs)
export_archive = export_lib.ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
"serve",
model.call,
input_signature=[
[
tf.TensorSpec(
shape=(None, 2),
dtype=tf.float32,
),
tf.TensorSpec(
shape=(None, 2),
dtype=tf.float32,
),
]
],
)
export_archive.write_out(temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_outputs[0].numpy(),
revived_model.serve(ref_inputs)[0].numpy(),
atol=1e-6,
)
self.assertAllClose(
ref_outputs[1].numpy(),
revived_model.serve(ref_inputs)[1].numpy(),
atol=1e-6,
)
# Now test dict inputs
model = models.Model({"x1": x1, "x2": x2}, [y1, y2])
ref_inputs = {
"x1": tf.random.normal((3, 2)),
"x2": tf.random.normal((3, 2)),
}
ref_outputs = model(ref_inputs)
export_archive = export_lib.ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
"serve",
model.call,
input_signature=[
{
"x1": tf.TensorSpec(
shape=(None, 2),
dtype=tf.float32,
),
"x2": tf.TensorSpec(
shape=(None, 2),
dtype=tf.float32,
),
}
],
)
export_archive.write_out(temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_outputs[0].numpy(),
revived_model.serve(ref_inputs)[0].numpy(),
atol=1e-6,
)
self.assertAllClose(
ref_outputs[1].numpy(),
revived_model.serve(ref_inputs)[1].numpy(),
atol=1e-6,
)
# def test_model_with_lookup_table(self):
# tf.debugging.disable_traceback_filtering()
# temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
# text_vectorization = layers.TextVectorization()
# text_vectorization.adapt(["one two", "three four", "five six"])
# model = models.Sequential(
# [
# layers.Input(shape=(), dtype="string"),
# text_vectorization,
# layers.Embedding(10, 32),
# layers.Dense(1),
# ]
# )
# ref_input = tf.convert_to_tensor(["one two three four"])
# ref_output = model(ref_input).numpy()
# export_lib.export_model(model, temp_filepath)
# revived_model = tf.saved_model.load(temp_filepath)
# self.assertAllClose(
# ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6
# )
def test_track_multiple_layers(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
layer_1 = layers.Dense(2)
ref_input_1 = tf.random.normal((3, 4))
ref_output_1 = layer_1(ref_input_1).numpy()
layer_2 = layers.Dense(3)
ref_input_2 = tf.random.normal((3, 5))
ref_output_2 = layer_2(ref_input_2).numpy()
export_archive = export_lib.ExportArchive()
export_archive.add_endpoint(
"call_1",
layer_1.call,
input_signature=[
tf.TensorSpec(
shape=(None, 4),
dtype=tf.float32,
),
],
)
export_archive.add_endpoint(
"call_2",
layer_2.call,
input_signature=[
tf.TensorSpec(
shape=(None, 5),
dtype=tf.float32,
),
],
)
export_archive.write_out(temp_filepath)
revived_layer = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output_1,
revived_layer.call_1(ref_input_1).numpy(),
atol=1e-6,
)
self.assertAllClose(
ref_output_2,
revived_layer.call_2(ref_input_2).numpy(),
atol=1e-6,
)
def test_non_standard_layer_signature(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer")
layer = layers.MultiHeadAttention(2, 2)
x1 = tf.random.normal((3, 2, 2))
x2 = tf.random.normal((3, 2, 2))
ref_output = layer(x1, x2).numpy() # Build layer (important)
export_archive = export_lib.ExportArchive()
export_archive.track(layer)
export_archive.add_endpoint(
"call",
layer.call,
input_signature=[
tf.TensorSpec(
shape=(None, 2, 2),
dtype=tf.float32,
),
tf.TensorSpec(
shape=(None, 2, 2),
dtype=tf.float32,
),
],
)
export_archive.write_out(temp_filepath)
revived_layer = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output,
revived_layer.call(query=x1, value=x2).numpy(),
atol=1e-6,
)
def test_variable_collection(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = models.Sequential(
[
layers.Input((10,)),
layers.Dense(2),
layers.Dense(2),
]
)
# Test variable tracking
export_archive = export_lib.ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
"call",
model.call,
input_signature=[
tf.TensorSpec(
shape=(None, 10),
dtype=tf.float32,
)
],
)
export_archive.add_variable_collection(
"my_vars", model.layers[1].weights
)
self.assertLen(export_archive.my_vars, 2)
export_archive.write_out(temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertLen(revived_model.my_vars, 2)
def test_export_model_errors(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
# Model has not been built
model = models.Sequential([layers.Dense(2)])
with self.assertRaisesRegex(ValueError, "It must be built"):
export_lib.export_model(model, temp_filepath)
# Subclassed model has not been called
class MyModel(models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = layers.Dense(2)
def build(self, input_shape):
self.dense.build(input_shape)
self.built = True
def call(self, x):
return self.dense(x)
model = MyModel()
model.build((2, 3))
with self.assertRaisesRegex(ValueError, "It must be called"):
export_lib.export_model(model, temp_filepath)
def test_export_archive_errors(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = models.Sequential([layers.Dense(2)])
model(tf.random.normal((2, 3)))
# Endpoint name reuse
export_archive = export_lib.ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
"call",
model.call,
input_signature=[
tf.TensorSpec(
shape=(None, 3),
dtype=tf.float32,
)
],
)
with self.assertRaisesRegex(ValueError, "already taken"):
export_archive.add_endpoint(
"call",
model.call,
input_signature=[
tf.TensorSpec(
shape=(None, 3),
dtype=tf.float32,
)
],
)
# Write out with no endpoints
export_archive = export_lib.ExportArchive()
export_archive.track(model)
with self.assertRaisesRegex(ValueError, "No endpoints have been set"):
export_archive.write_out(temp_filepath)
# Invalid object type
with self.assertRaisesRegex(ValueError, "Invalid resource type"):
export_archive = export_lib.ExportArchive()
export_archive.track("model")
# Set endpoint with no input signature
export_archive = export_lib.ExportArchive()
export_archive.track(model)
with self.assertRaisesRegex(
ValueError, "you must provide an `input_signature`"
):
export_archive.add_endpoint(
"call",
model.call,
)
# Set endpoint that has never been called
export_archive = export_lib.ExportArchive()
export_archive.track(model)
@tf.function()
def my_endpoint(x):
return model(x)
export_archive = export_lib.ExportArchive()
export_archive.track(model)
with self.assertRaisesRegex(
ValueError, "you must either provide a function"
):
export_archive.add_endpoint(
"call",
my_endpoint,
)
def test_export_no_assets(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
# Case where there are legitimately no assets.
model = models.Sequential([layers.Flatten()])
model(tf.random.normal((2, 3)))
export_archive = export_lib.ExportArchive()
export_archive.add_endpoint(
"call",
model.call,
input_signature=[
tf.TensorSpec(
shape=(None, 3),
dtype=tf.float32,
)
],
)
export_archive.write_out(temp_filepath)
def test_model_export_method(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input).numpy()
model.export(temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6
)
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Export only currently supports the TF backend.",
)
class TestTFSMLayer(testing.TestCase):
def test_reloading_export_archive(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input).numpy()
export_lib.export_model(model, temp_filepath)
reloaded_layer = export_lib.TFSMLayer(temp_filepath)
self.assertAllClose(
reloaded_layer(ref_input).numpy(), ref_output, atol=1e-7
)
self.assertLen(reloaded_layer.weights, len(model.weights))
self.assertLen(
reloaded_layer.trainable_weights, len(model.trainable_weights)
)
self.assertLen(
reloaded_layer.non_trainable_weights,
len(model.non_trainable_weights),
)
# TODO(nkovela): Expand test coverage/debug fine-tuning and
# non-trainable use cases here.
def test_reloading_default_saved_model(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input).numpy()
tf.saved_model.save(model, temp_filepath)
reloaded_layer = export_lib.TFSMLayer(
temp_filepath, call_endpoint="serving_default"
)
# The output is a dict, due to the nature of SavedModel saving.
new_output = reloaded_layer(ref_input)
self.assertAllClose(
new_output[list(new_output.keys())[0]].numpy(),
ref_output,
atol=1e-7,
)
self.assertLen(reloaded_layer.weights, len(model.weights))
self.assertLen(
reloaded_layer.trainable_weights, len(model.trainable_weights)
)
self.assertLen(
reloaded_layer.non_trainable_weights,
len(model.non_trainable_weights),
)
def test_call_training(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
utils.set_random_seed(1337)
model = models.Sequential(
[
layers.Input((10,)),
layers.Dense(10),
layers.Dropout(0.99999),
]
)
export_archive = export_lib.ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
)
export_archive.write_out(temp_filepath)
reloaded_layer = export_lib.TFSMLayer(
temp_filepath,
call_endpoint="call_inference",
call_training_endpoint="call_training",
)
inference_output = reloaded_layer(
tf.random.normal((1, 10)), training=False
)
training_output = reloaded_layer(
tf.random.normal((1, 10)), training=True
)
self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7)
self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7)
def test_serialization(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input).numpy()
export_lib.export_model(model, temp_filepath)
reloaded_layer = export_lib.TFSMLayer(temp_filepath)
# Test reinstantiation from config
config = reloaded_layer.get_config()
rereloaded_layer = export_lib.TFSMLayer.from_config(config)
self.assertAllClose(
rereloaded_layer(ref_input).numpy(), ref_output, atol=1e-7
)
# Test whole model saving with reloaded layer inside
model = models.Sequential([reloaded_layer])
temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras")
model.save(temp_model_filepath, save_format="keras_v3")
reloaded_model = saving_lib.load_model(
temp_model_filepath,
custom_objects={"TFSMLayer": export_lib.TFSMLayer},
)
self.assertAllClose(
reloaded_model(ref_input).numpy(), ref_output, atol=1e-7
)
def test_errors(self):
# Test missing call endpoint
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = models.Sequential([layers.Input((2,)), layers.Dense(3)])
export_lib.export_model(model, temp_filepath)
with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"):
export_lib.TFSMLayer(temp_filepath, call_endpoint="wrong")
# Test missing call training endpoint
with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"):
export_lib.TFSMLayer(
temp_filepath,
call_endpoint="serve",
call_training_endpoint="wrong",
)

@ -2,9 +2,9 @@ import numpy as np
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend.tensorflow import tf_utils
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
from keras_core.utils import tf_utils
from keras_core.utils.module_utils import tensorflow as tf

@ -101,8 +101,7 @@ class Functional(Function, Model):
f"including invalid value {v} of type {type(v)}"
)
if k != v.name:
# TODO: maybe make this a warning
raise ValueError(
warnings.warn(
"When providing `inputs` as a dict, all keys in the "
"dict must match the names of the corresponding "
f"tensors. Received key '{k}' mapping to value {v} "

@ -87,11 +87,6 @@ class FunctionalTest(testing.TestCase):
):
model = Functional({"aa": [input_a], "bb": input_b}, outputs)
with self.assertRaisesRegex(
ValueError, "all keys in the dict must match the names"
):
model = Functional({"aa": input_a, "bb": input_b}, outputs)
model = Functional({"a": input_a, "b": input_b}, outputs)
# Eager call

@ -448,12 +448,42 @@ class Model(Trainer, Layer):
return json.dumps(model_config, **kwargs)
def export(self, filepath, format="tf_saved_model"):
raise NotImplementedError(
"The export() method is not yet supported. It will "
"be added in the next version. For the time being, you "
"can use `tf.saved_model.save(model)` to save a "
"TensorFlow SavedModel for your Keras Core model."
)
"""[TF backend only]* Create a TF SavedModel artifact for inference
(e.g. via TF-Serving).
**Note:** This can currently only be used with the TF backend.
This method lets you export a model to a lightweight SavedModel artifact
that contains the model's forward pass only (its `call()` method)
and can be served via e.g. TF-Serving. The forward pass is registered
under the name `serve()` (see example below).
The original code of the model (including any custom layers you may
have used) is *no longer* necessary to reload the artifact -- it is
entirely standalone.
Args:
filepath: `str` or `pathlib.Path` object. Path where to save
the artifact.
Example:
```python
# Create the artifact
model.export("path/to/location")
# Later, in a different process / environment...
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)
```
If you would like to customize your serving endpoints, you can
use the lower-level `keras_core.export.ExportArchive` class. The
`export()` method relies on `ExportArchive` internally.
"""
from keras_core.export import export_lib
export_lib.export_model(self, filepath)
@classmethod
def from_config(cls, config, custom_objects=None):