6f2d44e3ef
* Add put operation * Trainer format * Add scatter ops * Revamp for scatter op and 3D test * Formatting * Shape tests * Formatting
65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from keras_core import backend
|
|
from keras_core import testing
|
|
from keras_core.operations import core
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not backend.DYNAMIC_SHAPES_OK,
|
|
reason="Backend does not support dynamic shapes",
|
|
)
|
|
class CoreOpsDynamicShapeTest(testing.TestCase):
|
|
pass
|
|
|
|
|
|
class CoreOpsStaticShapeTest(testing.TestCase):
|
|
def test_scatter(self):
|
|
# Requires dtype
|
|
indices = np.array([[0]], dtype="int32")
|
|
values = np.array([0], dtype="int32")
|
|
shape = (8,)
|
|
self.assertEqual(core.scatter(indices, values, shape).shape, (8,))
|
|
|
|
|
|
class CoreOpsCorrectnessTest(testing.TestCase):
|
|
def test_scatter(self):
|
|
# Test 1D
|
|
indices = np.array([[1], [3], [4], [7]])
|
|
values = np.array([9, 10, 11, 12])
|
|
self.assertAllClose(
|
|
core.scatter(indices, values, (8,)),
|
|
[0, 9, 0, 10, 11, 0, 0, 12],
|
|
)
|
|
# Test 2D
|
|
indices = np.array([[0, 1], [2, 0]])
|
|
values = np.array([5, 10])
|
|
self.assertAllClose(
|
|
core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]]
|
|
)
|
|
# Test 3D
|
|
indices = np.array([[1], [3]])
|
|
values = np.array(
|
|
[
|
|
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
]
|
|
)
|
|
self.assertAllClose(
|
|
core.scatter(indices, values, (4, 4, 4)),
|
|
[
|
|
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
|
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
|
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
],
|
|
)
|
|
# Test slices
|
|
indices = np.array([[2], [4]])
|
|
values = np.array([[1, 2, 3], [4, 5, 6]])
|
|
self.assertAllClose(
|
|
core.scatter(indices, values, (6, 3)),
|
|
[[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]],
|
|
)
|