keras/keras_core/models/model_test.py
2023-06-14 11:06:44 -07:00

68 lines
2.2 KiB
Python

import numpy as np
from keras_core import layers
from keras_core import testing
from keras_core.layers.core.input_layer import Input
from keras_core.models.functional import Functional
from keras_core.models.model import Model
from keras_core.models.model import model_from_json
class ModelTest(testing.TestCase):
def _get_model(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Model([input_a, input_b], outputs)
return model
def test_functional_rerouting(self):
model = self._get_model()
self.assertTrue(isinstance(model, Functional))
def test_json_serialization(self):
model = self._get_model()
json_string = model.to_json()
new_model = model_from_json(json_string)
self.assertEqual(json_string, new_model.to_json())
def test_tuple_input_model_subclass(self):
# https://github.com/keras-team/keras-core/issues/324
class MultiInputModel(Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense1 = layers.Dense(4)
def call(self, inputs):
a, b = inputs
r = self.dense1(a)
return layers.concatenate([r, b])
model = MultiInputModel()
x1 = np.random.rand(3, 3)
x2 = np.random.rand(3, 2)
out = model((x1, x2))
self.assertEqual(out.shape, (3, 6))
def test_reviving_functional_from_config_custom_layer(self):
class CustomDense(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.dense = layers.Dense(units)
def call(self, x):
return self.dense(x)
inputs = layers.Input((4,))
outputs = CustomDense(10)(inputs)
model = Model(inputs, outputs)
config = model.get_config()
new_model = Model.from_config(
config, custom_objects={"CustomDense": CustomDense}
)
self.assertTrue(isinstance(new_model, Functional))