Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-06-06 16:01:35 -07:00
parent b720bba5f5
commit 68178e9cb2
2 changed files with 12 additions and 1 deletions

@ -30,7 +30,7 @@ class InputLayer(Layer):
if shape is None and batch_shape is None: if shape is None and batch_shape is None:
raise ValueError("You must pass a `shape` argument.") raise ValueError("You must pass a `shape` argument.")
if shape: if shape is not None:
shape = backend.standardize_shape(shape) shape = backend.standardize_shape(shape)
batch_shape = (batch_size,) + shape batch_shape = (batch_size,) + shape
self.batch_shape = tuple(batch_shape) self.batch_shape = tuple(batch_shape)

@ -34,6 +34,17 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val) out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4)) self.assertEqual(out_val.shape, (2, 4))
def test_scalar_input(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(), batch_size=2, name="input_b")
outputs = input_a + input_b[:, None]
model = Functional([input_a, input_b], outputs)
model.summary()
in_val = [np.zeros((2, 3)), np.ones((2,))]
out_val = model(in_val)
self.assertAllClose(out_val, np.ones((2, 3)))
def test_basic_flow_multi_output(self): def test_basic_flow_multi_output(self):
inputs = Input(shape=(3,), batch_size=2, name="input") inputs = Input(shape=(3,), batch_size=2, name="input")
x = layers.Dense(5)(inputs) x = layers.Dense(5)(inputs)