2023-04-12 18:31:58 +00:00
|
|
|
import numpy as np
|
|
|
|
|
2023-04-19 17:29:25 +00:00
|
|
|
from keras_core import backend
|
2023-04-26 20:46:23 +00:00
|
|
|
from keras_core import layers
|
2023-04-19 18:09:43 +00:00
|
|
|
from keras_core import testing
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
2023-04-16 01:51:10 +00:00
|
|
|
class LayerTest(testing.TestCase):
|
2023-04-09 19:21:45 +00:00
|
|
|
def test_positional_arg_error(self):
|
2023-04-26 20:46:23 +00:00
|
|
|
class SomeLayer(layers.Layer):
|
2023-04-09 19:21:45 +00:00
|
|
|
def call(self, x, bool_arg):
|
|
|
|
if bool_arg:
|
|
|
|
return x
|
|
|
|
return x + 1
|
|
|
|
|
2023-04-19 17:29:25 +00:00
|
|
|
x = backend.KerasTensor(shape=(2, 3), name="x")
|
2023-04-12 18:00:14 +00:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Only input tensors may be passed as"
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
SomeLayer()(x, True)
|
|
|
|
|
|
|
|
# This works
|
|
|
|
SomeLayer()(x, bool_arg=True)
|
2023-04-19 17:29:25 +00:00
|
|
|
|
|
|
|
def test_rng_seed_tracking(self):
|
2023-04-26 20:46:23 +00:00
|
|
|
class RNGLayer(layers.Layer):
|
2023-04-19 17:29:25 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.seed_gen = backend.random.SeedGenerator(seed=1337)
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
return backend.random.dropout(x, rate=0.5, seed=self.seed_gen)
|
2023-04-19 18:09:43 +00:00
|
|
|
|
2023-04-19 17:29:25 +00:00
|
|
|
layer = RNGLayer()
|
|
|
|
self.assertEqual(layer.variables, [layer.seed_gen.state])
|
|
|
|
self.assertAllClose(layer.variables[0], [1337, 0])
|
|
|
|
layer(np.ones((3, 4)))
|
|
|
|
self.assertAllClose(layer.variables[0], [1337, 1])
|
|
|
|
|
|
|
|
# Test tracking in list attributes.
|
2023-04-26 20:46:23 +00:00
|
|
|
class RNGListLayer(layers.Layer):
|
2023-04-19 17:29:25 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.seed_gens = []
|
|
|
|
self.seed_gens.append(backend.random.SeedGenerator(seed=1))
|
|
|
|
self.seed_gens.append(backend.random.SeedGenerator(seed=10))
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[0])
|
|
|
|
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[1])
|
|
|
|
return x
|
2023-04-19 18:09:43 +00:00
|
|
|
|
2023-04-19 17:29:25 +00:00
|
|
|
layer = RNGListLayer()
|
2023-04-19 18:09:43 +00:00
|
|
|
self.assertEqual(
|
|
|
|
layer.variables,
|
|
|
|
[layer.seed_gens[0].state, layer.seed_gens[1].state],
|
|
|
|
)
|
2023-04-19 17:29:25 +00:00
|
|
|
self.assertAllClose(layer.variables[0], [1, 0])
|
|
|
|
self.assertAllClose(layer.variables[1], [10, 0])
|
|
|
|
layer(np.ones((3, 4)))
|
|
|
|
self.assertAllClose(layer.variables[0], [1, 1])
|
|
|
|
self.assertAllClose(layer.variables[1], [10, 1])
|
2023-04-26 17:29:40 +00:00
|
|
|
|
2023-04-26 20:46:23 +00:00
|
|
|
def test_layer_tracking(self):
|
|
|
|
class NestedLayer(layers.Layer):
|
|
|
|
def __init__(self, units):
|
|
|
|
super().__init__()
|
|
|
|
self.dense1 = layers.Dense(units)
|
|
|
|
self.layer_dict = {
|
|
|
|
"dense2": layers.Dense(units),
|
|
|
|
}
|
|
|
|
self.layer_list = [layers.Dense(units)]
|
|
|
|
self.units = units
|
|
|
|
|
|
|
|
def build(self, input_shape):
|
|
|
|
self.layer_list.append(layers.Dense(self.units))
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
x = self.dense1(x)
|
|
|
|
x = self.layer_dict["dense2"](x)
|
|
|
|
x = self.layer_list[0](x)
|
|
|
|
x = self.layer_list[1](x)
|
|
|
|
return x
|
|
|
|
|
2023-04-26 21:54:00 +00:00
|
|
|
class DoubleNestedLayer(layers.Layer):
|
|
|
|
def __init__(self, units):
|
|
|
|
super().__init__()
|
|
|
|
self.inner_layer = NestedLayer(units)
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
return self.inner_layer(x)
|
|
|
|
|
2023-04-26 20:46:23 +00:00
|
|
|
layer = NestedLayer(3)
|
|
|
|
layer.build((1, 3))
|
|
|
|
self.assertLen(layer._layers, 4)
|
|
|
|
layer(np.zeros((1, 3)))
|
|
|
|
self.assertLen(layer.weights, 8)
|
|
|
|
|
2023-04-26 21:54:00 +00:00
|
|
|
layer = DoubleNestedLayer(3)
|
|
|
|
self.assertLen(layer._layers, 1)
|
|
|
|
layer(np.zeros((1, 3)))
|
|
|
|
self.assertLen(layer.inner_layer.weights, 8)
|
|
|
|
self.assertLen(layer.weights, 8)
|
|
|
|
|
2023-04-26 17:29:40 +00:00
|
|
|
def test_build_on_call(self):
|
2023-04-26 20:46:23 +00:00
|
|
|
class LayerWithUnbuiltState(layers.Layer):
|
|
|
|
def __init__(self, units):
|
|
|
|
super().__init__()
|
|
|
|
self.dense1 = layers.Dense(units)
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
return self.dense1(x)
|
|
|
|
|
|
|
|
layer = LayerWithUnbuiltState(2)
|
|
|
|
layer(backend.KerasTensor((3, 4)))
|
|
|
|
self.assertLen(layer.weights, 2)
|
|
|
|
|
|
|
|
class KwargsLayerWithUnbuiltState(layers.Layer):
|
|
|
|
def __init__(self, units):
|
|
|
|
super().__init__()
|
|
|
|
self.dense1 = layers.Dense(units)
|
|
|
|
self.dense2 = layers.Dense(units)
|
|
|
|
|
|
|
|
def call(self, x1, x2):
|
|
|
|
return self.dense1(x1) + self.dense2(x2)
|
|
|
|
|
|
|
|
layer = KwargsLayerWithUnbuiltState(2)
|
|
|
|
layer(backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4)))
|
|
|
|
self.assertLen(layer.weights, 4)
|
|
|
|
|
|
|
|
layer = KwargsLayerWithUnbuiltState(2)
|
|
|
|
layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4)))
|
|
|
|
self.assertLen(layer.weights, 4)
|
|
|
|
|
|
|
|
def test_activity_regularization(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_add_loss(self):
|
2023-04-26 17:29:40 +00:00
|
|
|
pass
|