Adds adam optimizer

This commit is contained in:
Neel Kovelamudi 2023-04-21 22:01:17 +00:00 committed by Francois Chollet
parent ce72fe1f42
commit f06a9ce037
10 changed files with 53 additions and 26 deletions

@ -4,4 +4,7 @@
"editor.rulers": [
80
],
}
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
}

@ -1,30 +1,33 @@
import numpy as np
from keras_core import Model
from keras_core import layers
from keras_core import losses
from keras_core import metrics
from keras_core import optimizers
from keras_core.models import Model
inputs = layers.Input((128,), batch_size=32)
x = layers.Dense(256)(inputs)
x = layers.Dense(256)(x)
x = layers.Dense(256)(x)
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dense(256, activation="relu")(x)
outputs = layers.Dense(16)(x)
model = Model(inputs, outputs)
model.summary()
x = np.random.random((50000, 128))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 10
epochs = 6
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
print("History:")
print(history.history)

@ -10,8 +10,8 @@ from keras_core import optimizers
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = layers.Dense(hidden_dim)
self.dense2 = layers.Dense(hidden_dim)
self.dense1 = layers.Dense(hidden_dim, activation="relu")
self.dense2 = layers.Dense(hidden_dim, activation="relu")
self.dense3 = layers.Dense(output_dim)
def call(self, x):
@ -21,19 +21,22 @@ class MyModel(Model):
model = MyModel(hidden_dim=256, output_dim=16)
model.summary()
x = np.random.random((50000, 128))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 10
epochs = 6
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
print("History:")
print(history.history)
model.summary()

@ -52,7 +52,7 @@ class MiniBatchNorm(Layer):
initializers.Zeros()(shape), trainable=False, name="mean"
)
self.variance = backend.Variable(
initializers.GlorotUniform()(shape),
initializers.Ones()(shape),
trainable=False,
name="variance",
)
@ -61,8 +61,8 @@ class MiniBatchNorm(Layer):
def call(self, inputs, training=False):
if training:
mean = jnp.mean(inputs, axis=(0,)) # TODO: extend to rank 3+
variance = jnp.var(inputs, axis=(0,))
mean = ops.mean(inputs, axis=(0,)) # TODO: extend to rank 3+
variance = ops.var(inputs, axis=(0,))
outputs = (inputs - mean) / (variance + self.epsilon)
self.variance.assign(
self.variance * self.momentum + variance * (1.0 - self.momentum)

@ -1,3 +1,10 @@
from keras_core import backend
def relu(x):
return backend.nn.relu(x)
def identity(x):
return x
@ -5,4 +12,6 @@ def identity(x):
def get(identifier):
if identifier is None:
return identity
if identifier == "relu":
return relu
return identifier

@ -238,7 +238,6 @@ class JAXTrainer(base_trainer.Trainer):
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
epochs=1,
)
val_logs = self.evaluate(
x=val_x,

@ -231,12 +231,11 @@ class TensorFlowTrainer(base_trainer.Trainer):
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
self._eval_epoch_iterator = TFEpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
epochs=1,
)
val_logs = self.evaluate(
x=val_x,

@ -126,8 +126,10 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
@keras_core_export(
"keras_core.metrics.mean_squared_error",
"keras_core.losses.mean_squared_error",
[
"keras_core.metrics.mean_squared_error",
"keras_core.losses.mean_squared_error",
]
)
def mean_squared_error(y_true, y_pred):
"""Computes the mean squared error between labels and predictions.
@ -158,8 +160,10 @@ def mean_squared_error(y_true, y_pred):
@keras_core_export(
"keras_core.metrics.mean_absolute_error",
"keras_core.losses.mean_absolute_error",
[
"keras_core.metrics.mean_absolute_error",
"keras_core.losses.mean_absolute_error",
]
)
def mean_absolute_error(y_true, y_pred):
"""Computes the mean absolute error between labels and predictions.
@ -188,8 +192,10 @@ def mean_absolute_error(y_true, y_pred):
@keras_core_export(
"keras_core.metrics.mean_absolute_percentage_error",
"keras_core.losses.mean_absolute_percentage_error",
[
"keras_core.metrics.mean_absolute_percentage_error",
"keras_core.losses.mean_absolute_percentage_error",
]
)
def mean_absolute_percentage_error(y_true, y_pred):
"""Computes the mean absolute percentage error between `y_true` & `y_pred`.
@ -223,8 +229,10 @@ def mean_absolute_percentage_error(y_true, y_pred):
@keras_core_export(
"keras_core.metrics.mean_squared_logarithmic_error",
"keras_core.losses.mean_squared_logarithmic_error",
[
"keras_core.metrics.mean_squared_logarithmic_error",
"keras_core.losses.mean_squared_logarithmic_error",
]
)
def mean_squared_logarithmic_error(y_true, y_pred):
"""Computes the mean squared logarithmic error between `y_true` & `y_pred`.

@ -21,6 +21,7 @@ class MeanSquaredError(reduction_metrics.MeanMetricWrapper):
>>> m.result()
0.25
"""
def __init__(self, name="mean_squared_error", dtype=None):
super().__init__(fn=mean_squared_error, name=name, dtype=dtype)

2
pyproject.toml Normal file

@ -0,0 +1,2 @@
[tool.black]
line-length = 80