keras/keras_core/operations/math.py
Chen Qian 752f4b581b Add torch.math.* (#189)
* initial

* Add ops

* fix comments
2023-05-18 18:04:25 -07:00

109 lines
3.3 KiB
Python

"""
segment_sum
top_k
in_top_k
logsumexp
"""
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.operations.operation import Operation
from keras_core.operations.operation_utils import reduce_shape
class SegmentSum(Operation):
def __init__(self, num_segments=None, sorted=False):
super().__init__()
self.num_segments = num_segments
self.sorted = sorted
def compute_output_spec(self, data, segment_ids):
num_segments = self.num_segments
output_shape = (num_segments,) + tuple(data.shape[1:])
return KerasTensor(shape=output_shape, dtype=data.dtype)
def call(self, data, segment_ids):
return backend.math.segment_sum(
data,
segment_ids,
num_segments=self.num_segments,
sorted=self.sorted,
)
@keras_core_export("keras_core.operations.segment_sum")
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((data,)):
return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
return backend.math.segment_sum(
data, segment_ids, num_segments=num_segments, sorted=sorted
)
class TopK(Operation):
def __init__(self, k, sorted=False):
super().__init__()
self.k = k
self.sorted = sorted
def compute_output_spec(self, x):
output_shape = list(x.shape)
output_shape[-1] = self.k
# Return a tuple (values, indices).
return (
KerasTensor(shape=output_shape, dtype=x.dtype),
KerasTensor(shape=output_shape, dtype="int32"),
)
def call(self, x):
return backend.math.top_k(x, self.k, self.sorted)
@keras_core_export("keras_core.operations.top_k")
def top_k(x, k, sorted=True):
if any_symbolic_tensors((x,)):
return TopK(k, sorted).symbolic_call(x)
return backend.math.top_k(x, k, sorted)
class InTopK(Operation):
def __init__(self, k):
super().__init__()
self.k = k
def compute_output_spec(self, targets, predictions):
return KerasTensor(shape=targets.shape, dtype="bool")
def call(self, targets, predictions):
return backend.math.in_top_k(targets, predictions, self.k)
@keras_core_export("keras_core.operations.in_top_k")
def in_top_k(targets, predictions, k):
if any_symbolic_tensors((targets, predictions)):
return InTopK(k).symbolic_call(targets, predictions)
return backend.math.in_top_k(targets, predictions, k)
class Logsumexp(Operation):
def __init__(self, axis=None, keepdims=False):
super().__init__()
self.axis = axis
self.keepdims = keepdims
def compute_output_spec(self, x):
output_shape = reduce_shape(x.shape, self.axis, self.keepdims)
return KerasTensor(shape=output_shape)
def call(self, x):
return backend.math.logsumexp(x, axis=self.axis, keepdims=self.keepdims)
@keras_core_export("keras_core.operations.logsumexp")
def logsumexp(x, axis=None, keepdims=False):
if any_symbolic_tensors((x,)):
return Logsumexp(axis, keepdims).symbolic_call(x)
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)