diff --git a/keras_core/backend/jax/core.py b/keras_core/backend/jax/core.py index ea9463c51..92f9b1e9a 100644 --- a/keras_core/backend/jax/core.py +++ b/keras_core/backend/jax/core.py @@ -205,35 +205,3 @@ def block_update(inputs, start_indices, updates): ] inputs[tuple(slices)] = updates return inputs - - -def while_loop( - cond, - body, - loop_vars, - maximum_iterations=None, -): - loop_vars = tuple(loop_vars) - if maximum_iterations is not None: - current_iter = 0 - loop_vars = loop_vars + (current_iter,) - - # Unpack list/tuple args. The last argument is `current_iter`. - def _cond(args): - return cond(*args[:-1]) & (args[-1] < maximum_iterations) - - def _body(args): - return tuple(body(*args[:-1])) + (args[-1] + 1,) - - else: - - def _cond(args): - return cond(*args) - - def _body(args): - return tuple(body(*args)) - - outputs = jax.lax.while_loop(_cond, _body, loop_vars) - if maximum_iterations is not None: - outputs = outputs[:-1] - return outputs diff --git a/keras_core/backend/tensorflow/core.py b/keras_core/backend/tensorflow/core.py index 0c92df315..34918fdb0 100644 --- a/keras_core/backend/tensorflow/core.py +++ b/keras_core/backend/tensorflow/core.py @@ -111,17 +111,3 @@ def scatter_update(inputs, indices, updates): def block_update(inputs, start_indices, updates): return dynamic_update_slice(inputs, updates, start_indices) - - -def while_loop( - cond, - body, - loop_vars, - maximum_iterations=None, -): - return tf.while_loop( - cond, - body, - loop_vars, - maximum_iterations=maximum_iterations, - ) diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index 396834474..cfc214019 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -130,23 +130,3 @@ def block_update(inputs, start_indices, updates): ] inputs[slices] = updates return inputs - - -def while_loop( - cond, - body, - loop_vars, - maximum_iterations=None, -): - current_iter = 0 - iteration_check = ( - lambda iter: maximum_iterations is None or iter < maximum_iterations - ) - loop_vars = tuple([convert_to_tensor(v) for v in loop_vars]) - while cond(*loop_vars) and iteration_check(current_iter): - loop_vars = body(*loop_vars) - if not isinstance(loop_vars, (list, tuple)): - loop_vars = (loop_vars,) - loop_vars = tuple(loop_vars) - current_iter += 1 - return loop_vars diff --git a/keras_core/losses/losses_test.py b/keras_core/losses/losses_test.py index 6401b48b6..29417c835 100644 --- a/keras_core/losses/losses_test.py +++ b/keras_core/losses/losses_test.py @@ -881,6 +881,14 @@ class BinaryCrossentropyTest(testing.TestCase): self.assertAlmostEqual(loss, 0.0) def test_unweighted(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="float32") + y_pred = np.array( + [[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype="float32" + ) + bce_obj = losses.BinaryCrossentropy() + loss = bce_obj(y_true, y_pred) + self.assertAllClose(loss, 0.20046903) + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2]) bce_obj = losses.BinaryCrossentropy() diff --git a/keras_core/metrics/accuracy_metrics.py b/keras_core/metrics/accuracy_metrics.py index 179c80f92..f42964245 100644 --- a/keras_core/metrics/accuracy_metrics.py +++ b/keras_core/metrics/accuracy_metrics.py @@ -57,12 +57,15 @@ class Accuracy(reduction_metrics.MeanMetricWrapper): return {"name": self.name, "dtype": self.dtype} +@keras_core_export("keras_core.metrics.binary_accuracy") def binary_accuracy(y_true, y_pred, threshold=0.5): + y_true = ops.convert_to_tensor(y_true) y_pred = ops.convert_to_tensor(y_pred) + y_true, y_pred = squeeze_to_same_rank(y_true, y_pred) threshold = ops.cast(threshold, y_pred.dtype) - y_pred = ops.cast(y_pred > threshold, y_pred.dtype) + y_pred = ops.cast(y_pred > threshold, y_true.dtype) return ops.mean( - ops.cast(ops.equal(y_true, y_pred), backend.floatx()), + ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()), axis=-1, ) @@ -114,12 +117,14 @@ class BinaryAccuracy(reduction_metrics.MeanMetricWrapper): return {"name": self.name, "dtype": self.dtype} +@keras_core_export("keras_core.metrics.categorical_accuracy") def categorical_accuracy(y_true, y_pred): y_true = ops.argmax(y_true, axis=-1) reshape_matches = False y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype) + y_true_org_shape = ops.shape(y_true) y_pred_rank = len(y_pred.shape) y_true_rank = len(y_true.shape) @@ -198,6 +203,7 @@ class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper): return {"name": self.name, "dtype": self.dtype} +@keras_core_export("keras_core.metrics.sparse_categorical_accuracy") def sparse_categorical_accuracy(y_true, y_pred): reshape_matches = False y_pred = ops.convert_to_tensor(y_pred) @@ -281,6 +287,7 @@ class SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper): return {"name": self.name, "dtype": self.dtype} +@keras_core_export("keras_core.metrics.top_k_categorical_accuracy") def top_k_categorical_accuracy(y_true, y_pred, k=5): reshape_matches = False y_pred = ops.convert_to_tensor(y_pred) @@ -357,6 +364,7 @@ class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper): return {"name": self.name, "dtype": self.dtype, "k": self.k} +@keras_core_export("keras_core.metrics.sparse_top_k_categorical_accuracy") def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): reshape_matches = False y_pred = ops.convert_to_tensor(y_pred) diff --git a/keras_core/operations/core.py b/keras_core/operations/core.py index 9f972f642..76a720f9c 100644 --- a/keras_core/operations/core.py +++ b/keras_core/operations/core.py @@ -126,63 +126,3 @@ def block_update(inputs, start_indices, updates): if any_symbolic_tensors((inputs, start_indices, updates)): return BlockUpdate().symbolic_call(inputs, start_indices, updates) return backend.core.block_update(inputs, start_indices, updates) - - -class WhileLoop(Operation): - def __init__(self, cond, body, maximum_iterations): - super().__init__() - self.cond = cond - self.body = body - self.maximum_iterations = maximum_iterations - - def call(self, loop_vars): - return backend.core.while_loop( - self.cond, - self.body, - loop_vars, - maximum_iterations=self.maximum_iterations, - ) - - def compute_output_spec(self, loop_vars): - return [KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars] - - -@keras_core_export("keras_core.operations.while_loop") -def while_loop( - cond, - body, - loop_vars, - maximum_iterations=None, -): - """While loop implemetation. - - Args: - cond: A callable that represents the termination condition of the loop. - Must have the same number of args as `loop_vars`, and return a bool. - body: A callable that represents the loop body. Must have the same - number of args as `loop_vars`, and return a list/tuple of the same - length, shape and dtype as `loop_vars`. - loop_vars: A list/tuple of tensors, the loop variables. - maximum_iterations: Optional maximum number of iterations of the while - loop to run. If provided, the `cond` output is AND-ed with an - additional condition ensuring the number of iterations executed is - no greater than `maximum_iterations`. - - Returns: - A list/tuple of tensors, has the same shape and dtype as `inputs`. - - Examples: - - >>> i = 0 - >>> cond = lambda i: i < 10 - >>> body = lambda i: i + 1 - >>> keras_core.operations.while_loop(cond, body, [i])[0] - 10 - """ - - return backend.core.while_loop( - cond, - body, - loop_vars, - maximum_iterations=maximum_iterations, - ) diff --git a/keras_core/operations/core_test.py b/keras_core/operations/core_test.py index 486df50ec..3a0f7aca4 100644 --- a/keras_core/operations/core_test.py +++ b/keras_core/operations/core_test.py @@ -139,22 +139,3 @@ class CoreOpsCorrectnessTest(testing.TestCase): updates = np.zeros([2, 2, 2, 2]) outputs = core.block_update(inputs, start_indices, updates) self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) - - def test_while_loop(self): - def cond(x, y): - return x[0, 0] < 10 - - def body(x, y): - return x + 1, y + 1 - - x = np.ones((2, 3)) - y = np.ones((3, 2)) - x, y = core.while_loop(cond, body, (x, y)) - self.assertAllClose(x, np.ones((2, 3)) * 10) - self.assertAllClose(y, np.ones((3, 2)) * 10) - - x = np.ones((2, 3)) - y = np.ones((3, 2)) - x, y = core.while_loop(cond, body, (x, y), maximum_iterations=5) - self.assertAllClose(x, np.ones((2, 3)) * 6) - self.assertAllClose(y, np.ones((3, 2)) * 6)