Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-07-07 16:34:09 -07:00
parent 76a225d370
commit 46b8fd2d6d
2 changed files with 13 additions and 12 deletions

@ -113,7 +113,7 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
def setup(current_gpu_index, num_gpu):
# Device setup
os.environ["MASTER_ADDR"] = "keras-core-torch"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index))
dist.init_process_group(

@ -52,7 +52,7 @@ import keras_core as keras
def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(x)
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
@ -172,7 +172,7 @@ Here's how it works:
- We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one
per device. Each process will run the `per_device_launch_fn` function.
- The `per_device_launch_fn` function does the following:
- It uses `torch.distributed.dist.init_process_group` and `torch.cuda.set_device`
- It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`
to configure the device to be used for that process.
- It uses `torch.utils.data.distributed.DistributedSampler`
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
@ -194,10 +194,10 @@ print(f"Running on {num_gpu} GPUs")
def setup_device(current_gpu_index, num_gpus):
# Device setup
os.environ["MASTER_ADDR"] = "keras-core-torch"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index))
torch.distributed.dist.init_process_group(
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=num_gpus,
@ -207,7 +207,7 @@ def setup_device(current_gpu_index, num_gpus):
def cleanup():
torch.distributed.dist.dist.destroy_process_group()
torch.distributed.destroy_process_group()
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
@ -259,6 +259,7 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
Time to spawn:
"""
if __name__ == '__main__':
torch.multiprocessing.spawn(
per_device_launch_fn,
args=(num_gpu,),