From 0490665dc93079ae044934bdf9ec0d2ba787cde6 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 8 Jun 2023 17:13:29 -0700 Subject: [PATCH] Add core ops shape (#306) --- keras_core/operations/__init__.py | 1 - keras_core/operations/core.py | 8 ++++++++ keras_core/operations/core_test.py | 7 +++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/keras_core/operations/__init__.py b/keras_core/operations/__init__.py index 945356e31..5777ea576 100644 --- a/keras_core/operations/__init__.py +++ b/keras_core/operations/__init__.py @@ -8,7 +8,6 @@ from keras_core.backend import convert_to_tensor from keras_core.backend import is_tensor from keras_core.backend import name_scope from keras_core.backend import random -from keras_core.backend import shape from keras_core.operations import image from keras_core.operations import operation_utils from keras_core.operations.core import * # noqa: F403 diff --git a/keras_core/operations/core.py b/keras_core/operations/core.py index d8b5bbc7a..efdba8600 100644 --- a/keras_core/operations/core.py +++ b/keras_core/operations/core.py @@ -222,3 +222,11 @@ def stop_gradient(variable): >>> var = keras_core.operations.stop_gradient(var) """ return backend.core.stop_gradient(variable) + + +@keras_core_export("keras_core.operations.shape") +def shape(x): + """Gets the shape of the tensor input.""" + if any_symbolic_tensors((x,)): + return x.shape + return backend.core.shape(x) diff --git a/keras_core/operations/core_test.py b/keras_core/operations/core_test.py index 60bb2b1c1..732fc9725 100644 --- a/keras_core/operations/core_test.py +++ b/keras_core/operations/core_test.py @@ -184,3 +184,10 @@ class CoreOpsCorrectnessTest(testing.TestCase): model.fit(x, y, epochs=1, batch_size=2) self.assertEqual(model.layers[0].w.numpy(), 0.0) self.assertNotEqual(model.layers[0].b.numpy(), 0.0) + + def test_shape(self): + x = np.ones((2, 3, 7, 1)) + self.assertAllEqual(core.shape(x), (2, 3, 7, 1)) + + x = KerasTensor((None, 3, None, 1)) + self.assertAllEqual(core.shape(x), (None, 3, None, 1))