diff --git a/examples/demo_custom_jax_workflow.py b/examples/demo_custom_jax_workflow.py index c5651942b..204173e0f 100644 --- a/examples/demo_custom_jax_workflow.py +++ b/examples/demo_custom_jax_workflow.py @@ -58,20 +58,19 @@ def loss_fn(y_true, y_pred): model = MyModel(hidden_dim=256, output_dim=4) -optimizer = optimizers.SGD(learning_rate=0.0001) +optimizer = optimizers.SGD(learning_rate=0.001) dataset = Dataset() - -######### Custom JAX workflow ############### - - # Build model -x = ops.convert_to_tensor(np.random.random((1, 128))) +x = np.random.random((1, 128)) model(x) # Build optimizer optimizer.build(model.trainable_variables) +######### Custom JAX workflow ############### + + def compute_loss_and_updates( trainable_variables, non_trainable_variables, x, y ): diff --git a/examples/demo_custom_layer_backend_agnostic.py b/examples/demo_custom_layer_backend_agnostic.py index 924fc25ec..8083dbcac 100644 --- a/examples/demo_custom_layer_backend_agnostic.py +++ b/examples/demo_custom_layer_backend_agnostic.py @@ -1,5 +1,6 @@ import numpy as np +import keras_core from keras_core import Model from keras_core import backend from keras_core import initializers @@ -42,11 +43,11 @@ class MyDropout(layers.Layer): # Use seed_generator for managing RNG state. # It is a state element and its seed variable is # tracked as part of `layer.variables`. - self.seed_generator = backend.random.SeedGenerator(1337) + self.seed_generator = keras_core.random.SeedGenerator(1337) def call(self, inputs): - # Use `backend.random` for random ops. - return backend.random.dropout( + # Use `keras_core.random` for random ops. + return keras_core.random.dropout( inputs, self.rate, seed=self.seed_generator ) diff --git a/examples/demo_functional.py b/examples/demo_functional.py index 04a733fb5..ba607fd0c 100644 --- a/examples/demo_functional.py +++ b/examples/demo_functional.py @@ -8,8 +8,10 @@ from keras_core import optimizers inputs = layers.Input((100,), batch_size=32) x = layers.Dense(256, activation="relu")(inputs) +residual = x x = layers.Dense(256, activation="relu")(x) x = layers.Dense(256, activation="relu")(x) +x += residual outputs = layers.Dense(16)(x) model = Model(inputs, outputs) @@ -21,9 +23,9 @@ batch_size = 32 epochs = 5 model.compile( - optimizer=optimizers.SGD(learning_rate=0.001), + optimizer=optimizers.Adam(learning_rate=0.001), loss=losses.MeanSquaredError(), - metrics=[metrics.MeanSquaredError()], + metrics=[metrics.CategoricalAccuracy(name="acc"), metrics.MeanSquaredError(name="mse")], ) history = model.fit( x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2