""" segment_sum top_k in_top_k logsumexp """ from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.backend import KerasTensor from keras_core.backend import any_symbolic_tensors from keras_core.operations.operation import Operation from keras_core.operations.operation_utils import reduce_shape class SegmentSum(Operation): def __init__(self, num_segments=None, sorted=False): super().__init__() self.num_segments = num_segments self.sorted = sorted def compute_output_spec(self, data, segment_ids): num_segments = self.num_segments output_shape = (num_segments,) + tuple(data.shape[1:]) return KerasTensor(shape=output_shape, dtype=data.dtype) def call(self, data, segment_ids): return backend.math.segment_sum( data, segment_ids, num_segments=self.num_segments, sorted=self.sorted, ) @keras_core_export("keras_core.operations.segment_sum") def segment_sum(data, segment_ids, num_segments=None, sorted=False): if any_symbolic_tensors((data,)): return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids) return backend.math.segment_sum( data, segment_ids, num_segments=num_segments, sorted=sorted ) class TopK(Operation): def __init__(self, k, sorted=False): super().__init__() self.k = k self.sorted = sorted def compute_output_spec(self, x): output_shape = list(x.shape) output_shape[-1] = self.k # Return a tuple (values, indices). return ( KerasTensor(shape=output_shape, dtype=x.dtype), KerasTensor(shape=output_shape, dtype="int32"), ) def call(self, x): return backend.math.top_k(x, self.k, self.sorted) @keras_core_export("keras_core.operations.top_k") def top_k(x, k, sorted=True): if any_symbolic_tensors((x,)): return TopK(k, sorted).symbolic_call(x) return backend.math.top_k(x, k, sorted) class InTopK(Operation): def __init__(self, k): super().__init__() self.k = k def compute_output_spec(self, targets, predictions): return KerasTensor(shape=targets.shape, dtype="bool") def call(self, targets, predictions): return backend.math.in_top_k(targets, predictions, self.k) @keras_core_export("keras_core.operations.in_top_k") def in_top_k(targets, predictions, k): if any_symbolic_tensors((targets, predictions)): return InTopK(k).symbolic_call(targets, predictions) return backend.math.in_top_k(targets, predictions, k) class Logsumexp(Operation): def __init__(self, axis=None, keepdims=False): super().__init__() self.axis = axis self.keepdims = keepdims def compute_output_spec(self, x): output_shape = reduce_shape(x.shape, self.axis, self.keepdims) return KerasTensor(shape=output_shape) def call(self, x): return backend.math.logsumexp(x, axis=self.axis, keepdims=self.keepdims) @keras_core_export("keras_core.operations.logsumexp") def logsumexp(x, axis=None, keepdims=False): if any_symbolic_tensors((x,)): return Logsumexp(axis, keepdims).symbolic_call(x) return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)