keras/keras_core/operations/core.py
Chen Qian 59876afd28 Export operations (#176)
* Export operations

* dual export
2023-05-16 20:35:11 -07:00

25 lines
760 B
Python

"""
scatter
"""
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.operations.operation import Operation
class Scatter(Operation):
def call(self, indices, values, shape):
return backend.core.scatter(indices, values, shape)
def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)
@keras_core_export("keras_core.operations.scatter")
def scatter(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return Scatter().symbolic_call(indices, values, shape)
return backend.core.scatter(indices, values, shape)