import numpy as np import pytest from keras_core import backend from keras_core import testing from keras_core.operations import core @pytest.mark.skipif( not backend.DYNAMIC_SHAPES_OK, reason="Backend does not support dynamic shapes", ) class CoreOpsDynamicShapeTest(testing.TestCase): pass class CoreOpsStaticShapeTest(testing.TestCase): def test_scatter(self): # Requires dtype indices = np.array([[0]], dtype="int32") values = np.array([0], dtype="int32") shape = (8,) self.assertEqual(core.scatter(indices, values, shape).shape, (8,)) class CoreOpsCorrectnessTest(testing.TestCase): def test_scatter(self): # Test 1D indices = np.array([[1], [3], [4], [7]]) values = np.array([9, 10, 11, 12]) self.assertAllClose( core.scatter(indices, values, (8,)), [0, 9, 0, 10, 11, 0, 0, 12], ) # Test 2D indices = np.array([[0, 1], [2, 0]]) values = np.array([5, 10]) self.assertAllClose( core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]] ) # Test 3D indices = np.array([[1], [3]]) values = np.array( [ [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], ] ) self.assertAllClose( core.scatter(indices, values, (4, 4, 4)), [ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], ], ) # Test slices indices = np.array([[2], [4]]) values = np.array([[1, 2, 3], [4, 5, 6]]) self.assertAllClose( core.scatter(indices, values, (6, 3)), [[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]], )