diff --git a/keras_core/initializers/random_initializers_test.py b/keras_core/initializers/random_initializers_test.py index 38ab70bcc..a2ad291d0 100644 --- a/keras_core/initializers/random_initializers_test.py +++ b/keras_core/initializers/random_initializers_test.py @@ -19,7 +19,7 @@ class InitializersTest(testing.TestCase): self.assertEqual(initializer.stddev, stddev) self.assertEqual(initializer.seed, seed) self.assertEqual(values.shape, shape) - self.test_load_external_config(initializer, external_config) + self.assert_idempotent_config(initializer, external_config) def test_random_uniform(self): shape = (5, 5) @@ -35,11 +35,11 @@ class InitializersTest(testing.TestCase): self.assertEqual(initializer.maxval, maxval) self.assertEqual(initializer.seed, seed) self.assertEqual(values.shape, shape) - self.test_load_external_config(initializer, external_config) + self.assert_idempotent_config(initializer, external_config) values = values.numpy() self.assertGreaterEqual(np.min(values), minval) self.assertLess(np.max(values), maxval) - def test_load_external_config(self, initializer, config): + def assert_idempotent_config(self, initializer, config): initializer = initializer.from_config(config) self.assertEqual(initializer.get_config(), config) diff --git a/keras_core/metrics/metric_test.py b/keras_core/metrics/metric_test.py index 878928c09..19a02fa68 100644 --- a/keras_core/metrics/metric_test.py +++ b/keras_core/metrics/metric_test.py @@ -11,10 +11,10 @@ class ExampleMetric(Metric): def __init__(self, name="mean_square_error", dtype=None): super().__init__(name=name, dtype=dtype) self.sum = self.add_variable( - name="sum", initializer=initializers.Zeros() + name="sum", shape=(), initializer=initializers.Zeros() ) self.total = self.add_variable( - name="total", initializer=initializers.Zeros(), dtype="int32" + name="total", shape=(), initializer=initializers.Zeros(), dtype="int32" ) def update_state(self, y_true, y_pred): diff --git a/keras_core/metrics/regression_metrics.py b/keras_core/metrics/regression_metrics.py index ffcbee696..e64a85ab0 100644 --- a/keras_core/metrics/regression_metrics.py +++ b/keras_core/metrics/regression_metrics.py @@ -7,10 +7,10 @@ class MeanSquareError(Metric): def __init__(self, name="mean_square_error", dtype=None): super().__init__(name=name, dtype=dtype) self.sum = self.add_variable( - name="sum", initializer=initializers.Zeros() + name="sum", shape=(), initializer=initializers.Zeros() ) self.total = self.add_variable( - name="total", initializer=initializers.Zeros() + name="total", shape=(), initializer=initializers.Zeros() ) def update_state(self, y_true, y_pred): diff --git a/keras_core/operations/numpy.py b/keras_core/operations/numpy.py index a027ddcc6..904134bd8 100644 --- a/keras_core/operations/numpy.py +++ b/keras_core/operations/numpy.py @@ -996,6 +996,20 @@ def square(x): return backend.execute("square", x) +class Sqrt(Operation): + def call(self, x): + return backend.execute("sqrt", x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +def sqrt(x): + if any_symbolic_tensors((x,)): + return Sqrt().symbolic_call(x) + return backend.execute("sqrt", x) + + class Squeeze(Operation): def __init__(self, axis=None): super().__init__() diff --git a/keras_core/regularizers/regularizers.py b/keras_core/regularizers/regularizers.py index 03cd8847d..a0d28edac 100644 --- a/keras_core/regularizers/regularizers.py +++ b/keras_core/regularizers/regularizers.py @@ -345,6 +345,6 @@ def validate_float_arg(value, name): return float(value) -def l2_normalize(x): - l2_norm = ops.sqrt(ops.sum(ops.square(x))) +def l2_normalize(x, axis=0): + l2_norm = ops.sqrt(ops.sum(ops.square(x), axis=axis)) return x / l2_norm