Add core ops shape (#306)
This commit is contained in:
parent
636337fd06
commit
0490665dc9
@ -8,7 +8,6 @@ from keras_core.backend import convert_to_tensor
|
|||||||
from keras_core.backend import is_tensor
|
from keras_core.backend import is_tensor
|
||||||
from keras_core.backend import name_scope
|
from keras_core.backend import name_scope
|
||||||
from keras_core.backend import random
|
from keras_core.backend import random
|
||||||
from keras_core.backend import shape
|
|
||||||
from keras_core.operations import image
|
from keras_core.operations import image
|
||||||
from keras_core.operations import operation_utils
|
from keras_core.operations import operation_utils
|
||||||
from keras_core.operations.core import * # noqa: F403
|
from keras_core.operations.core import * # noqa: F403
|
||||||
|
@ -222,3 +222,11 @@ def stop_gradient(variable):
|
|||||||
>>> var = keras_core.operations.stop_gradient(var)
|
>>> var = keras_core.operations.stop_gradient(var)
|
||||||
"""
|
"""
|
||||||
return backend.core.stop_gradient(variable)
|
return backend.core.stop_gradient(variable)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.operations.shape")
|
||||||
|
def shape(x):
|
||||||
|
"""Gets the shape of the tensor input."""
|
||||||
|
if any_symbolic_tensors((x,)):
|
||||||
|
return x.shape
|
||||||
|
return backend.core.shape(x)
|
||||||
|
@ -184,3 +184,10 @@ class CoreOpsCorrectnessTest(testing.TestCase):
|
|||||||
model.fit(x, y, epochs=1, batch_size=2)
|
model.fit(x, y, epochs=1, batch_size=2)
|
||||||
self.assertEqual(model.layers[0].w.numpy(), 0.0)
|
self.assertEqual(model.layers[0].w.numpy(), 0.0)
|
||||||
self.assertNotEqual(model.layers[0].b.numpy(), 0.0)
|
self.assertNotEqual(model.layers[0].b.numpy(), 0.0)
|
||||||
|
|
||||||
|
def test_shape(self):
|
||||||
|
x = np.ones((2, 3, 7, 1))
|
||||||
|
self.assertAllEqual(core.shape(x), (2, 3, 7, 1))
|
||||||
|
|
||||||
|
x = KerasTensor((None, 3, None, 1))
|
||||||
|
self.assertAllEqual(core.shape(x), (None, 3, None, 1))
|
||||||
|
Loading…
Reference in New Issue
Block a user