keras/keras_core/operations/core.py
Neel Kovelamudi 2d40cb20b9 Adds CategoryEncoding layer, bincount op, and tests (#161)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
2023-05-14 00:07:43 +00:00

23 lines
656 B
Python

"""
scatter
"""
from keras_core import backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.operations.operation import Operation
class Scatter(Operation):
def call(self, indices, values, shape):
return backend.core.scatter(indices, values, shape)
def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)
def scatter(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return Scatter().symbolic_call(indices, values, shape)
return backend.core.scatter(indices, values, shape)