Add model.to_json()
This commit is contained in:
parent
c63de3adca
commit
10e49077a7
@ -33,7 +33,7 @@ class InputLayer(Layer):
|
||||
if shape:
|
||||
shape = backend.standardize_shape(shape)
|
||||
batch_shape = (batch_size,) + shape
|
||||
self.batch_shape = batch_shape
|
||||
self.batch_shape = tuple(batch_shape)
|
||||
self._dtype = backend.standardize_dtype(dtype)
|
||||
|
||||
if input_tensor is not None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
@ -341,11 +342,58 @@ class Model(Trainer, Layer):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
"""Returns a JSON string containing the network configuration.
|
||||
|
||||
To load a network from a JSON save file, use
|
||||
`keras.models.model_from_json(json_string, custom_objects={...})`.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments to be passed to
|
||||
`json.dumps()`.
|
||||
|
||||
Returns:
|
||||
A JSON string.
|
||||
"""
|
||||
from keras_core.saving import serialization_lib
|
||||
|
||||
model_config = serialization_lib.serialize_keras_object(self)
|
||||
return json.dumps(model_config, **kwargs)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def export(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@keras_core_export("keras_core.models.model_from_json")
|
||||
def model_from_json(json_string, custom_objects=None):
|
||||
"""Parses a JSON model configuration string and returns a model instance.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> model = keras_core.Sequential([
|
||||
... keras_core.layers.Dense(5, input_shape=(3,)),
|
||||
... keras_core.layers.Softmax()])
|
||||
>>> config = model.to_json()
|
||||
>>> loaded_model = keras_core.models.model_from_json(config)
|
||||
|
||||
Args:
|
||||
json_string: JSON string encoding a model configuration.
|
||||
custom_objects: Optional dictionary mapping names
|
||||
(strings) to custom classes or functions to be
|
||||
considered during deserialization.
|
||||
|
||||
Returns:
|
||||
A Keras model instance (uncompiled).
|
||||
"""
|
||||
from keras_core.saving import serialization_lib
|
||||
|
||||
model_config = json.loads(json_string)
|
||||
return serialization_lib.deserialize_keras_object(
|
||||
model_config, custom_objects=custom_objects
|
||||
)
|
||||
|
||||
|
||||
def functional_init_arguments(args, kwargs):
|
||||
return (
|
||||
(len(args) == 2)
|
||||
|
@ -3,14 +3,25 @@ from keras_core import testing
|
||||
from keras_core.layers.core.input_layer import Input
|
||||
from keras_core.models.functional import Functional
|
||||
from keras_core.models.model import Model
|
||||
from keras_core.models.model import model_from_json
|
||||
|
||||
|
||||
class ModelTest(testing.TestCase):
|
||||
def test_functional_rerouting(self):
|
||||
def _get_model(self):
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
model = Model([input_a, input_b], outputs)
|
||||
return model
|
||||
|
||||
def test_functional_rerouting(self):
|
||||
model = self._get_model()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
|
||||
def test_json_serialization(self):
|
||||
model = self._get_model()
|
||||
json_string = model.to_json()
|
||||
new_model = model_from_json(json_string)
|
||||
self.assertEqual(json_string, new_model.to_json())
|
||||
|
Loading…
Reference in New Issue
Block a user