support keras.Input in get_source_inputs (#732)
Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
parent
d9f0a5ab9e
commit
4513bf7ba0
@ -88,8 +88,10 @@ class LayerNormalizationTest(testing.TestCase):
|
||||
|
||||
def test_output(self):
|
||||
layer = layers.LayerNormalization(
|
||||
dtype="float32", beta_initializer="ones", gamma_initializer="ones",
|
||||
dtype="float32",
|
||||
beta_initializer="ones",
|
||||
gamma_initializer="ones",
|
||||
)
|
||||
inputs = np.arange(5).astype("float32")[None, :]
|
||||
out = layer(inputs)
|
||||
self.assertAllClose(out, [[-0.41386, 0.29307, 1., 1.70693, 2.41386]])
|
||||
self.assertAllClose(out, [[-0.41386, 0.29307, 1.0, 1.70693, 2.41386]])
|
||||
|
@ -481,34 +481,3 @@ def rsqrt(x):
|
||||
return Rsqrt().symbolic_call(x)
|
||||
x = backend.convert_to_tensor(x)
|
||||
return backend.math.rsqrt(x)
|
||||
|
||||
|
||||
class Rsqrt(Operation):
|
||||
"""Computes reciprocal of square root of x element-wise.
|
||||
|
||||
Args:
|
||||
x: input tensor
|
||||
|
||||
Returns:
|
||||
A tensor with the same type as `x`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = keras_core.ops.convert_to_tensor([2., 3., -2.])
|
||||
>>> rsqrt(x)
|
||||
"""
|
||||
|
||||
def call(self, x):
|
||||
x = backend.convert_to_tensor(x)
|
||||
return backend.math.rsqrt(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape, dtype=x.dtype)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.ops.rsqrt")
|
||||
def rsqrt(x):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Rsqrt().symbolic_call(x)
|
||||
x = backend.convert_to_tensor(x)
|
||||
return backend.math.rsqrt(x)
|
||||
|
@ -261,7 +261,7 @@ def get_source_inputs(tensor):
|
||||
node = operation._inbound_nodes[node_index]
|
||||
if node.is_input:
|
||||
# Reached input node, stop recursion.
|
||||
return tree.flatten(node.input_tensors)
|
||||
return tree.flatten(node.output_tensors)
|
||||
else:
|
||||
source_tensors = []
|
||||
for tensor in node.input_tensors:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from keras_core import backend
|
||||
from keras_core import ops
|
||||
from keras_core import testing
|
||||
from keras_core.layers.core import input_layer
|
||||
from keras_core.ops import operation_utils
|
||||
|
||||
|
||||
@ -12,3 +13,7 @@ class OperationUtilsTest(testing.TestCase):
|
||||
x += 2
|
||||
x = ops.square(x)
|
||||
self.assertEqual(operation_utils.get_source_inputs(x), [x1, x2])
|
||||
|
||||
def test_get_source_inputs_return_input_tensor(self):
|
||||
inputs = input_layer.Input(shape=(10,))
|
||||
self.assertIs(operation_utils.get_source_inputs(inputs)[0], inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user