From 6769821b1c9ff2b52c7cad2460f481e72f9c9837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Thu, 22 Jun 2023 18:44:27 +0200 Subject: [PATCH] added Jax distributed training exammple using a Keras model (#384) * added Jax distributed training exammple using a Keras model * fixed file formatting * fixed file formatting --- examples/demo_jax_distributed.py | 337 +++++++++++++++++++++++++++++++ 1 file changed, 337 insertions(+) create mode 100644 examples/demo_jax_distributed.py diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py new file mode 100644 index 000000000..328490f0a --- /dev/null +++ b/examples/demo_jax_distributed.py @@ -0,0 +1,337 @@ +# To run this demo, you will need to spin up a "TPU VM" on Google Cloud. +# Please follow instructions here: https://cloud.google.com/tpu/docs/run-calculation-jax + +# Force a JAX backend +import os, pprint, collections + +os.environ["KERAS_BACKEND"] = "jax" + +pp = pprint.PrettyPrinter() + +import jax +import jax.numpy as jnp +import tensorflow as tf # just for tf.data +import keras_core as keras # Keras multi-backend + +import numpy as np +from tqdm import tqdm + +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +""" Dataset +Classic MNIST, loaded using tf.data +""" + +BATCH_SIZE = 192 + +(x_train, train_labels), ( + x_eval, + eval_labels, +) = keras.datasets.mnist.load_data() +x_train = np.expand_dims(x_train, axis=-1).astype( + np.float32 +) # from 28x28 to 28x28 x 1 color channel (B&W) +x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32) + +train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) +train_data = train_data.shuffle(5000, reshuffle_each_iteration=True) +train_data = train_data.batch(BATCH_SIZE, drop_remainder=True) +train_data = train_data.repeat() + +eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels)) +eval_data = eval_data.batch(10000) # everything as one batch + +STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE + +""" Keras model +Simple but non-trivial model with: +* Batch Normalization (non-trainable state updated during trainig, different training-time and inference behavior) +* Dropout (randomness, different training time and inference behavior) +""" + + +# Keras "sequential" model building style +def make_backbone(): + return keras.Sequential( + [ + keras.layers.Rescaling( + 1.0 / 255.0 + ), # input images are in the range [0, 255] + keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=24, + kernel_size=6, + padding="same", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + use_bias=False, + strides=2, + name="large_k", + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + ], + name="backbone", + ) + + +def make_model(): + input = keras.Input(shape=[28, 28, 1]) + y = make_backbone()(input) + y = keras.layers.Flatten()(y) + y = keras.layers.Dense(200, activation="relu")(y) + y = keras.layers.Dropout(0.4)(y) + y = keras.layers.Dense(10, activation="softmax")(y) + model = keras.Model(inputs=input, outputs=y) + return model + + +""" JAX-native distribution with a Keras model +For now, you have to write a custom training loop for this +Note: The features required by jax.sharding are not supported by the Colab TPU +runtime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs. +""" + +if len(jax.local_devices()) < 8: + raise Exception("This part requires 8 devices to run") +else: + print("\nIdentified local devices:") + pp.pprint(jax.local_devices()) + +# ----------------- Keras --------------------- + +# instantiate the model +model = make_model() + +# learning rate +lr = keras.optimizers.schedules.ExponentialDecay(0.01, STEPS_PER_EPOCH, 0.6) + +# optimizer +optimizer = keras.optimizers.Adam(lr) + +# initialize all state with .build() +(one_batch, one_batch_labels) = next(iter(train_data)) +model.build(one_batch) +optimizer.build(model.trainable_variables) + +""" Distribution settings + +* Sharding the data on the batch axis +* Replicating all model variables + +Note: this implements standard "data parallel" distributed training + +* Just for show, sharding the largest convolutional kernel along the + "channels" axis 4-ways and replicating 2-ways + +Note: this does not reflect a best practice but is intended to show + that you can split a very large kernel across multiple devices + if you have to +""" + +print( + "\nMostly data-parallel distribution. " + "Data is sharded across devices while the model is replicated. " + "For demo purposes, we split the largest kernel 4-ways " + "(and replicate 2-ways since we have 8 devices)." +) + +# ------------------ Jax ---------------------- + +devices = mesh_utils.create_device_mesh((8,)) + +# data will be split along the batch axis +data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh +# naming axes of the sharded partition +data_sharding = NamedSharding(data_mesh,P("batch",),) + +# all variables will be replicated on all devices +var_mesh = Mesh(devices, axis_names=("_")) +# in NamedSharding, axes that are not mentioned are replicated (all axes here) +var_replication = NamedSharding(var_mesh, P()) + +# for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) +large_kernel_mesh = Mesh( + devices.reshape((-1, 4)), axis_names=(None, "out_chan") +) # naming axes of the mesh +large_kernel_sharding = NamedSharding( + large_kernel_mesh, P(None, None, None, "out_chan") +) # naming axes of the sharded partition + +# ----------------- Keras --------------------- + +# Use Keras APIs to find the variable of a specific layer (we will be sharding this one in a special way) +# In a Conv2D or Dense layer, the variables are 'kernel' and 'bias' +special_layer_var = model.get_layer("backbone").get_layer("large_k").kernel + +# ------------------ Jax ---------------------- +# - accessing variables in Keras lists model.trainable_variables, +# - model.non_trainable_variables and optimizer.variables + +# Apply the distribution settings to the model variables +non_trainable_variables = jax.device_put( + model.non_trainable_variables, var_replication +) +optimizer_variables = jax.device_put(optimizer.variables, var_replication) +# this is what you would do replicate all trainable variables: +# trainable_variables = jax.device_put(model.trainable_variables, var_replication) + +# For the demo, we split the largest kernel 4-ways instead of replicating it. +# We still replicate all other trainable variables as in standard "data-parallel" +# distributed training. +print_once = True +trainable_variables = model.trainable_variables +for i, v in enumerate(trainable_variables): + if v is special_layer_var: + # Apply distribution settings: sharding + sharded_v = jax.device_put(v, large_kernel_sharding) + trainable_variables[i] = sharded_v + + print("Sharding of convolutional", v.name, v.shape) + jax.debug.visualize_array_sharding( + jnp.reshape(sharded_v, [-1, v.shape[-1]]) + ) + else: + # Apply distribution settings: replication + replicated_v = jax.device_put(v, var_replication) + trainable_variables[i] = replicated_v + + if print_once: + print_once = False + print( + "\nSharding of all other model variables (they are replicated)" + ) + jax.debug.visualize_array_sharding( + jnp.reshape(replicated_v, [-1, v.shape[-1]]) + ) + +# collect state in a handy named tuple +TrainingState = collections.namedtuple( + "TrainingState", + ["trainable_variables", "non_trainable_variables", "optimizer_variables"], +) +device_train_state = TrainingState( + trainable_variables=trainable_variables, + non_trainable_variables=non_trainable_variables, + optimizer_variables=optimizer_variables, +) +# display data sharding +x, y = next(iter(train_data)) +sharded_x = jax.device_put(x.numpy(), data_sharding) +print("Data sharding") +jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28 * 28])) + +# ------------------ Jax ---------------------- +# - Using Keras-provided stateless APIs +# - model.stateless_call +# - optimizer.stateless_apply +# These functions also work on other backends. + +# define loss +loss = keras.losses.SparseCategoricalCrossentropy() + + +# This is the loss function that will be differentiated. +# Keras provides a pure functional forward pass: model.stateless_call +def compute_loss(trainable_variables, non_trainable_variables, x, y): + y_pred, updated_non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss_value = loss(y, y_pred) + return loss_value, updated_non_trainable_variables + + +# function to compute gradients +compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) + + +# Trainig step: Keras provides a pure functional optimizer.stateless_apply +@jax.jit +def train_step(train_state, x, y): + (loss_value, non_trainable_variables), grads = compute_gradients( + train_state.trainable_variables, + train_state.non_trainable_variables, + x, + y, + ) + + trainable_variables, optimizer_variables = optimizer.stateless_apply( + grads, train_state.trainable_variables, train_state.optimizer_variables + ) + + return loss_value, TrainingState( + trainable_variables, non_trainable_variables, optimizer_variables + ) + + +# training loop +EPOCHS = 5 +print("\nTrainig:") +data_iter = iter(train_data) +for epoch in range(EPOCHS): + for i in tqdm(range(STEPS_PER_EPOCH)): + x, y = next(data_iter) + sharded_x = jax.device_put(x.numpy(), data_sharding) + loss_value, device_train_state = train_step( + device_train_state, sharded_x, y.numpy() + ) + print("Epoch", epoch, "loss:", loss_value) + +# The output of the model is still sharded. Sharding follows the data. + +data, labels = next(iter(eval_data)) +sharded_data = jax.device_put(data.numpy(), data_sharding) + + +@jax.jit +def predict(data): + predictions, updated_non_trainable_variables = model.stateless_call( + device_train_state.trainable_variables, + device_train_state.non_trainable_variables, + data, + ) + return predictions + + +predictions = predict(sharded_data) +print("\nModel output sharding follows data sharding:") +jax.debug.visualize_array_sharding(predictions) + +# Post-processing model state update to write them back into the model +update = lambda variable, value: variable.assign(value) + +jax.tree_map( + update, model.trainable_variables, device_train_state.trainable_variables +) +jax.tree_map( + update, + model.non_trainable_variables, + device_train_state.non_trainable_variables, +) +jax.tree_map( + update, optimizer.variables, device_train_state.optimizer_variables +) + +# check that the model has the new state by running an eval +# known issue: the optimizer should not be required here +model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(), + metrics=[keras.metrics.SparseCategoricalAccuracy()], +) +print("\nUpdating model and running an eval:") +loss, accuracy = model.evaluate(eval_data) +print("The model achieved an evaluation accuracy of:", accuracy)