d6bcc56001
* Added confusion metrics -- still using TF ops * Fixed structure + tests pass for TF (still need to port to multi-backend) * Got rid of most tf deps, still a few more to go * Full removal of TF. Tests pass for both Jax and TF * Full removal of TF. Tests pass for both Jax and TF * Formatting * Formatting * Review comments * More review comments + formatting
13 lines
330 B
Python
13 lines
330 B
Python
import tensorflow as tf
|
|
|
|
|
|
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
|
if sorted:
|
|
return tf.math.segment_sum(data, segment_ids)
|
|
else:
|
|
return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)
|
|
|
|
|
|
def top_k(x, k, sorted=False):
|
|
return tf.math.top_k(x, k, sorted=sorted)
|