keras/keras_core/operations/operation_test.py
2023-04-23 11:05:04 -07:00

151 lines
5.0 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import testing
from keras_core.backend import keras_tensor
from keras_core.operations import numpy as knp
from keras_core.operations import operation
class OpWithMultipleInputs(operation.Operation):
def call(self, x, y, z=None):
return x + 2 * y + 3 * z
def compute_output_spec(self, x, y, z=None):
return keras_tensor.KerasTensor(x.shape, x.dtype)
class OpWithMultipleOutputs(operation.Operation):
def call(self, x):
return (x, x + 1)
def compute_output_spec(self, x):
return (
keras_tensor.KerasTensor(x.shape, x.dtype),
keras_tensor.KerasTensor(x.shape, x.dtype),
)
class OpWithCustomConstructor(operation.Operation):
def __init__(self, alpha, mode="foo"):
super().__init__()
self.alpha = alpha
self.mode = mode
def call(self, x):
if self.mode == "foo":
return x
return self.alpha * x
def compute_output_spec(self, x):
return keras_tensor.KerasTensor(x.shape, x.dtype)
class OperationTest(testing.TestCase):
def test_symbolic_call(self):
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
y = keras_tensor.KerasTensor(shape=(2, 3), name="y")
z = keras_tensor.KerasTensor(shape=(2, 3), name="z")
# Positional arguments
op = OpWithMultipleInputs(name="test_op")
self.assertEqual(op.name, "test_op")
out = op(x, y, z)
self.assertIsInstance(out, keras_tensor.KerasTensor)
self.assertEqual(out.shape, (2, 3))
self.assertEqual(len(op._inbound_nodes), 1)
self.assertEqual(op.input, [x, y, z])
self.assertEqual(op.output, out)
# Keyword arguments
op = OpWithMultipleInputs(name="test_op")
out = op(x=x, y=y, z=z)
self.assertIsInstance(out, keras_tensor.KerasTensor)
self.assertEqual(out.shape, (2, 3))
self.assertEqual(len(op._inbound_nodes), 1)
self.assertEqual(op.input, [x, y, z])
self.assertEqual(op.output, out)
# Mix
op = OpWithMultipleInputs(name="test_op")
out = op(x, y=y, z=z)
self.assertIsInstance(out, keras_tensor.KerasTensor)
self.assertEqual(out.shape, (2, 3))
self.assertEqual(len(op._inbound_nodes), 1)
self.assertEqual(op.input, [x, y, z])
self.assertEqual(op.output, out)
# Test op reuse
prev_out = out
out = op(x, y=y, z=z)
self.assertIsInstance(out, keras_tensor.KerasTensor)
self.assertEqual(out.shape, (2, 3))
self.assertEqual(len(op._inbound_nodes), 2)
self.assertEqual(op.output, prev_out)
# Test multiple outputs
op = OpWithMultipleOutputs()
out = op(x)
self.assertIsInstance(out, tuple)
self.assertEqual(len(out), 2)
self.assertIsInstance(out[0], keras_tensor.KerasTensor)
self.assertIsInstance(out[1], keras_tensor.KerasTensor)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
self.assertEqual(len(op._inbound_nodes), 1)
self.assertEqual(op.output, list(out))
def test_eager_call(self):
x = knp.ones((2, 3))
y = knp.ones((2, 3))
z = knp.ones((2, 3))
op = OpWithMultipleInputs(name="test_op")
self.assertEqual(op.name, "test_op")
# Positional arguments
out = op(x, y, z)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(out, 6 * np.ones((2, 3)))
# Keyword arguments
out = op(x=x, y=y, z=z)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(out, 6 * np.ones((2, 3)))
# Mixed arguments
out = op(x, y=y, z=z)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(out, 6 * np.ones((2, 3)))
# Test multiple outputs
op = OpWithMultipleOutputs()
out = op(x)
self.assertEqual(len(out), 2)
self.assertTrue(backend.is_tensor(out[0]))
self.assertTrue(backend.is_tensor(out[1]))
self.assertAllClose(out[0], np.ones((2, 3)))
self.assertAllClose(out[1], np.ones((2, 3)) + 1)
def test_serialization(self):
op = OpWithMultipleOutputs(name="test_op")
config = op.get_config()
self.assertEqual(config, {"name": "test_op"})
op = OpWithMultipleOutputs.from_config(config)
self.assertEqual(op.name, "test_op")
def test_autoconfig(self):
op = OpWithCustomConstructor(alpha=0.2, mode="bar")
config = op.get_config()
self.assertEqual(config, {"alpha": 0.2, "mode": "bar"})
revived = OpWithCustomConstructor.from_config(config)
self.assertEqual(revived.get_config(), config)
def test_input_conversion(self):
x = np.ones((2,))
y = np.ones((2,))
z = knp.ones((2,)) # mix
op = OpWithMultipleInputs()
out = op(x, y, z)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(out, 6 * np.ones((2,)))