Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
76a225d370
commit
46b8fd2d6d
@ -113,7 +113,7 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
|
||||
|
||||
def setup(current_gpu_index, num_gpu):
|
||||
# Device setup
|
||||
os.environ["MASTER_ADDR"] = "keras-core-torch"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "56492"
|
||||
device = torch.device("cuda:{}".format(current_gpu_index))
|
||||
dist.init_process_group(
|
||||
|
@ -52,7 +52,7 @@ import keras_core as keras
|
||||
def get_model():
|
||||
# Make a simple convnet with batch normalization and dropout.
|
||||
inputs = keras.Input(shape=(28, 28, 1))
|
||||
x = keras.layers.Rescaling(1.0 / 255.0)(x)
|
||||
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
|
||||
x = keras.layers.Conv2D(
|
||||
filters=12, kernel_size=3, padding="same", use_bias=False
|
||||
)(x)
|
||||
@ -172,7 +172,7 @@ Here's how it works:
|
||||
- We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one
|
||||
per device. Each process will run the `per_device_launch_fn` function.
|
||||
- The `per_device_launch_fn` function does the following:
|
||||
- It uses `torch.distributed.dist.init_process_group` and `torch.cuda.set_device`
|
||||
- It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`
|
||||
to configure the device to be used for that process.
|
||||
- It uses `torch.utils.data.distributed.DistributedSampler`
|
||||
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
|
||||
@ -194,10 +194,10 @@ print(f"Running on {num_gpu} GPUs")
|
||||
|
||||
def setup_device(current_gpu_index, num_gpus):
|
||||
# Device setup
|
||||
os.environ["MASTER_ADDR"] = "keras-core-torch"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "56492"
|
||||
device = torch.device("cuda:{}".format(current_gpu_index))
|
||||
torch.distributed.dist.init_process_group(
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
world_size=num_gpus,
|
||||
@ -207,7 +207,7 @@ def setup_device(current_gpu_index, num_gpus):
|
||||
|
||||
|
||||
def cleanup():
|
||||
torch.distributed.dist.dist.destroy_process_group()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
|
||||
@ -259,12 +259,13 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
|
||||
Time to spawn:
|
||||
"""
|
||||
|
||||
torch.multiprocessing.spawn(
|
||||
per_device_launch_fn,
|
||||
args=(num_gpu,),
|
||||
nprocs=num_gpu,
|
||||
join=True,
|
||||
)
|
||||
if __name__ == '__main__':
|
||||
torch.multiprocessing.spawn(
|
||||
per_device_launch_fn,
|
||||
args=(num_gpu,),
|
||||
nprocs=num_gpu,
|
||||
join=True,
|
||||
)
|
||||
|
||||
"""
|
||||
That's it!
|
||||
|
Loading…
Reference in New Issue
Block a user