Add model.to_json()

This commit is contained in:
Francois Chollet 2023-06-04 12:00:10 -07:00
parent c63de3adca
commit 10e49077a7
3 changed files with 61 additions and 2 deletions

@ -33,7 +33,7 @@ class InputLayer(Layer):
if shape:
shape = backend.standardize_shape(shape)
batch_shape = (batch_size,) + shape
self.batch_shape = batch_shape
self.batch_shape = tuple(batch_shape)
self._dtype = backend.standardize_dtype(dtype)
if input_tensor is not None:

@ -1,3 +1,4 @@
import json
import os
import warnings
@ -341,11 +342,58 @@ class Model(Trainer, Layer):
stacklevel=2,
)
def to_json(self, **kwargs):
"""Returns a JSON string containing the network configuration.
To load a network from a JSON save file, use
`keras.models.model_from_json(json_string, custom_objects={...})`.
Args:
**kwargs: Additional keyword arguments to be passed to
`json.dumps()`.
Returns:
A JSON string.
"""
from keras_core.saving import serialization_lib
model_config = serialization_lib.serialize_keras_object(self)
return json.dumps(model_config, **kwargs)
@traceback_utils.filter_traceback
def export(self, filepath):
raise NotImplementedError
@keras_core_export("keras_core.models.model_from_json")
def model_from_json(json_string, custom_objects=None):
"""Parses a JSON model configuration string and returns a model instance.
Usage:
>>> model = keras_core.Sequential([
... keras_core.layers.Dense(5, input_shape=(3,)),
... keras_core.layers.Softmax()])
>>> config = model.to_json()
>>> loaded_model = keras_core.models.model_from_json(config)
Args:
json_string: JSON string encoding a model configuration.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
Returns:
A Keras model instance (uncompiled).
"""
from keras_core.saving import serialization_lib
model_config = json.loads(json_string)
return serialization_lib.deserialize_keras_object(
model_config, custom_objects=custom_objects
)
def functional_init_arguments(args, kwargs):
return (
(len(args) == 2)

@ -3,14 +3,25 @@ from keras_core import testing
from keras_core.layers.core.input_layer import Input
from keras_core.models.functional import Functional
from keras_core.models.model import Model
from keras_core.models.model import model_from_json
class ModelTest(testing.TestCase):
def test_functional_rerouting(self):
def _get_model(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Model([input_a, input_b], outputs)
return model
def test_functional_rerouting(self):
model = self._get_model()
self.assertTrue(isinstance(model, Functional))
def test_json_serialization(self):
model = self._get_model()
json_string = model.to_json()
new_model = model_from_json(json_string)
self.assertEqual(json_string, new_model.to_json())