Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
b720bba5f5
commit
68178e9cb2
@ -30,7 +30,7 @@ class InputLayer(Layer):
|
|||||||
if shape is None and batch_shape is None:
|
if shape is None and batch_shape is None:
|
||||||
raise ValueError("You must pass a `shape` argument.")
|
raise ValueError("You must pass a `shape` argument.")
|
||||||
|
|
||||||
if shape:
|
if shape is not None:
|
||||||
shape = backend.standardize_shape(shape)
|
shape = backend.standardize_shape(shape)
|
||||||
batch_shape = (batch_size,) + shape
|
batch_shape = (batch_size,) + shape
|
||||||
self.batch_shape = tuple(batch_shape)
|
self.batch_shape = tuple(batch_shape)
|
||||||
|
@ -34,6 +34,17 @@ class FunctionalTest(testing.TestCase):
|
|||||||
out_val = model(in_val)
|
out_val = model(in_val)
|
||||||
self.assertEqual(out_val.shape, (2, 4))
|
self.assertEqual(out_val.shape, (2, 4))
|
||||||
|
|
||||||
|
def test_scalar_input(self):
|
||||||
|
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||||
|
input_b = Input(shape=(), batch_size=2, name="input_b")
|
||||||
|
outputs = input_a + input_b[:, None]
|
||||||
|
model = Functional([input_a, input_b], outputs)
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
in_val = [np.zeros((2, 3)), np.ones((2,))]
|
||||||
|
out_val = model(in_val)
|
||||||
|
self.assertAllClose(out_val, np.ones((2, 3)))
|
||||||
|
|
||||||
def test_basic_flow_multi_output(self):
|
def test_basic_flow_multi_output(self):
|
||||||
inputs = Input(shape=(3,), batch_size=2, name="input")
|
inputs = Input(shape=(3,), batch_size=2, name="input")
|
||||||
x = layers.Dense(5)(inputs)
|
x = layers.Dense(5)(inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user