Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
76a225d370
commit
46b8fd2d6d
@ -113,7 +113,7 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
|
|||||||
|
|
||||||
def setup(current_gpu_index, num_gpu):
|
def setup(current_gpu_index, num_gpu):
|
||||||
# Device setup
|
# Device setup
|
||||||
os.environ["MASTER_ADDR"] = "keras-core-torch"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = "56492"
|
os.environ["MASTER_PORT"] = "56492"
|
||||||
device = torch.device("cuda:{}".format(current_gpu_index))
|
device = torch.device("cuda:{}".format(current_gpu_index))
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
|
@ -52,7 +52,7 @@ import keras_core as keras
|
|||||||
def get_model():
|
def get_model():
|
||||||
# Make a simple convnet with batch normalization and dropout.
|
# Make a simple convnet with batch normalization and dropout.
|
||||||
inputs = keras.Input(shape=(28, 28, 1))
|
inputs = keras.Input(shape=(28, 28, 1))
|
||||||
x = keras.layers.Rescaling(1.0 / 255.0)(x)
|
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
|
||||||
x = keras.layers.Conv2D(
|
x = keras.layers.Conv2D(
|
||||||
filters=12, kernel_size=3, padding="same", use_bias=False
|
filters=12, kernel_size=3, padding="same", use_bias=False
|
||||||
)(x)
|
)(x)
|
||||||
@ -172,7 +172,7 @@ Here's how it works:
|
|||||||
- We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one
|
- We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one
|
||||||
per device. Each process will run the `per_device_launch_fn` function.
|
per device. Each process will run the `per_device_launch_fn` function.
|
||||||
- The `per_device_launch_fn` function does the following:
|
- The `per_device_launch_fn` function does the following:
|
||||||
- It uses `torch.distributed.dist.init_process_group` and `torch.cuda.set_device`
|
- It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`
|
||||||
to configure the device to be used for that process.
|
to configure the device to be used for that process.
|
||||||
- It uses `torch.utils.data.distributed.DistributedSampler`
|
- It uses `torch.utils.data.distributed.DistributedSampler`
|
||||||
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
|
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
|
||||||
@ -194,10 +194,10 @@ print(f"Running on {num_gpu} GPUs")
|
|||||||
|
|
||||||
def setup_device(current_gpu_index, num_gpus):
|
def setup_device(current_gpu_index, num_gpus):
|
||||||
# Device setup
|
# Device setup
|
||||||
os.environ["MASTER_ADDR"] = "keras-core-torch"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = "56492"
|
os.environ["MASTER_PORT"] = "56492"
|
||||||
device = torch.device("cuda:{}".format(current_gpu_index))
|
device = torch.device("cuda:{}".format(current_gpu_index))
|
||||||
torch.distributed.dist.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
init_method="env://",
|
init_method="env://",
|
||||||
world_size=num_gpus,
|
world_size=num_gpus,
|
||||||
@ -207,7 +207,7 @@ def setup_device(current_gpu_index, num_gpus):
|
|||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
torch.distributed.dist.dist.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
|
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
|
||||||
@ -259,12 +259,13 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
|
|||||||
Time to spawn:
|
Time to spawn:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
torch.multiprocessing.spawn(
|
if __name__ == '__main__':
|
||||||
per_device_launch_fn,
|
torch.multiprocessing.spawn(
|
||||||
args=(num_gpu,),
|
per_device_launch_fn,
|
||||||
nprocs=num_gpu,
|
args=(num_gpu,),
|
||||||
join=True,
|
nprocs=num_gpu,
|
||||||
)
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
That's it!
|
That's it!
|
||||||
|
Loading…
Reference in New Issue
Block a user