Add core ops shape (#306)

This commit is contained in:
Matt Watson 2023-06-08 17:13:29 -07:00 committed by Francois Chollet
parent 636337fd06
commit 0490665dc9
3 changed files with 15 additions and 1 deletions

@ -8,7 +8,6 @@ from keras_core.backend import convert_to_tensor
from keras_core.backend import is_tensor from keras_core.backend import is_tensor
from keras_core.backend import name_scope from keras_core.backend import name_scope
from keras_core.backend import random from keras_core.backend import random
from keras_core.backend import shape
from keras_core.operations import image from keras_core.operations import image
from keras_core.operations import operation_utils from keras_core.operations import operation_utils
from keras_core.operations.core import * # noqa: F403 from keras_core.operations.core import * # noqa: F403

@ -222,3 +222,11 @@ def stop_gradient(variable):
>>> var = keras_core.operations.stop_gradient(var) >>> var = keras_core.operations.stop_gradient(var)
""" """
return backend.core.stop_gradient(variable) return backend.core.stop_gradient(variable)
@keras_core_export("keras_core.operations.shape")
def shape(x):
"""Gets the shape of the tensor input."""
if any_symbolic_tensors((x,)):
return x.shape
return backend.core.shape(x)

@ -184,3 +184,10 @@ class CoreOpsCorrectnessTest(testing.TestCase):
model.fit(x, y, epochs=1, batch_size=2) model.fit(x, y, epochs=1, batch_size=2)
self.assertEqual(model.layers[0].w.numpy(), 0.0) self.assertEqual(model.layers[0].w.numpy(), 0.0)
self.assertNotEqual(model.layers[0].b.numpy(), 0.0) self.assertNotEqual(model.layers[0].b.numpy(), 0.0)
def test_shape(self):
x = np.ones((2, 3, 7, 1))
self.assertAllEqual(core.shape(x), (2, 3, 7, 1))
x = KerasTensor((None, 3, None, 1))
self.assertAllEqual(core.shape(x), (None, 3, None, 1))