2d40cb20b9
* Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests * Adds RandomContrast and associated tests * Remove arithmetic comment * Adds RandomBrightness and tests * Fix docstring and format * Fix nits and add backend generator * Inlines random_contrast helper * Add bincount op * Add CategoryEncoding layer and tests * Fix formatting * Fix JAX issues * Fix JAX bincount * Formatting and small fix * Fix nits and docstrings * Add args to bincount op test
23 lines
656 B
Python
23 lines
656 B
Python
"""
|
|
scatter
|
|
"""
|
|
|
|
from keras_core import backend
|
|
from keras_core.backend import KerasTensor
|
|
from keras_core.backend import any_symbolic_tensors
|
|
from keras_core.operations.operation import Operation
|
|
|
|
|
|
class Scatter(Operation):
|
|
def call(self, indices, values, shape):
|
|
return backend.core.scatter(indices, values, shape)
|
|
|
|
def compute_output_spec(self, indices, values, shape):
|
|
return KerasTensor(shape, dtype=values.dtype)
|
|
|
|
|
|
def scatter(indices, values, shape):
|
|
if any_symbolic_tensors((indices, values, shape)):
|
|
return Scatter().symbolic_call(indices, values, shape)
|
|
return backend.core.scatter(indices, values, shape)
|