68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
import numpy as np
|
|
|
|
from keras_core import layers
|
|
from keras_core import testing
|
|
from keras_core.layers.core.input_layer import Input
|
|
from keras_core.models.functional import Functional
|
|
from keras_core.models.model import Model
|
|
from keras_core.models.model import model_from_json
|
|
|
|
|
|
class ModelTest(testing.TestCase):
|
|
def _get_model(self):
|
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
x = input_a + input_b
|
|
x = layers.Dense(5)(x)
|
|
outputs = layers.Dense(4)(x)
|
|
model = Model([input_a, input_b], outputs)
|
|
return model
|
|
|
|
def test_functional_rerouting(self):
|
|
model = self._get_model()
|
|
self.assertTrue(isinstance(model, Functional))
|
|
|
|
def test_json_serialization(self):
|
|
model = self._get_model()
|
|
json_string = model.to_json()
|
|
new_model = model_from_json(json_string)
|
|
self.assertEqual(json_string, new_model.to_json())
|
|
|
|
def test_tuple_input_model_subclass(self):
|
|
# https://github.com/keras-team/keras-core/issues/324
|
|
|
|
class MultiInputModel(Model):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.dense1 = layers.Dense(4)
|
|
|
|
def call(self, inputs):
|
|
a, b = inputs
|
|
r = self.dense1(a)
|
|
return layers.concatenate([r, b])
|
|
|
|
model = MultiInputModel()
|
|
x1 = np.random.rand(3, 3)
|
|
x2 = np.random.rand(3, 2)
|
|
out = model((x1, x2))
|
|
self.assertEqual(out.shape, (3, 6))
|
|
|
|
def test_reviving_functional_from_config_custom_layer(self):
|
|
class CustomDense(layers.Layer):
|
|
def __init__(self, units, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.dense = layers.Dense(units)
|
|
|
|
def call(self, x):
|
|
return self.dense(x)
|
|
|
|
inputs = layers.Input((4,))
|
|
outputs = CustomDense(10)(inputs)
|
|
model = Model(inputs, outputs)
|
|
config = model.get_config()
|
|
|
|
new_model = Model.from_config(
|
|
config, custom_objects={"CustomDense": CustomDense}
|
|
)
|
|
self.assertTrue(isinstance(new_model, Functional))
|