support keras.Input in get_source_inputs (#732)

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
Haifeng Jin 2023-08-15 14:40:46 -07:00 committed by Francois Chollet
parent d9f0a5ab9e
commit 4513bf7ba0
4 changed files with 10 additions and 34 deletions

@ -88,8 +88,10 @@ class LayerNormalizationTest(testing.TestCase):
def test_output(self):
layer = layers.LayerNormalization(
dtype="float32", beta_initializer="ones", gamma_initializer="ones",
dtype="float32",
beta_initializer="ones",
gamma_initializer="ones",
)
inputs = np.arange(5).astype("float32")[None, :]
out = layer(inputs)
self.assertAllClose(out, [[-0.41386, 0.29307, 1., 1.70693, 2.41386]])
self.assertAllClose(out, [[-0.41386, 0.29307, 1.0, 1.70693, 2.41386]])

@ -481,34 +481,3 @@ def rsqrt(x):
return Rsqrt().symbolic_call(x)
x = backend.convert_to_tensor(x)
return backend.math.rsqrt(x)
class Rsqrt(Operation):
"""Computes reciprocal of square root of x element-wise.
Args:
x: input tensor
Returns:
A tensor with the same type as `x`.
Example:
>>> x = keras_core.ops.convert_to_tensor([2., 3., -2.])
>>> rsqrt(x)
"""
def call(self, x):
x = backend.convert_to_tensor(x)
return backend.math.rsqrt(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export("keras_core.ops.rsqrt")
def rsqrt(x):
if any_symbolic_tensors((x,)):
return Rsqrt().symbolic_call(x)
x = backend.convert_to_tensor(x)
return backend.math.rsqrt(x)

@ -261,7 +261,7 @@ def get_source_inputs(tensor):
node = operation._inbound_nodes[node_index]
if node.is_input:
# Reached input node, stop recursion.
return tree.flatten(node.input_tensors)
return tree.flatten(node.output_tensors)
else:
source_tensors = []
for tensor in node.input_tensors:

@ -1,6 +1,7 @@
from keras_core import backend
from keras_core import ops
from keras_core import testing
from keras_core.layers.core import input_layer
from keras_core.ops import operation_utils
@ -12,3 +13,7 @@ class OperationUtilsTest(testing.TestCase):
x += 2
x = ops.square(x)
self.assertEqual(operation_utils.get_source_inputs(x), [x1, x2])
def test_get_source_inputs_return_input_tensor(self):
inputs = input_layer.Input(shape=(10,))
self.assertIs(operation_utils.get_source_inputs(inputs)[0], inputs)