keras/integration_tests/distribute_training_test.py
Neel Kovelamudi 2d40cb20b9 Adds CategoryEncoding layer, bincount op, and tests (#161)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
2023-05-14 00:07:43 +00:00

59 lines
1.7 KiB
Python

import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import losses
from keras_core import models
from keras_core import metrics
from keras_core import optimizers
from keras_core.utils import rng_utils
def test_model_fit():
cpus = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
cpus[0],
[
tf.config.LogicalDeviceConfiguration(),
tf.config.LogicalDeviceConfiguration(),
],
)
rng_utils.set_random_seed(1337)
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
with strategy.scope():
inputs = layers.Input((100,), batch_size=32)
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(16)(x)
model = models.Model(inputs, outputs)
model.summary()
x = np.random.random((50000, 100))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 5
with strategy.scope():
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable that is not created eagerly
# and break the usage of XLA.
jit_compile=False,
)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
print("History:")
print(history.history)
if __name__ == "__main__":
test_model_fit()