keras/integration_tests/distribute_training_test.py
Qianli Scott Zhu 46378e883f Add distribution strategy support for model.fit/eval/predict. (#119)
* Add unit/integration test for tf.distribute.

* Fix format

* Skip the test case for non-tf backend

* Fix typo

* Fix format and unit test context config.

* Address review comments.

* Add support for h5 weights loading.

* Fix test

* Add support for a -1 dimension in the `Reshape` operation. (#103)

The code to compute the output shape is now shared between the `Reshape` operation and the `Reshape` layer.

* Added ReLU activation layer (#104)

* added relu

* add relu

* added correctness test

* reformated

* updates based on review

* Fix docstring

* Added R2score (#106)

* Add meanX metrics

* All regression metrics except for root mean squared error

* Formatting issues

* Add RootMeanSquaredError

* Docstring spacing

* Line too long fix

* Add R2Score

* Docstring fixes

* Fix test

* Fix tests

* Adds RemoteMonitor Callback (#108)

* Add Remote Monitor Callback

* Add Remote Monitor Callback

* Add Remote Monitor Callback

* Add Remote Monitor

* Add wrapper layer.

* Add learning rate schedules (#102)

* Add learning rate schedules

* Some review comments

* Use fancy new serialization tests

* s/TensorFlow/backend in docstring

* Update docstrings

* More review comments

* Added LeakyReLU activation layer (#109)

* added LeakyReLu

* update docstring

* reformat

* update config

* updated test name

* Fix docstrings

* Fix init and update tests to import from correct path (#110)

* Add distribute support for tensorflow trainer.

* Revert the previous merge edit.

* Fix lint issue

* Address review comments.

* Add TPU strategy support

* Fix lint

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
Co-authored-by: divyasreepat <divyashreepathihalli@gmail.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com>
2023-05-09 16:46:44 -07:00

58 lines
1.6 KiB
Python

import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import losses
from keras_core import models
from keras_core import metrics
from keras_core import optimizers
from keras_core.utils import rng_utils
def test_model_fit():
cpus = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
cpus[0],
[
tf.config.LogicalDeviceConfiguration(),
tf.config.LogicalDeviceConfiguration(),
],
)
rng_utils.set_random_seed(1337)
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
with strategy.scope():
inputs = layers.Input((100,), batch_size=32)
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(16)(x)
model = models.Model(inputs, outputs)
model.summary()
x = np.random.random((50000, 100))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 5
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable that is not created eagerly
# and break the usage of XLA.
jit_compile=False,
)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
print("History:")
print(history.history)
if __name__ == "__main__":
test_model_fit()