Add compute output spec test
This commit is contained in:
parent
2bab1b1923
commit
427562e42f
44
keras_core/backend/tests/compute_output_spec_test.py
Normal file
44
keras_core/backend/tests/compute_output_spec_test.py
Normal file
@ -0,0 +1,44 @@
|
||||
import unittest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
|
||||
|
||||
def single_arg_test_fn(x):
|
||||
return backend.numpy.concatenate([(x + 1) ** 2, x], axis=-1)
|
||||
|
||||
|
||||
def three_args_2_kwarg_test_fn(x1, x2, x3=None):
|
||||
x1 = backend.numpy.max(x1, axis=1)
|
||||
x2 = backend.numpy.max(x2, axis=1)
|
||||
if x3 is not None:
|
||||
x1 += backend.numpy.max(x3, axis=1)
|
||||
return x1 + x2
|
||||
|
||||
|
||||
class ComputeOutputSpecTest(unittest.TestCase):
|
||||
def test_dynamic_batch_size(self):
|
||||
x = KerasTensor(shape=(None, 3, 5))
|
||||
y = backend.compute_output_spec(single_arg_test_fn, x)
|
||||
self.assertEqual(y.shape, (None, 3, 10))
|
||||
|
||||
x1 = KerasTensor(shape=(None, 3, 5))
|
||||
x2 = KerasTensor(shape=(None, 3, 5))
|
||||
x3 = KerasTensor(shape=(None, 3, 5))
|
||||
y = backend.compute_output_spec(
|
||||
three_args_2_kwarg_test_fn, x1, x2, x3=x3
|
||||
)
|
||||
self.assertEqual(y.shape, (None, 5))
|
||||
|
||||
def test_dynamic_everything(self):
|
||||
x = KerasTensor(shape=(2, None, 3))
|
||||
y = backend.compute_output_spec(single_arg_test_fn, x)
|
||||
self.assertEqual(y.shape, (2, None, 6))
|
||||
|
||||
x1 = KerasTensor(shape=(None, None, 5))
|
||||
x2 = KerasTensor(shape=(None, None, 5))
|
||||
x3 = KerasTensor(shape=(None, None, 5))
|
||||
y = backend.compute_output_spec(
|
||||
three_args_2_kwarg_test_fn, x1, x2, x3=x3
|
||||
)
|
||||
self.assertEqual(y.shape, (None, 5))
|
Loading…
Reference in New Issue
Block a user