Fix demos

This commit is contained in:
Francois Chollet 2023-05-17 16:06:18 -07:00
parent 4679cdd3ab
commit f94df3479a
3 changed files with 13 additions and 11 deletions

@ -58,20 +58,19 @@ def loss_fn(y_true, y_pred):
model = MyModel(hidden_dim=256, output_dim=4)
optimizer = optimizers.SGD(learning_rate=0.0001)
optimizer = optimizers.SGD(learning_rate=0.001)
dataset = Dataset()
######### Custom JAX workflow ###############
# Build model
x = ops.convert_to_tensor(np.random.random((1, 128)))
x = np.random.random((1, 128))
model(x)
# Build optimizer
optimizer.build(model.trainable_variables)
######### Custom JAX workflow ###############
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, x, y
):

@ -1,5 +1,6 @@
import numpy as np
import keras_core
from keras_core import Model
from keras_core import backend
from keras_core import initializers
@ -42,11 +43,11 @@ class MyDropout(layers.Layer):
# Use seed_generator for managing RNG state.
# It is a state element and its seed variable is
# tracked as part of `layer.variables`.
self.seed_generator = backend.random.SeedGenerator(1337)
self.seed_generator = keras_core.random.SeedGenerator(1337)
def call(self, inputs):
# Use `backend.random` for random ops.
return backend.random.dropout(
# Use `keras_core.random` for random ops.
return keras_core.random.dropout(
inputs, self.rate, seed=self.seed_generator
)

@ -8,8 +8,10 @@ from keras_core import optimizers
inputs = layers.Input((100,), batch_size=32)
x = layers.Dense(256, activation="relu")(inputs)
residual = x
x = layers.Dense(256, activation="relu")(x)
x = layers.Dense(256, activation="relu")(x)
x += residual
outputs = layers.Dense(16)(x)
model = Model(inputs, outputs)
@ -21,9 +23,9 @@ batch_size = 32
epochs = 5
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
optimizer=optimizers.Adam(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
metrics=[metrics.CategoricalAccuracy(name="acc"), metrics.MeanSquaredError(name="mse")],
)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2