keras/keras_core/layers/layer_test.py

63 lines
2.1 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import testing
from keras_core.layers.layer import Layer
class LayerTest(testing.TestCase):
def test_positional_arg_error(self):
class SomeLayer(Layer):
def call(self, x, bool_arg):
if bool_arg:
return x
return x + 1
x = backend.KerasTensor(shape=(2, 3), name="x")
with self.assertRaisesRegex(
ValueError, "Only input tensors may be passed as"
):
SomeLayer()(x, True)
# This works
SomeLayer()(x, bool_arg=True)
def test_rng_seed_tracking(self):
class RNGLayer(Layer):
def __init__(self):
super().__init__()
self.seed_gen = backend.random.SeedGenerator(seed=1337)
def call(self, x):
return backend.random.dropout(x, rate=0.5, seed=self.seed_gen)
layer = RNGLayer()
self.assertEqual(layer.variables, [layer.seed_gen.state])
self.assertAllClose(layer.variables[0], [1337, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1337, 1])
# Test tracking in list attributes.
class RNGListLayer(Layer):
def __init__(self):
super().__init__()
self.seed_gens = []
self.seed_gens.append(backend.random.SeedGenerator(seed=1))
self.seed_gens.append(backend.random.SeedGenerator(seed=10))
def call(self, x):
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[0])
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[1])
return x
layer = RNGListLayer()
self.assertEqual(
layer.variables,
[layer.seed_gens[0].state, layer.seed_gens[1].state],
)
self.assertAllClose(layer.variables[0], [1, 0])
self.assertAllClose(layer.variables[1], [10, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1, 1])
self.assertAllClose(layer.variables[1], [10, 1])