2d40cb20b9
* Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests * Adds RandomContrast and associated tests * Remove arithmetic comment * Adds RandomBrightness and tests * Fix docstring and format * Fix nits and add backend generator * Inlines random_contrast helper * Add bincount op * Add CategoryEncoding layer and tests * Fix formatting * Fix JAX issues * Fix JAX bincount * Formatting and small fix * Fix nits and docstrings * Add args to bincount op test
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from keras_core import layers
|
|
from keras_core import losses
|
|
from keras_core import models
|
|
from keras_core import metrics
|
|
from keras_core import optimizers
|
|
from keras_core.utils import rng_utils
|
|
|
|
|
|
def test_model_fit():
|
|
|
|
cpus = tf.config.list_physical_devices("CPU")
|
|
tf.config.set_logical_device_configuration(
|
|
cpus[0],
|
|
[
|
|
tf.config.LogicalDeviceConfiguration(),
|
|
tf.config.LogicalDeviceConfiguration(),
|
|
],
|
|
)
|
|
|
|
rng_utils.set_random_seed(1337)
|
|
|
|
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
|
|
with strategy.scope():
|
|
inputs = layers.Input((100,), batch_size=32)
|
|
x = layers.Dense(256, activation="relu")(inputs)
|
|
x = layers.Dense(256, activation="relu")(x)
|
|
x = layers.Dense(256, activation="relu")(x)
|
|
x = layers.BatchNormalization()(x)
|
|
outputs = layers.Dense(16)(x)
|
|
model = models.Model(inputs, outputs)
|
|
|
|
model.summary()
|
|
|
|
x = np.random.random((50000, 100))
|
|
y = np.random.random((50000, 16))
|
|
batch_size = 32
|
|
epochs = 5
|
|
|
|
with strategy.scope():
|
|
model.compile(
|
|
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),
|
|
loss=losses.MeanSquaredError(),
|
|
metrics=[metrics.MeanSquaredError()],
|
|
# TODO(scottzhu): Find out where is the variable that is not created eagerly
|
|
# and break the usage of XLA.
|
|
jit_compile=False,
|
|
)
|
|
history = model.fit(
|
|
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
|
)
|
|
|
|
print("History:")
|
|
print(history.history)
|
|
|
|
if __name__ == "__main__":
|
|
test_model_fit() |