2023-04-12 21:27:30 +00:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from keras_core import backend
|
|
|
|
from keras_core import layers
|
|
|
|
from keras_core import testing
|
|
|
|
from keras_core.layers.core.input_layer import Input
|
2023-04-21 22:01:17 +00:00
|
|
|
from keras_core.models import Functional
|
|
|
|
from keras_core.models import Model
|
2023-04-12 21:27:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
class FunctionalTest(testing.TestCase):
|
2023-04-12 22:20:56 +00:00
|
|
|
def test_basic_flow_multi_input(self):
|
2023-04-12 21:27:30 +00:00
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
|
|
x = input_a + input_b
|
|
|
|
x = layers.Dense(5)(x)
|
|
|
|
outputs = layers.Dense(4)(x)
|
2023-04-21 22:01:17 +00:00
|
|
|
model = Functional([input_a, input_b], outputs, name="basic")
|
2023-04-13 17:59:51 +00:00
|
|
|
model.summary()
|
2023-04-12 21:27:30 +00:00
|
|
|
|
2023-04-21 22:01:17 +00:00
|
|
|
self.assertEqual(model.name, "basic")
|
|
|
|
self.assertTrue(isinstance(model, Functional))
|
|
|
|
self.assertTrue(isinstance(model, Model))
|
|
|
|
|
2023-04-12 21:27:30 +00:00
|
|
|
# Eager call
|
|
|
|
in_val = [np.random.random((2, 3)), np.random.random((2, 3))]
|
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertEqual(out_val.shape, (2, 4))
|
|
|
|
|
|
|
|
# Symbolic call
|
|
|
|
input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2")
|
|
|
|
input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2")
|
|
|
|
in_val = [input_a_2, input_b_2]
|
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertEqual(out_val.shape, (2, 4))
|
|
|
|
|
2023-04-12 22:20:56 +00:00
|
|
|
def test_basic_flow_multi_output(self):
|
|
|
|
inputs = Input(shape=(3,), batch_size=2, name="input")
|
|
|
|
x = layers.Dense(5)(inputs)
|
|
|
|
output_a = layers.Dense(4)(x)
|
|
|
|
output_b = layers.Dense(5)(x)
|
|
|
|
model = Functional(inputs, [output_a, output_b])
|
|
|
|
|
|
|
|
# Eager call
|
|
|
|
in_val = np.random.random((2, 3))
|
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertTrue(isinstance(out_val, list))
|
|
|
|
self.assertEqual(len(out_val), 2)
|
|
|
|
self.assertEqual(out_val[0].shape, (2, 4))
|
|
|
|
self.assertEqual(out_val[1].shape, (2, 5))
|
|
|
|
|
|
|
|
# Symbolic call
|
|
|
|
out_val = model(Input(shape=(3,), batch_size=2))
|
|
|
|
self.assertTrue(isinstance(out_val, list))
|
|
|
|
self.assertEqual(len(out_val), 2)
|
|
|
|
self.assertEqual(out_val[0].shape, (2, 4))
|
|
|
|
self.assertEqual(out_val[1].shape, (2, 5))
|
|
|
|
|
2023-04-12 22:41:35 +00:00
|
|
|
def test_basic_flow_dict_io(self):
|
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="b")
|
|
|
|
x = input_a + input_b
|
|
|
|
x = layers.Dense(5)(x)
|
|
|
|
outputs = layers.Dense(4)(x)
|
|
|
|
|
2023-04-12 22:43:56 +00:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "all values in the dict must be KerasTensors"
|
|
|
|
):
|
2023-04-12 22:41:35 +00:00
|
|
|
model = Functional({"aa": [input_a], "bb": input_b}, outputs)
|
|
|
|
|
2023-04-12 22:43:56 +00:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "all keys in the dict must match the names"
|
|
|
|
):
|
2023-04-12 22:41:35 +00:00
|
|
|
model = Functional({"aa": input_a, "bb": input_b}, outputs)
|
|
|
|
|
|
|
|
model = Functional({"a": input_a, "b": input_b}, outputs)
|
|
|
|
|
|
|
|
# Eager call
|
|
|
|
in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 3))}
|
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertEqual(out_val.shape, (2, 4))
|
|
|
|
|
|
|
|
# Symbolic call
|
|
|
|
input_a_2 = Input(shape=(3,), batch_size=2)
|
|
|
|
input_b_2 = Input(shape=(3,), batch_size=2)
|
|
|
|
in_val = {"a": input_a_2, "b": input_b_2}
|
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertEqual(out_val.shape, (2, 4))
|
|
|
|
|
2023-04-12 21:27:30 +00:00
|
|
|
def test_layer_getters(self):
|
|
|
|
# Test mixing ops and layers
|
2023-04-12 22:20:56 +00:00
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
|
|
x = input_a + input_b
|
|
|
|
x = layers.Dense(5, name="dense_1")(x)
|
|
|
|
outputs = layers.Dense(4, name="dense_2")(x)
|
|
|
|
model = Functional([input_a, input_b], outputs)
|
|
|
|
|
|
|
|
self.assertEqual(len(model.layers), 4)
|
|
|
|
self.assertEqual(len(model._operations), 5)
|
|
|
|
self.assertEqual(model.get_layer(index=0).name, "input_a")
|
|
|
|
self.assertEqual(model.get_layer(index=1).name, "input_b")
|
|
|
|
self.assertEqual(model.get_layer(index=2).name, "dense_1")
|
|
|
|
self.assertEqual(model.get_layer(index=3).name, "dense_2")
|
|
|
|
self.assertEqual(model.get_layer(name="dense_1").name, "dense_1")
|
2023-04-12 21:27:30 +00:00
|
|
|
|
|
|
|
def test_training_arg(self):
|
2023-04-12 22:20:56 +00:00
|
|
|
class Canary(layers.Layer):
|
|
|
|
def call(self, x, training=False):
|
|
|
|
assert training
|
|
|
|
return x
|
2023-04-12 22:43:56 +00:00
|
|
|
|
2023-04-12 22:20:56 +00:00
|
|
|
def compute_output_spec(self, x, training=False):
|
|
|
|
return backend.KerasTensor(x.shape, dtype=x.dtype)
|
2023-04-12 21:27:30 +00:00
|
|
|
|
2023-04-12 22:20:56 +00:00
|
|
|
inputs = Input(shape=(3,), batch_size=2)
|
|
|
|
outputs = Canary()(inputs)
|
|
|
|
model = Functional(inputs, outputs)
|
|
|
|
model(np.random.random((2, 3)), training=True)
|
2023-04-12 21:27:30 +00:00
|
|
|
|
2023-04-12 22:20:56 +00:00
|
|
|
def test_mask_arg(self):
|
|
|
|
# TODO
|
2023-04-12 21:27:30 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
def test_passing_inputs_by_name(self):
|
2023-04-12 22:41:35 +00:00
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
|
|
x = input_a + input_b
|
|
|
|
x = layers.Dense(5)(x)
|
|
|
|
outputs = layers.Dense(4)(x)
|
|
|
|
model = Functional([input_a, input_b], outputs)
|
|
|
|
|
|
|
|
# Eager call
|
2023-04-12 22:43:56 +00:00
|
|
|
in_val = {
|
|
|
|
"input_a": np.random.random((2, 3)),
|
|
|
|
"input_b": np.random.random((2, 3)),
|
|
|
|
}
|
2023-04-12 22:41:35 +00:00
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertEqual(out_val.shape, (2, 4))
|
|
|
|
|
|
|
|
# Symbolic call
|
|
|
|
input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2")
|
|
|
|
input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2")
|
|
|
|
in_val = {"input_a": input_a_2, "input_b": input_b_2}
|
|
|
|
out_val = model(in_val)
|
|
|
|
self.assertEqual(out_val.shape, (2, 4))
|
2023-04-12 21:27:30 +00:00
|
|
|
|
|
|
|
def test_rank_standardization(self):
|
2023-04-12 22:41:35 +00:00
|
|
|
# Downranking
|
|
|
|
inputs = Input(shape=(3,), batch_size=2)
|
|
|
|
outputs = layers.Dense(3)(inputs)
|
|
|
|
model = Functional(inputs, outputs)
|
|
|
|
out_val = model(np.random.random((2, 3, 1)))
|
|
|
|
self.assertEqual(out_val.shape, (2, 3))
|
|
|
|
|
|
|
|
# Upranking
|
|
|
|
inputs = Input(shape=(3, 1), batch_size=2)
|
|
|
|
outputs = layers.Dense(3)(inputs)
|
|
|
|
model = Functional(inputs, outputs)
|
|
|
|
out_val = model(np.random.random((2, 3)))
|
|
|
|
self.assertEqual(out_val.shape, (2, 3, 3))
|
2023-04-12 21:27:30 +00:00
|
|
|
|
|
|
|
def test_serialization(self):
|
2023-04-21 22:01:17 +00:00
|
|
|
# Test basic model
|
|
|
|
inputs = Input(shape=(3,), batch_size=2)
|
|
|
|
outputs = layers.Dense(3)(inputs)
|
|
|
|
model = Functional(inputs, outputs)
|
|
|
|
self.run_class_serialization_test(model)
|
|
|
|
|
|
|
|
# Test multi-io model
|
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
|
|
xa = layers.Dense(5, name="middle_a")(input_a)
|
|
|
|
xb = layers.Dense(5, name="middle_b")(input_b)
|
|
|
|
output_a = layers.Dense(4, name="output_a")(xa)
|
|
|
|
output_b = layers.Dense(4, name="output_b")(xb)
|
|
|
|
model = Functional(
|
|
|
|
[input_a, input_b], [output_a, output_b], name="func"
|
|
|
|
)
|
|
|
|
self.run_class_serialization_test(model)
|
|
|
|
|
|
|
|
# Test model that includes floating ops
|
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
|
|
x = input_a + input_b
|
|
|
|
x = layers.Dense(5, name="middle")(x)
|
|
|
|
output_a = layers.Dense(4, name="output_a")(x)
|
|
|
|
output_b = layers.Dense(4, name="output_b")(x)
|
|
|
|
model = Functional(
|
|
|
|
[input_a, input_b], [output_a, output_b], name="func"
|
|
|
|
)
|
|
|
|
self.run_class_serialization_test(model)
|
|
|
|
|
|
|
|
# Test model with dict i/o
|
|
|
|
input_a = Input(shape=(3,), batch_size=2, name="a")
|
|
|
|
input_b = Input(shape=(3,), batch_size=2, name="b")
|
|
|
|
x = input_a + input_b
|
|
|
|
x = layers.Dense(5)(x)
|
|
|
|
outputs = layers.Dense(4)(x)
|
|
|
|
model = Functional({"a": input_a, "b": input_b}, outputs)
|
|
|
|
self.run_class_serialization_test(model)
|
2023-04-12 21:27:30 +00:00
|
|
|
|
|
|
|
def test_add_loss(self):
|
|
|
|
# TODO
|
|
|
|
pass
|