keras/examples/demo_torch_multi_gpu.py
2023-09-22 09:29:36 -07:00

214 lines
5.9 KiB
Python

# flake8: noqa
import os
# Set backend env to torch
os.environ["KERAS_BACKEND"] = "torch"
import torch
import torch.nn as nn
import torch.optim as optim
from keras import layers
import keras
import numpy as np
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
learning_rate = 0.01
batch_size = 128
num_epochs = 1
def get_data():
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
return dataset
def get_model():
# Create the Keras model
model = keras.Sequential(
[
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes),
]
)
return model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.model = keras.Sequential(
[
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes),
]
)
def forward(self, x):
return self.model(x)
def train(model, train_loader, num_epochs, optimizer, loss_fn):
for epoch in range(num_epochs):
running_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs = inputs.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
# Forward pass
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# Print loss statistics
if (batch_idx + 1) % 10 == 0:
print(
f"Epoch [{epoch+1}/{num_epochs}], "
f"Batch [{batch_idx+1}/{len(train_loader)}], "
f"Loss: {running_loss / 10}"
)
running_loss = 0.0
def setup(current_gpu_index, num_gpu):
# Device setup
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index))
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=num_gpu,
rank=current_gpu_index,
)
torch.cuda.set_device(device)
def prepare(dataset, current_gpu_index, num_gpu, batch_size):
sampler = DistributedSampler(
dataset,
num_replicas=num_gpu,
rank=current_gpu_index,
shuffle=False,
)
# Create a DataLoader
train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=False,
)
return train_loader
def cleanup():
# Cleanup
dist.destroy_process_group()
def main(current_gpu_index, num_gpu):
# setup the process groups
setup(current_gpu_index, num_gpu)
#################################################################
######## Writing a torch training loop for a Keras model ########
#################################################################
dataset = get_data()
model = get_model()
# prepare the dataloader
dataloader = prepare(dataset, current_gpu_index, num_gpu, batch_size)
# Instantiate the torch optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Instantiate the torch loss function
loss_fn = nn.CrossEntropyLoss()
# Put model on device
model = model.to(current_gpu_index)
ddp_model = DDP(
model, device_ids=[current_gpu_index], output_device=current_gpu_index
)
train(ddp_model, dataloader, num_epochs, optimizer, loss_fn)
################################################################
######## Using a Keras model or layer in a torch Module ########
################################################################
torch_module = MyModel().to(current_gpu_index)
ddp_torch_module = DDP(
torch_module,
device_ids=[current_gpu_index],
output_device=current_gpu_index,
)
# Instantiate the torch optimizer
optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate)
# Instantiate the torch loss function
loss_fn = nn.CrossEntropyLoss()
train(ddp_torch_module, dataloader, num_epochs, optimizer, loss_fn)
cleanup()
if __name__ == "__main__":
# GPU parameters
num_gpu = torch.cuda.device_count()
print(f"Running on {num_gpu} GPUs")
torch.multiprocessing.spawn(
main,
args=(num_gpu,),
nprocs=num_gpu,
join=True,
)