2023-04-30 16:10:07 +00:00
|
|
|
"""
|
|
|
|
segment_sum
|
|
|
|
top_k
|
2023-05-17 03:35:11 +00:00
|
|
|
in_top_k
|
|
|
|
logsumexp
|
2023-04-30 16:10:07 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
from keras_core import backend
|
2023-05-17 03:35:11 +00:00
|
|
|
from keras_core.api_export import keras_core_export
|
2023-04-30 16:10:07 +00:00
|
|
|
from keras_core.backend import any_symbolic_tensors
|
|
|
|
from keras_core.operations.operation import Operation
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentSum(Operation):
|
2023-05-05 16:53:36 +00:00
|
|
|
def call(self, data, segment_ids, num_segments=None, sorted=False):
|
|
|
|
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
|
2023-04-30 16:10:07 +00:00
|
|
|
|
|
|
|
|
2023-05-17 03:35:11 +00:00
|
|
|
@keras_core_export("keras_core.operations.segment_sum")
|
2023-05-05 16:53:36 +00:00
|
|
|
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
|
|
|
if any_symbolic_tensors((data,)):
|
|
|
|
return SegmentSum().symbolic_call(
|
|
|
|
data, segment_ids, num_segments, sorted
|
|
|
|
)
|
|
|
|
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
|
2023-04-30 16:10:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TopK(Operation):
|
|
|
|
def call(self, x, k, sorted=True):
|
|
|
|
return backend.math.top_k(x, k, sorted)
|
|
|
|
|
|
|
|
|
2023-05-17 03:35:11 +00:00
|
|
|
@keras_core_export("keras_core.operations.top_k")
|
2023-04-30 16:10:07 +00:00
|
|
|
def top_k(x, k, sorted=True):
|
|
|
|
if any_symbolic_tensors((x,)):
|
|
|
|
return TopK().symbolic_call(x, k, sorted)
|
|
|
|
return backend.math.top_k(x, k, sorted)
|
2023-05-01 18:17:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
class InTopK(Operation):
|
|
|
|
def call(self, targets, predictions, k):
|
|
|
|
return backend.math.in_top_k(targets, predictions, k)
|
|
|
|
|
|
|
|
|
2023-05-17 03:35:11 +00:00
|
|
|
@keras_core_export("keras_core.operations.in_top_k")
|
2023-05-01 18:17:26 +00:00
|
|
|
def in_top_k(targets, predictions, k):
|
|
|
|
if any_symbolic_tensors((targets, predictions)):
|
|
|
|
return InTopK().symbolic_call(targets, predictions, k)
|
|
|
|
return backend.math.in_top_k(targets, predictions, k)
|
2023-05-09 23:03:15 +00:00
|
|
|
|
|
|
|
|
2023-05-17 03:35:11 +00:00
|
|
|
class Logsumexp(Operation):
|
|
|
|
def call(self, x, axis=None, keepdims=False):
|
|
|
|
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
|
|
|
|
|
|
|
|
|
|
|
|
@keras_core_export("keras_core.operations.logsumexp")
|
2023-05-09 23:03:15 +00:00
|
|
|
def logsumexp(x, axis=None, keepdims=False):
|
|
|
|
if any_symbolic_tensors((x,)):
|
|
|
|
return Logsumexp().symbolic_call(x, axis=axis, keepdims=keepdims)
|
|
|
|
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
|