199 lines
5.8 KiB
Python
199 lines
5.8 KiB
Python
|
import os
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
from keras_core import layers
|
||
|
import keras_core
|
||
|
import numpy as np
|
||
|
|
||
|
import torch.multiprocessing as mp
|
||
|
import torch.distributed as dist
|
||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
|
from torch.utils.data import DataLoader, Dataset
|
||
|
from torch.utils.data.distributed import DistributedSampler
|
||
|
|
||
|
# Model / data parameters
|
||
|
num_classes = 10
|
||
|
input_shape = (28, 28, 1)
|
||
|
learning_rate = 0.01
|
||
|
batch_size = 128
|
||
|
num_epochs = 1
|
||
|
|
||
|
def get_data():
|
||
|
# Load the data and split it between train and test sets
|
||
|
(x_train, y_train), (x_test, y_test) = keras_core.datasets.mnist.load_data()
|
||
|
|
||
|
# Scale images to the [0, 1] range
|
||
|
x_train = x_train.astype("float32") / 255
|
||
|
x_test = x_test.astype("float32") / 255
|
||
|
# Make sure images have shape (28, 28, 1)
|
||
|
x_train = np.expand_dims(x_train, -1)
|
||
|
x_test = np.expand_dims(x_test, -1)
|
||
|
print("x_train shape:", x_train.shape)
|
||
|
print(x_train.shape[0], "train samples")
|
||
|
print(x_test.shape[0], "test samples")
|
||
|
|
||
|
# Create a TensorDataset
|
||
|
dataset = torch.utils.data.TensorDataset(
|
||
|
torch.from_numpy(x_train), torch.from_numpy(y_train)
|
||
|
)
|
||
|
return dataset
|
||
|
|
||
|
def get_model():
|
||
|
# Create the Keras model
|
||
|
model = keras_core.Sequential(
|
||
|
[
|
||
|
layers.Input(shape=(28, 28, 1)),
|
||
|
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
||
|
layers.MaxPooling2D(pool_size=(2, 2)),
|
||
|
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
||
|
layers.MaxPooling2D(pool_size=(2, 2)),
|
||
|
layers.Flatten(),
|
||
|
layers.Dropout(0.5),
|
||
|
layers.Dense(num_classes),
|
||
|
]
|
||
|
)
|
||
|
return model
|
||
|
|
||
|
|
||
|
class MyModel(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.model = keras_core.Sequential(
|
||
|
[
|
||
|
layers.Input(shape=(28, 28, 1)),
|
||
|
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
||
|
layers.MaxPooling2D(pool_size=(2, 2)),
|
||
|
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
||
|
layers.MaxPooling2D(pool_size=(2, 2)),
|
||
|
layers.Flatten(),
|
||
|
layers.Dropout(0.5),
|
||
|
layers.Dense(num_classes),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.model(x)
|
||
|
|
||
|
|
||
|
def train(model, train_loader, num_epochs, optimizer, loss_fn):
|
||
|
for epoch in range(num_epochs):
|
||
|
running_loss = 0.0
|
||
|
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
||
|
inputs = inputs.cuda(non_blocking=True)
|
||
|
targets = targets.cuda(non_blocking=True)
|
||
|
|
||
|
# Forward pass
|
||
|
outputs = model(inputs)
|
||
|
loss = loss_fn(outputs, targets)
|
||
|
|
||
|
# Backward and optimize
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
running_loss += loss.item()
|
||
|
|
||
|
# Print loss statistics
|
||
|
if (batch_idx + 1) % 10 == 0:
|
||
|
print(
|
||
|
f"Epoch [{epoch+1}/{num_epochs}], "
|
||
|
f"Batch [{batch_idx+1}/{len(train_loader)}], "
|
||
|
f"Loss: {running_loss / 10}"
|
||
|
)
|
||
|
running_loss = 0.0
|
||
|
|
||
|
def setup(current_gpu_index, num_gpu):
|
||
|
# Device setup
|
||
|
os.environ['MASTER_ADDR'] = 'keras-core-torch'
|
||
|
os.environ['MASTER_PORT'] = '56492'
|
||
|
device = torch.device("cuda:{}".format(current_gpu_index))
|
||
|
dist.init_process_group(
|
||
|
backend='nccl',
|
||
|
init_method='env://',
|
||
|
world_size=num_gpu,
|
||
|
rank=current_gpu_index
|
||
|
)
|
||
|
torch.cuda.set_device(device)
|
||
|
|
||
|
|
||
|
def prepare(dataset, current_gpu_index, num_gpu, batch_size):
|
||
|
sampler = DistributedSampler(
|
||
|
dataset,
|
||
|
num_replicas=num_gpu,
|
||
|
rank=current_gpu_index,
|
||
|
shuffle=False,
|
||
|
)
|
||
|
|
||
|
# Create a DataLoader
|
||
|
train_loader = DataLoader(
|
||
|
dataset,
|
||
|
sampler=sampler,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
)
|
||
|
|
||
|
return train_loader
|
||
|
|
||
|
def cleanup():
|
||
|
# Cleanup
|
||
|
dist.destroy_process_group()
|
||
|
|
||
|
|
||
|
def main(current_gpu_index, num_gpu):
|
||
|
# setup the process groups
|
||
|
setup(current_gpu_index, num_gpu)
|
||
|
|
||
|
#################################################################
|
||
|
######## Writing a torch training loop for a Keras model ########
|
||
|
#################################################################
|
||
|
|
||
|
dataset = get_data()
|
||
|
model = get_model()
|
||
|
|
||
|
# prepare the dataloader
|
||
|
dataloader = prepare(dataset, current_gpu_index, num_gpu, batch_size)
|
||
|
|
||
|
# Instantiate the torch optimizer
|
||
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||
|
|
||
|
# Instantiate the torch loss function
|
||
|
loss_fn = nn.CrossEntropyLoss()
|
||
|
|
||
|
# Put model on device
|
||
|
model = model.to(current_gpu_index)
|
||
|
ddp_model = DDP(model, device_ids=[current_gpu_index], output_device=current_gpu_index)
|
||
|
|
||
|
train(ddp_model, dataloader, num_epochs, optimizer, loss_fn)
|
||
|
|
||
|
################################################################
|
||
|
######## Using a Keras model or layer in a torch Module ########
|
||
|
################################################################
|
||
|
|
||
|
torch_module = MyModel().to(current_gpu_index)
|
||
|
ddp_torch_module = DDP(torch_module, device_ids=[current_gpu_index], output_device=current_gpu_index)
|
||
|
|
||
|
# Instantiate the torch optimizer
|
||
|
optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate)
|
||
|
|
||
|
# Instantiate the torch loss function
|
||
|
loss_fn = nn.CrossEntropyLoss()
|
||
|
|
||
|
train(ddp_torch_module, dataloader, num_epochs, optimizer, loss_fn)
|
||
|
|
||
|
cleanup()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# GPU parameters
|
||
|
num_gpu = torch.cuda.device_count()
|
||
|
|
||
|
print(f"Running on {num_gpu} GPUs")
|
||
|
|
||
|
torch.multiprocessing.spawn(
|
||
|
main,
|
||
|
args=(num_gpu, ),
|
||
|
nprocs=num_gpu,
|
||
|
join=True,
|
||
|
)
|