46378e883f
* Add unit/integration test for tf.distribute. * Fix format * Skip the test case for non-tf backend * Fix typo * Fix format and unit test context config. * Address review comments. * Add support for h5 weights loading. * Fix test * Add support for a -1 dimension in the `Reshape` operation. (#103) The code to compute the output shape is now shared between the `Reshape` operation and the `Reshape` layer. * Added ReLU activation layer (#104) * added relu * add relu * added correctness test * reformated * updates based on review * Fix docstring * Added R2score (#106) * Add meanX metrics * All regression metrics except for root mean squared error * Formatting issues * Add RootMeanSquaredError * Docstring spacing * Line too long fix * Add R2Score * Docstring fixes * Fix test * Fix tests * Adds RemoteMonitor Callback (#108) * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor * Add wrapper layer. * Add learning rate schedules (#102) * Add learning rate schedules * Some review comments * Use fancy new serialization tests * s/TensorFlow/backend in docstring * Update docstrings * More review comments * Added LeakyReLU activation layer (#109) * added LeakyReLu * update docstring * reformat * update config * updated test name * Fix docstrings * Fix init and update tests to import from correct path (#110) * Add distribute support for tensorflow trainer. * Revert the previous merge edit. * Fix lint issue * Address review comments. * Add TPU strategy support * Fix lint --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: divyasreepat <divyashreepathihalli@gmail.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com>
58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from keras_core import layers
|
|
from keras_core import losses
|
|
from keras_core import models
|
|
from keras_core import metrics
|
|
from keras_core import optimizers
|
|
from keras_core.utils import rng_utils
|
|
|
|
|
|
def test_model_fit():
|
|
|
|
cpus = tf.config.list_physical_devices("CPU")
|
|
tf.config.set_logical_device_configuration(
|
|
cpus[0],
|
|
[
|
|
tf.config.LogicalDeviceConfiguration(),
|
|
tf.config.LogicalDeviceConfiguration(),
|
|
],
|
|
)
|
|
|
|
rng_utils.set_random_seed(1337)
|
|
|
|
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
|
|
with strategy.scope():
|
|
inputs = layers.Input((100,), batch_size=32)
|
|
x = layers.Dense(256, activation="relu")(inputs)
|
|
x = layers.Dense(256, activation="relu")(x)
|
|
x = layers.Dense(256, activation="relu")(x)
|
|
x = layers.BatchNormalization()(x)
|
|
outputs = layers.Dense(16)(x)
|
|
model = models.Model(inputs, outputs)
|
|
|
|
model.summary()
|
|
|
|
x = np.random.random((50000, 100))
|
|
y = np.random.random((50000, 16))
|
|
batch_size = 32
|
|
epochs = 5
|
|
|
|
model.compile(
|
|
optimizer=optimizers.SGD(learning_rate=0.001),
|
|
loss=losses.MeanSquaredError(),
|
|
metrics=[metrics.MeanSquaredError()],
|
|
# TODO(scottzhu): Find out where is the variable that is not created eagerly
|
|
# and break the usage of XLA.
|
|
jit_compile=False,
|
|
)
|
|
history = model.fit(
|
|
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
|
)
|
|
|
|
print("History:")
|
|
print(history.history)
|
|
|
|
if __name__ == "__main__":
|
|
test_model_fit() |