Fix some tests

This commit is contained in:
Francois Chollet 2023-04-12 20:40:23 -07:00
parent 242f984013
commit 1fc98ab59b
5 changed files with 23 additions and 9 deletions

@ -19,7 +19,7 @@ class InitializersTest(testing.TestCase):
self.assertEqual(initializer.stddev, stddev)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.test_load_external_config(initializer, external_config)
self.assert_idempotent_config(initializer, external_config)
def test_random_uniform(self):
shape = (5, 5)
@ -35,11 +35,11 @@ class InitializersTest(testing.TestCase):
self.assertEqual(initializer.maxval, maxval)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.test_load_external_config(initializer, external_config)
self.assert_idempotent_config(initializer, external_config)
values = values.numpy()
self.assertGreaterEqual(np.min(values), minval)
self.assertLess(np.max(values), maxval)
def test_load_external_config(self, initializer, config):
def assert_idempotent_config(self, initializer, config):
initializer = initializer.from_config(config)
self.assertEqual(initializer.get_config(), config)

@ -11,10 +11,10 @@ class ExampleMetric(Metric):
def __init__(self, name="mean_square_error", dtype=None):
super().__init__(name=name, dtype=dtype)
self.sum = self.add_variable(
name="sum", initializer=initializers.Zeros()
name="sum", shape=(), initializer=initializers.Zeros()
)
self.total = self.add_variable(
name="total", initializer=initializers.Zeros(), dtype="int32"
name="total", shape=(), initializer=initializers.Zeros(), dtype="int32"
)
def update_state(self, y_true, y_pred):

@ -7,10 +7,10 @@ class MeanSquareError(Metric):
def __init__(self, name="mean_square_error", dtype=None):
super().__init__(name=name, dtype=dtype)
self.sum = self.add_variable(
name="sum", initializer=initializers.Zeros()
name="sum", shape=(), initializer=initializers.Zeros()
)
self.total = self.add_variable(
name="total", initializer=initializers.Zeros()
name="total", shape=(), initializer=initializers.Zeros()
)
def update_state(self, y_true, y_pred):

@ -996,6 +996,20 @@ def square(x):
return backend.execute("square", x)
class Sqrt(Operation):
def call(self, x):
return backend.execute("sqrt", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def sqrt(x):
if any_symbolic_tensors((x,)):
return Sqrt().symbolic_call(x)
return backend.execute("sqrt", x)
class Squeeze(Operation):
def __init__(self, axis=None):
super().__init__()

@ -345,6 +345,6 @@ def validate_float_arg(value, name):
return float(value)
def l2_normalize(x):
l2_norm = ops.sqrt(ops.sum(ops.square(x)))
def l2_normalize(x, axis=0):
l2_norm = ops.sqrt(ops.sum(ops.square(x), axis=axis))
return x / l2_norm