import os import torch import torch.nn as nn import torch.optim as optim from keras_core import layers import keras_core import numpy as np import torch.multiprocessing as mp import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler # Model / data parameters num_classes = 10 input_shape = (28, 28, 1) learning_rate = 0.01 batch_size = 128 num_epochs = 1 def get_data(): # Load the data and split it between train and test sets (x_train, y_train), (x_test, y_test) = keras_core.datasets.mnist.load_data() # Scale images to the [0, 1] range x_train = x_train.astype("float32") / 255 x_test = x_test.astype("float32") / 255 # Make sure images have shape (28, 28, 1) x_train = np.expand_dims(x_train, -1) x_test = np.expand_dims(x_test, -1) print("x_train shape:", x_train.shape) print(x_train.shape[0], "train samples") print(x_test.shape[0], "test samples") # Create a TensorDataset dataset = torch.utils.data.TensorDataset( torch.from_numpy(x_train), torch.from_numpy(y_train) ) return dataset def get_model(): # Create the Keras model model = keras_core.Sequential( [ layers.Input(shape=(28, 28, 1)), layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), layers.Flatten(), layers.Dropout(0.5), layers.Dense(num_classes), ] ) return model class MyModel(nn.Module): def __init__(self): super().__init__() self.model = keras_core.Sequential( [ layers.Input(shape=(28, 28, 1)), layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), layers.Flatten(), layers.Dropout(0.5), layers.Dense(num_classes), ] ) def forward(self, x): return self.model(x) def train(model, train_loader, num_epochs, optimizer, loss_fn): for epoch in range(num_epochs): running_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs = inputs.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) # Forward pass outputs = model(inputs) loss = loss_fn(outputs, targets) # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() # Print loss statistics if (batch_idx + 1) % 10 == 0: print( f"Epoch [{epoch+1}/{num_epochs}], " f"Batch [{batch_idx+1}/{len(train_loader)}], " f"Loss: {running_loss / 10}" ) running_loss = 0.0 def setup(current_gpu_index, num_gpu): # Device setup os.environ['MASTER_ADDR'] = 'keras-core-torch' os.environ['MASTER_PORT'] = '56492' device = torch.device("cuda:{}".format(current_gpu_index)) dist.init_process_group( backend='nccl', init_method='env://', world_size=num_gpu, rank=current_gpu_index ) torch.cuda.set_device(device) def prepare(dataset, current_gpu_index, num_gpu, batch_size): sampler = DistributedSampler( dataset, num_replicas=num_gpu, rank=current_gpu_index, shuffle=False, ) # Create a DataLoader train_loader = DataLoader( dataset, sampler=sampler, batch_size=batch_size, shuffle=False, ) return train_loader def cleanup(): # Cleanup dist.destroy_process_group() def main(current_gpu_index, num_gpu): # setup the process groups setup(current_gpu_index, num_gpu) ################################################################# ######## Writing a torch training loop for a Keras model ######## ################################################################# dataset = get_data() model = get_model() # prepare the dataloader dataloader = prepare(dataset, current_gpu_index, num_gpu, batch_size) # Instantiate the torch optimizer optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Instantiate the torch loss function loss_fn = nn.CrossEntropyLoss() # Put model on device model = model.to(current_gpu_index) ddp_model = DDP(model, device_ids=[current_gpu_index], output_device=current_gpu_index) train(ddp_model, dataloader, num_epochs, optimizer, loss_fn) ################################################################ ######## Using a Keras model or layer in a torch Module ######## ################################################################ torch_module = MyModel().to(current_gpu_index) ddp_torch_module = DDP(torch_module, device_ids=[current_gpu_index], output_device=current_gpu_index) # Instantiate the torch optimizer optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate) # Instantiate the torch loss function loss_fn = nn.CrossEntropyLoss() train(ddp_torch_module, dataloader, num_epochs, optimizer, loss_fn) cleanup() if __name__ == "__main__": # GPU parameters num_gpu = torch.cuda.device_count() print(f"Running on {num_gpu} GPUs") torch.multiprocessing.spawn( main, args=(num_gpu, ), nprocs=num_gpu, join=True, )