From 10e49077a75d444e66f18c4b5c58b99b29da2037 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sun, 4 Jun 2023 12:00:10 -0700 Subject: [PATCH] Add model.to_json() --- keras_core/layers/core/input_layer.py | 2 +- keras_core/models/model.py | 48 +++++++++++++++++++++++++++ keras_core/models/model_test.py | 13 +++++++- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/keras_core/layers/core/input_layer.py b/keras_core/layers/core/input_layer.py index 883074626..ee09a3ff0 100644 --- a/keras_core/layers/core/input_layer.py +++ b/keras_core/layers/core/input_layer.py @@ -33,7 +33,7 @@ class InputLayer(Layer): if shape: shape = backend.standardize_shape(shape) batch_shape = (batch_size,) + shape - self.batch_shape = batch_shape + self.batch_shape = tuple(batch_shape) self._dtype = backend.standardize_dtype(dtype) if input_tensor is not None: diff --git a/keras_core/models/model.py b/keras_core/models/model.py index 03344fce5..d1ada91eb 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -1,3 +1,4 @@ +import json import os import warnings @@ -341,11 +342,58 @@ class Model(Trainer, Layer): stacklevel=2, ) + def to_json(self, **kwargs): + """Returns a JSON string containing the network configuration. + + To load a network from a JSON save file, use + `keras.models.model_from_json(json_string, custom_objects={...})`. + + Args: + **kwargs: Additional keyword arguments to be passed to + `json.dumps()`. + + Returns: + A JSON string. + """ + from keras_core.saving import serialization_lib + + model_config = serialization_lib.serialize_keras_object(self) + return json.dumps(model_config, **kwargs) + @traceback_utils.filter_traceback def export(self, filepath): raise NotImplementedError +@keras_core_export("keras_core.models.model_from_json") +def model_from_json(json_string, custom_objects=None): + """Parses a JSON model configuration string and returns a model instance. + + Usage: + + >>> model = keras_core.Sequential([ + ... keras_core.layers.Dense(5, input_shape=(3,)), + ... keras_core.layers.Softmax()]) + >>> config = model.to_json() + >>> loaded_model = keras_core.models.model_from_json(config) + + Args: + json_string: JSON string encoding a model configuration. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + """ + from keras_core.saving import serialization_lib + + model_config = json.loads(json_string) + return serialization_lib.deserialize_keras_object( + model_config, custom_objects=custom_objects + ) + + def functional_init_arguments(args, kwargs): return ( (len(args) == 2) diff --git a/keras_core/models/model_test.py b/keras_core/models/model_test.py index 32f51c61f..370d1c00b 100644 --- a/keras_core/models/model_test.py +++ b/keras_core/models/model_test.py @@ -3,14 +3,25 @@ from keras_core import testing from keras_core.layers.core.input_layer import Input from keras_core.models.functional import Functional from keras_core.models.model import Model +from keras_core.models.model import model_from_json class ModelTest(testing.TestCase): - def test_functional_rerouting(self): + def _get_model(self): input_a = Input(shape=(3,), batch_size=2, name="input_a") input_b = Input(shape=(3,), batch_size=2, name="input_b") x = input_a + input_b x = layers.Dense(5)(x) outputs = layers.Dense(4)(x) model = Model([input_a, input_b], outputs) + return model + + def test_functional_rerouting(self): + model = self._get_model() self.assertTrue(isinstance(model, Functional)) + + def test_json_serialization(self): + model = self._get_model() + json_string = model.to_json() + new_model = model_from_json(json_string) + self.assertEqual(json_string, new_model.to_json())