Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-07-07 16:34:09 -07:00
parent 76a225d370
commit 46b8fd2d6d
2 changed files with 13 additions and 12 deletions

@ -113,7 +113,7 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
def setup(current_gpu_index, num_gpu): def setup(current_gpu_index, num_gpu):
# Device setup # Device setup
os.environ["MASTER_ADDR"] = "keras-core-torch" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492" os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index)) device = torch.device("cuda:{}".format(current_gpu_index))
dist.init_process_group( dist.init_process_group(

@ -52,7 +52,7 @@ import keras_core as keras
def get_model(): def get_model():
# Make a simple convnet with batch normalization and dropout. # Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1)) inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(x) x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D( x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False filters=12, kernel_size=3, padding="same", use_bias=False
)(x) )(x)
@ -172,7 +172,7 @@ Here's how it works:
- We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one - We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one
per device. Each process will run the `per_device_launch_fn` function. per device. Each process will run the `per_device_launch_fn` function.
- The `per_device_launch_fn` function does the following: - The `per_device_launch_fn` function does the following:
- It uses `torch.distributed.dist.init_process_group` and `torch.cuda.set_device` - It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`
to configure the device to be used for that process. to configure the device to be used for that process.
- It uses `torch.utils.data.distributed.DistributedSampler` - It uses `torch.utils.data.distributed.DistributedSampler`
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader. and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
@ -194,10 +194,10 @@ print(f"Running on {num_gpu} GPUs")
def setup_device(current_gpu_index, num_gpus): def setup_device(current_gpu_index, num_gpus):
# Device setup # Device setup
os.environ["MASTER_ADDR"] = "keras-core-torch" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492" os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index)) device = torch.device("cuda:{}".format(current_gpu_index))
torch.distributed.dist.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
init_method="env://", init_method="env://",
world_size=num_gpus, world_size=num_gpus,
@ -207,7 +207,7 @@ def setup_device(current_gpu_index, num_gpus):
def cleanup(): def cleanup():
torch.distributed.dist.dist.destroy_process_group() torch.distributed.destroy_process_group()
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size): def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
@ -259,12 +259,13 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
Time to spawn: Time to spawn:
""" """
torch.multiprocessing.spawn( if __name__ == '__main__':
per_device_launch_fn, torch.multiprocessing.spawn(
args=(num_gpu,), per_device_launch_fn,
nprocs=num_gpu, args=(num_gpu,),
join=True, nprocs=num_gpu,
) join=True,
)
""" """
That's it! That's it!