keras/keras_core/operations/core_test.py

65 lines
2.1 KiB
Python
Raw Normal View History

import numpy as np
import pytest
from keras_core import backend
from keras_core import testing
from keras_core.operations import core
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class CoreOpsDynamicShapeTest(testing.TestCase):
pass
class CoreOpsStaticShapeTest(testing.TestCase):
def test_scatter(self):
# Requires dtype
indices = np.array([[0]], dtype="int32")
values = np.array([0], dtype="int32")
shape = (8,)
self.assertEqual(core.scatter(indices, values, shape).shape, (8,))
class CoreOpsCorrectnessTest(testing.TestCase):
def test_scatter(self):
# Test 1D
indices = np.array([[1], [3], [4], [7]])
values = np.array([9, 10, 11, 12])
self.assertAllClose(
core.scatter(indices, values, (8,)),
[0, 9, 0, 10, 11, 0, 0, 12],
)
# Test 2D
indices = np.array([[0, 1], [2, 0]])
values = np.array([5, 10])
self.assertAllClose(
core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]]
)
# Test 3D
indices = np.array([[1], [3]])
values = np.array(
[
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
]
)
self.assertAllClose(
core.scatter(indices, values, (4, 4, 4)),
[
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
],
)
# Test slices
indices = np.array([[2], [4]])
values = np.array([[1, 2, 3], [4, 5, 6]])
self.assertAllClose(
core.scatter(indices, values, (6, 3)),
[[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]],
)