Add a few more tests.
This commit is contained in:
parent
b804a5b608
commit
08b67d8f1e
@ -35,11 +35,15 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
|
||||
# Compute gradients
|
||||
# TODO: move value conversion to TF
|
||||
trainable_weights = [v.value for v in self.trainable_weights]
|
||||
gradients = tape.gradient(loss, trainable_weights)
|
||||
if self.trainable_weights:
|
||||
trainable_weights = [v.value for v in self.trainable_weights]
|
||||
gradients = tape.gradient(loss, trainable_weights)
|
||||
|
||||
# Update weights
|
||||
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||
else:
|
||||
warnings.warn("The model does not have any trainable weights.")
|
||||
|
||||
# Update weights
|
||||
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
|
||||
|
||||
def test_step(self, data):
|
||||
|
@ -42,7 +42,12 @@ from keras_core.utils.tracking import Tracker
|
||||
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
|
||||
class Layer(Operation):
|
||||
def __init__(
|
||||
self, activity_regularizer=None, trainable=True, dtype=None, name=None
|
||||
self,
|
||||
*,
|
||||
activity_regularizer=None,
|
||||
trainable=True,
|
||||
dtype=None,
|
||||
name=None,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self.activity_regularizer = regularizers.get(activity_regularizer)
|
||||
@ -594,7 +599,11 @@ class Layer(Operation):
|
||||
try:
|
||||
self.call(input_tensors)
|
||||
return True
|
||||
except:
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"Error when attempting to automatically build "
|
||||
f"the layer by tracing it: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _build_by_run_for_kwargs(self, shapes_dict):
|
||||
@ -602,12 +611,18 @@ class Layer(Operation):
|
||||
if all(is_shape_tuple(s) for s in shapes_dict.values()):
|
||||
# Case: all input keyword arguments were plain tensors.
|
||||
input_tensors = {
|
||||
k: backend.traceable_tensor(shape)
|
||||
# We strip the `_shape` suffix to recover kwarg names.
|
||||
k[:-6]: backend.traceable_tensor(shape)
|
||||
for k, shape in shapes_dict.items()
|
||||
}
|
||||
try:
|
||||
self.call(**input_tensors)
|
||||
except:
|
||||
return True
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"Error when attempting to automatically build "
|
||||
f"the layer by tracing it: {e}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# Not supported: nested input keyword arguments.
|
||||
|
@ -1,13 +1,13 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
class LayerTest(testing.TestCase):
|
||||
def test_positional_arg_error(self):
|
||||
class SomeLayer(Layer):
|
||||
class SomeLayer(layers.Layer):
|
||||
def call(self, x, bool_arg):
|
||||
if bool_arg:
|
||||
return x
|
||||
@ -23,7 +23,7 @@ class LayerTest(testing.TestCase):
|
||||
SomeLayer()(x, bool_arg=True)
|
||||
|
||||
def test_rng_seed_tracking(self):
|
||||
class RNGLayer(Layer):
|
||||
class RNGLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seed_gen = backend.random.SeedGenerator(seed=1337)
|
||||
@ -38,7 +38,7 @@ class LayerTest(testing.TestCase):
|
||||
self.assertAllClose(layer.variables[0], [1337, 1])
|
||||
|
||||
# Test tracking in list attributes.
|
||||
class RNGListLayer(Layer):
|
||||
class RNGListLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seed_gens = []
|
||||
@ -61,5 +61,65 @@ class LayerTest(testing.TestCase):
|
||||
self.assertAllClose(layer.variables[0], [1, 1])
|
||||
self.assertAllClose(layer.variables[1], [10, 1])
|
||||
|
||||
def test_layer_tracking(self):
|
||||
class NestedLayer(layers.Layer):
|
||||
def __init__(self, units):
|
||||
super().__init__()
|
||||
self.dense1 = layers.Dense(units)
|
||||
self.layer_dict = {
|
||||
"dense2": layers.Dense(units),
|
||||
}
|
||||
self.layer_list = [layers.Dense(units)]
|
||||
self.units = units
|
||||
|
||||
def build(self, input_shape):
|
||||
self.layer_list.append(layers.Dense(self.units))
|
||||
|
||||
def call(self, x):
|
||||
x = self.dense1(x)
|
||||
x = self.layer_dict["dense2"](x)
|
||||
x = self.layer_list[0](x)
|
||||
x = self.layer_list[1](x)
|
||||
return x
|
||||
|
||||
layer = NestedLayer(3)
|
||||
layer.build((1, 3))
|
||||
self.assertLen(layer._layers, 4)
|
||||
layer(np.zeros((1, 3)))
|
||||
self.assertLen(layer.weights, 8)
|
||||
|
||||
def test_build_on_call(self):
|
||||
class LayerWithUnbuiltState(layers.Layer):
|
||||
def __init__(self, units):
|
||||
super().__init__()
|
||||
self.dense1 = layers.Dense(units)
|
||||
|
||||
def call(self, x):
|
||||
return self.dense1(x)
|
||||
|
||||
layer = LayerWithUnbuiltState(2)
|
||||
layer(backend.KerasTensor((3, 4)))
|
||||
self.assertLen(layer.weights, 2)
|
||||
|
||||
class KwargsLayerWithUnbuiltState(layers.Layer):
|
||||
def __init__(self, units):
|
||||
super().__init__()
|
||||
self.dense1 = layers.Dense(units)
|
||||
self.dense2 = layers.Dense(units)
|
||||
|
||||
def call(self, x1, x2):
|
||||
return self.dense1(x1) + self.dense2(x2)
|
||||
|
||||
layer = KwargsLayerWithUnbuiltState(2)
|
||||
layer(backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4)))
|
||||
self.assertLen(layer.weights, 4)
|
||||
|
||||
layer = KwargsLayerWithUnbuiltState(2)
|
||||
layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4)))
|
||||
self.assertLen(layer.weights, 4)
|
||||
|
||||
def test_activity_regularization(self):
|
||||
pass
|
||||
|
||||
def test_add_loss(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user