Adds adam optimizer
This commit is contained in:
parent
ce72fe1f42
commit
f06a9ce037
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@ -4,4 +4,7 @@
|
||||
"editor.rulers": [
|
||||
80
|
||||
],
|
||||
}
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.linting.enabled": true,
|
||||
}
|
||||
|
@ -1,30 +1,33 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import Model
|
||||
from keras_core import layers
|
||||
from keras_core import losses
|
||||
from keras_core import metrics
|
||||
from keras_core import optimizers
|
||||
from keras_core.models import Model
|
||||
|
||||
inputs = layers.Input((128,), batch_size=32)
|
||||
x = layers.Dense(256)(inputs)
|
||||
x = layers.Dense(256)(x)
|
||||
x = layers.Dense(256)(x)
|
||||
x = layers.Dense(256, activation="relu")(inputs)
|
||||
x = layers.Dense(256, activation="relu")(x)
|
||||
x = layers.Dense(256, activation="relu")(x)
|
||||
outputs = layers.Dense(16)(x)
|
||||
model = Model(inputs, outputs)
|
||||
|
||||
model.summary()
|
||||
|
||||
x = np.random.random((50000, 128))
|
||||
y = np.random.random((50000, 16))
|
||||
batch_size = 32
|
||||
epochs = 10
|
||||
epochs = 6
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.SGD(learning_rate=0.001),
|
||||
loss=losses.MeanSquaredError(),
|
||||
metrics=[metrics.MeanSquaredError()],
|
||||
)
|
||||
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
|
||||
history = model.fit(
|
||||
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
||||
)
|
||||
|
||||
print("History:")
|
||||
print(history.history)
|
||||
|
@ -10,8 +10,8 @@ from keras_core import optimizers
|
||||
class MyModel(Model):
|
||||
def __init__(self, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.dense1 = layers.Dense(hidden_dim)
|
||||
self.dense2 = layers.Dense(hidden_dim)
|
||||
self.dense1 = layers.Dense(hidden_dim, activation="relu")
|
||||
self.dense2 = layers.Dense(hidden_dim, activation="relu")
|
||||
self.dense3 = layers.Dense(output_dim)
|
||||
|
||||
def call(self, x):
|
||||
@ -21,19 +21,22 @@ class MyModel(Model):
|
||||
|
||||
|
||||
model = MyModel(hidden_dim=256, output_dim=16)
|
||||
model.summary()
|
||||
|
||||
x = np.random.random((50000, 128))
|
||||
y = np.random.random((50000, 16))
|
||||
batch_size = 32
|
||||
epochs = 10
|
||||
epochs = 6
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.SGD(learning_rate=0.001),
|
||||
loss=losses.MeanSquaredError(),
|
||||
metrics=[metrics.MeanSquaredError()],
|
||||
)
|
||||
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
|
||||
history = model.fit(
|
||||
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
||||
)
|
||||
|
||||
print("History:")
|
||||
print(history.history)
|
||||
|
||||
model.summary()
|
||||
|
@ -52,7 +52,7 @@ class MiniBatchNorm(Layer):
|
||||
initializers.Zeros()(shape), trainable=False, name="mean"
|
||||
)
|
||||
self.variance = backend.Variable(
|
||||
initializers.GlorotUniform()(shape),
|
||||
initializers.Ones()(shape),
|
||||
trainable=False,
|
||||
name="variance",
|
||||
)
|
||||
@ -61,8 +61,8 @@ class MiniBatchNorm(Layer):
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
if training:
|
||||
mean = jnp.mean(inputs, axis=(0,)) # TODO: extend to rank 3+
|
||||
variance = jnp.var(inputs, axis=(0,))
|
||||
mean = ops.mean(inputs, axis=(0,)) # TODO: extend to rank 3+
|
||||
variance = ops.var(inputs, axis=(0,))
|
||||
outputs = (inputs - mean) / (variance + self.epsilon)
|
||||
self.variance.assign(
|
||||
self.variance * self.momentum + variance * (1.0 - self.momentum)
|
||||
|
@ -1,3 +1,10 @@
|
||||
from keras_core import backend
|
||||
|
||||
|
||||
def relu(x):
|
||||
return backend.nn.relu(x)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
@ -5,4 +12,6 @@ def identity(x):
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return identity
|
||||
if identifier == "relu":
|
||||
return relu
|
||||
return identifier
|
||||
|
@ -238,7 +238,6 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
y=val_y,
|
||||
sample_weight=val_sample_weight,
|
||||
batch_size=validation_batch_size or batch_size,
|
||||
epochs=1,
|
||||
)
|
||||
val_logs = self.evaluate(
|
||||
x=val_x,
|
||||
|
@ -231,12 +231,11 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
if validation_data and self._should_eval(epoch, validation_freq):
|
||||
# Create EpochIterator for evaluation and cache it.
|
||||
if getattr(self, "_eval_epoch_iterator", None) is None:
|
||||
self._eval_epoch_iterator = EpochIterator(
|
||||
self._eval_epoch_iterator = TFEpochIterator(
|
||||
x=val_x,
|
||||
y=val_y,
|
||||
sample_weight=val_sample_weight,
|
||||
batch_size=validation_batch_size or batch_size,
|
||||
epochs=1,
|
||||
)
|
||||
val_logs = self.evaluate(
|
||||
x=val_x,
|
||||
|
@ -126,8 +126,10 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
"keras_core.metrics.mean_squared_error",
|
||||
"keras_core.losses.mean_squared_error",
|
||||
[
|
||||
"keras_core.metrics.mean_squared_error",
|
||||
"keras_core.losses.mean_squared_error",
|
||||
]
|
||||
)
|
||||
def mean_squared_error(y_true, y_pred):
|
||||
"""Computes the mean squared error between labels and predictions.
|
||||
@ -158,8 +160,10 @@ def mean_squared_error(y_true, y_pred):
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
"keras_core.metrics.mean_absolute_error",
|
||||
"keras_core.losses.mean_absolute_error",
|
||||
[
|
||||
"keras_core.metrics.mean_absolute_error",
|
||||
"keras_core.losses.mean_absolute_error",
|
||||
]
|
||||
)
|
||||
def mean_absolute_error(y_true, y_pred):
|
||||
"""Computes the mean absolute error between labels and predictions.
|
||||
@ -188,8 +192,10 @@ def mean_absolute_error(y_true, y_pred):
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
"keras_core.metrics.mean_absolute_percentage_error",
|
||||
"keras_core.losses.mean_absolute_percentage_error",
|
||||
[
|
||||
"keras_core.metrics.mean_absolute_percentage_error",
|
||||
"keras_core.losses.mean_absolute_percentage_error",
|
||||
]
|
||||
)
|
||||
def mean_absolute_percentage_error(y_true, y_pred):
|
||||
"""Computes the mean absolute percentage error between `y_true` & `y_pred`.
|
||||
@ -223,8 +229,10 @@ def mean_absolute_percentage_error(y_true, y_pred):
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
"keras_core.metrics.mean_squared_logarithmic_error",
|
||||
"keras_core.losses.mean_squared_logarithmic_error",
|
||||
[
|
||||
"keras_core.metrics.mean_squared_logarithmic_error",
|
||||
"keras_core.losses.mean_squared_logarithmic_error",
|
||||
]
|
||||
)
|
||||
def mean_squared_logarithmic_error(y_true, y_pred):
|
||||
"""Computes the mean squared logarithmic error between `y_true` & `y_pred`.
|
||||
|
@ -21,6 +21,7 @@ class MeanSquaredError(reduction_metrics.MeanMetricWrapper):
|
||||
>>> m.result()
|
||||
0.25
|
||||
"""
|
||||
|
||||
def __init__(self, name="mean_squared_error", dtype=None):
|
||||
super().__init__(fn=mean_squared_error, name=name, dtype=dtype)
|
||||
|
||||
|
2
pyproject.toml
Normal file
2
pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.black]
|
||||
line-length = 80
|
Loading…
Reference in New Issue
Block a user