Fix some tests
This commit is contained in:
parent
242f984013
commit
1fc98ab59b
@ -19,7 +19,7 @@ class InitializersTest(testing.TestCase):
|
||||
self.assertEqual(initializer.stddev, stddev)
|
||||
self.assertEqual(initializer.seed, seed)
|
||||
self.assertEqual(values.shape, shape)
|
||||
self.test_load_external_config(initializer, external_config)
|
||||
self.assert_idempotent_config(initializer, external_config)
|
||||
|
||||
def test_random_uniform(self):
|
||||
shape = (5, 5)
|
||||
@ -35,11 +35,11 @@ class InitializersTest(testing.TestCase):
|
||||
self.assertEqual(initializer.maxval, maxval)
|
||||
self.assertEqual(initializer.seed, seed)
|
||||
self.assertEqual(values.shape, shape)
|
||||
self.test_load_external_config(initializer, external_config)
|
||||
self.assert_idempotent_config(initializer, external_config)
|
||||
values = values.numpy()
|
||||
self.assertGreaterEqual(np.min(values), minval)
|
||||
self.assertLess(np.max(values), maxval)
|
||||
|
||||
def test_load_external_config(self, initializer, config):
|
||||
def assert_idempotent_config(self, initializer, config):
|
||||
initializer = initializer.from_config(config)
|
||||
self.assertEqual(initializer.get_config(), config)
|
||||
|
@ -11,10 +11,10 @@ class ExampleMetric(Metric):
|
||||
def __init__(self, name="mean_square_error", dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self.sum = self.add_variable(
|
||||
name="sum", initializer=initializers.Zeros()
|
||||
name="sum", shape=(), initializer=initializers.Zeros()
|
||||
)
|
||||
self.total = self.add_variable(
|
||||
name="total", initializer=initializers.Zeros(), dtype="int32"
|
||||
name="total", shape=(), initializer=initializers.Zeros(), dtype="int32"
|
||||
)
|
||||
|
||||
def update_state(self, y_true, y_pred):
|
||||
|
@ -7,10 +7,10 @@ class MeanSquareError(Metric):
|
||||
def __init__(self, name="mean_square_error", dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self.sum = self.add_variable(
|
||||
name="sum", initializer=initializers.Zeros()
|
||||
name="sum", shape=(), initializer=initializers.Zeros()
|
||||
)
|
||||
self.total = self.add_variable(
|
||||
name="total", initializer=initializers.Zeros()
|
||||
name="total", shape=(), initializer=initializers.Zeros()
|
||||
)
|
||||
|
||||
def update_state(self, y_true, y_pred):
|
||||
|
@ -996,6 +996,20 @@ def square(x):
|
||||
return backend.execute("square", x)
|
||||
|
||||
|
||||
class Sqrt(Operation):
|
||||
def call(self, x):
|
||||
return backend.execute("sqrt", x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape, dtype=x.dtype)
|
||||
|
||||
|
||||
def sqrt(x):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Sqrt().symbolic_call(x)
|
||||
return backend.execute("sqrt", x)
|
||||
|
||||
|
||||
class Squeeze(Operation):
|
||||
def __init__(self, axis=None):
|
||||
super().__init__()
|
||||
|
@ -345,6 +345,6 @@ def validate_float_arg(value, name):
|
||||
return float(value)
|
||||
|
||||
|
||||
def l2_normalize(x):
|
||||
l2_norm = ops.sqrt(ops.sum(ops.square(x)))
|
||||
def l2_normalize(x, axis=0):
|
||||
l2_norm = ops.sqrt(ops.sum(ops.square(x), axis=axis))
|
||||
return x / l2_norm
|
||||
|
Loading…
Reference in New Issue
Block a user