keras/integration_tests/distribute_training_test.py

59 lines
1.7 KiB
Python
Raw Normal View History

Add distribution strategy support for model.fit/eval/predict. (#119) * Add unit/integration test for tf.distribute. * Fix format * Skip the test case for non-tf backend * Fix typo * Fix format and unit test context config. * Address review comments. * Add support for h5 weights loading. * Fix test * Add support for a -1 dimension in the `Reshape` operation. (#103) The code to compute the output shape is now shared between the `Reshape` operation and the `Reshape` layer. * Added ReLU activation layer (#104) * added relu * add relu * added correctness test * reformated * updates based on review * Fix docstring * Added R2score (#106) * Add meanX metrics * All regression metrics except for root mean squared error * Formatting issues * Add RootMeanSquaredError * Docstring spacing * Line too long fix * Add R2Score * Docstring fixes * Fix test * Fix tests * Adds RemoteMonitor Callback (#108) * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor * Add wrapper layer. * Add learning rate schedules (#102) * Add learning rate schedules * Some review comments * Use fancy new serialization tests * s/TensorFlow/backend in docstring * Update docstrings * More review comments * Added LeakyReLU activation layer (#109) * added LeakyReLu * update docstring * reformat * update config * updated test name * Fix docstrings * Fix init and update tests to import from correct path (#110) * Add distribute support for tensorflow trainer. * Revert the previous merge edit. * Fix lint issue * Address review comments. * Add TPU strategy support * Fix lint --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: divyasreepat <divyashreepathihalli@gmail.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com>
2023-05-09 23:46:44 +00:00
import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import losses
from keras_core import models
from keras_core import metrics
from keras_core import optimizers
from keras_core.utils import rng_utils
def test_model_fit():
cpus = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
cpus[0],
[
tf.config.LogicalDeviceConfiguration(),
tf.config.LogicalDeviceConfiguration(),
],
)
rng_utils.set_random_seed(1337)
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
with strategy.scope():
inputs = layers.Input((100,), batch_size=32)
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(16)(x)
model = models.Model(inputs, outputs)
model.summary()
x = np.random.random((50000, 100))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 5
with strategy.scope():
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable that is not created eagerly
# and break the usage of XLA.
jit_compile=False,
)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
Add distribution strategy support for model.fit/eval/predict. (#119) * Add unit/integration test for tf.distribute. * Fix format * Skip the test case for non-tf backend * Fix typo * Fix format and unit test context config. * Address review comments. * Add support for h5 weights loading. * Fix test * Add support for a -1 dimension in the `Reshape` operation. (#103) The code to compute the output shape is now shared between the `Reshape` operation and the `Reshape` layer. * Added ReLU activation layer (#104) * added relu * add relu * added correctness test * reformated * updates based on review * Fix docstring * Added R2score (#106) * Add meanX metrics * All regression metrics except for root mean squared error * Formatting issues * Add RootMeanSquaredError * Docstring spacing * Line too long fix * Add R2Score * Docstring fixes * Fix test * Fix tests * Adds RemoteMonitor Callback (#108) * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor * Add wrapper layer. * Add learning rate schedules (#102) * Add learning rate schedules * Some review comments * Use fancy new serialization tests * s/TensorFlow/backend in docstring * Update docstrings * More review comments * Added LeakyReLU activation layer (#109) * added LeakyReLu * update docstring * reformat * update config * updated test name * Fix docstrings * Fix init and update tests to import from correct path (#110) * Add distribute support for tensorflow trainer. * Revert the previous merge edit. * Fix lint issue * Address review comments. * Add TPU strategy support * Fix lint --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: divyasreepat <divyashreepathihalli@gmail.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com>
2023-05-09 23:46:44 +00:00
print("History:")
print(history.history)
if __name__ == "__main__":
test_model_fit()