keras/keras_core/operations/math.py

62 lines
1.9 KiB
Python
Raw Normal View History

"""
segment_sum
top_k
in_top_k
logsumexp
"""
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend import any_symbolic_tensors
from keras_core.operations.operation import Operation
class SegmentSum(Operation):
def call(self, data, segment_ids, num_segments=None, sorted=False):
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
@keras_core_export("keras_core.operations.segment_sum")
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((data,)):
return SegmentSum().symbolic_call(
data, segment_ids, num_segments, sorted
)
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
class TopK(Operation):
def call(self, x, k, sorted=True):
return backend.math.top_k(x, k, sorted)
@keras_core_export("keras_core.operations.top_k")
def top_k(x, k, sorted=True):
if any_symbolic_tensors((x,)):
return TopK().symbolic_call(x, k, sorted)
return backend.math.top_k(x, k, sorted)
class InTopK(Operation):
def call(self, targets, predictions, k):
return backend.math.in_top_k(targets, predictions, k)
@keras_core_export("keras_core.operations.in_top_k")
def in_top_k(targets, predictions, k):
if any_symbolic_tensors((targets, predictions)):
return InTopK().symbolic_call(targets, predictions, k)
return backend.math.in_top_k(targets, predictions, k)
class Logsumexp(Operation):
def call(self, x, axis=None, keepdims=False):
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
@keras_core_export("keras_core.operations.logsumexp")
def logsumexp(x, axis=None, keepdims=False):
if any_symbolic_tensors((x,)):
return Logsumexp().symbolic_call(x, axis=axis, keepdims=keepdims)
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)