Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
b720bba5f5
commit
68178e9cb2
@ -30,7 +30,7 @@ class InputLayer(Layer):
|
||||
if shape is None and batch_shape is None:
|
||||
raise ValueError("You must pass a `shape` argument.")
|
||||
|
||||
if shape:
|
||||
if shape is not None:
|
||||
shape = backend.standardize_shape(shape)
|
||||
batch_shape = (batch_size,) + shape
|
||||
self.batch_shape = tuple(batch_shape)
|
||||
|
@ -34,6 +34,17 @@ class FunctionalTest(testing.TestCase):
|
||||
out_val = model(in_val)
|
||||
self.assertEqual(out_val.shape, (2, 4))
|
||||
|
||||
def test_scalar_input(self):
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(), batch_size=2, name="input_b")
|
||||
outputs = input_a + input_b[:, None]
|
||||
model = Functional([input_a, input_b], outputs)
|
||||
model.summary()
|
||||
|
||||
in_val = [np.zeros((2, 3)), np.ones((2,))]
|
||||
out_val = model(in_val)
|
||||
self.assertAllClose(out_val, np.ones((2, 3)))
|
||||
|
||||
def test_basic_flow_multi_output(self):
|
||||
inputs = Input(shape=(3,), batch_size=2, name="input")
|
||||
x = layers.Dense(5)(inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user