1. Introduction
Pytorch provides a few options for mutli-GPU/multi-CPU computing or in other words distributed computing. While this is unsurprising for Deep learning, what is pleasantly surprising is the support for general purpose low-level distributed or parallel computing. Those who have used MPI will find this functionality to be familiar. Pytorch can be used for the following scenarios:
- Single GPU, single node (multiple CPUs on the same node)
- Single GPU, multiple nodes
- Multiple GPUs, single node
- Multiple GPUs, multiple nodes
Pytorch allows ‘Gloo’, ‘MPI’ and ‘NCCL’ as backends for parallelization. In general, Gloo is available on most Linux distros and should be used for parallelization on CPUs. An MPI installation would probably be faster if one is possible. NCCL should be preferred for using multiple GPUs, since it is optimized for precisely that. The choice of backend has an implication on the type of functionality available to you on CPUs and GPUs if one is interested in using the multiprocessing primitives such as all_reduce, gather, scatter etc. For e.g. Gloo supports all primitives on the CPU whereas NCCL does not. However, NCCL supports reduce and all_reduce on the GPU whereas Gloo does not. Refer to the table here.
The examples shown use material from the Pytorch website and from here, and have been modified.
2. DataParallel: MNIST on multiple GPUs
This is the easiest way to obtain multi-GPU data parallelism using Pytorch. Model parallelism is another paradigm that Pytorch provides (not covered here). The example below assumes that you have 10 GPUs available on a single node. You can select the GPUs using the environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7,8,9"
Note that the GPU device numbering goes from 0 to 3 even though physical devices 5 - 9 are selected here.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7,8,9"
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.batchnorm = nn.BatchNorm1d(128)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=5, metavar='N',
help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net().to(device)
if torch.cuda.device_count() > 1:
print("We have available ", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model, device_ids=[0,1,2,3])
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
if __name__ == '__main__':
main()
3. Distributed Data Parallel
a. Minimal Working Example
The following shows the setup for Distributed Data Parallel in Pytorch. The syntax offers more fine-grained control of the parallelization but also a deeper understanding of multiprocessing. This is the preferred approach in Pytorch and is usually faster than DataParallel.
import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7,8,9"
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.ones(200, 10))
labels = torch.randn(200, 5).to(rank)
loss = loss_fn(outputs, labels)
print("Loss is ",loss.item())
loss.backward()
optimizer.step()
cleanup()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gpus', default=2, type=int,
help='number of gpus per node')
parser.add_argument('--epochs', default=2, type=int,
metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
world_size = args.gpus
print("We have available ", torch.cuda.device_count(), "GPUs! but using ",world_size," GPUs")
#########################################################
mp.spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True)
#########################################################
b. MNIST
The following example shown below using Distributed Data Parallel for MNIST.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
import torchvision
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7,8,9"
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.batchnorm = nn.BatchNorm1d(128)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch, rank):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
if(rank == 0):
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader, rank):
model.eval()
test_loss = 0
correct = 0
data_len = 0
test_loss_tensor = 0
correct_tensor = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
test_loss_tensor += F.nll_loss(output, target, reduction='sum')
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
correct_tensor += pred.eq(target.view_as(pred)).sum()
data_len += len(data)
test_loss /= data_len
dist.all_reduce(test_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
if(rank == 0):
print("Test average loss: {}, correct predictons: {}, total: {}, accuracy: {}% \n".format(test_loss_tensor.item() / len(test_loader.dataset), correct_tensor.item(), len(test_loader.dataset),
100.0 * correct_tensor.item() / len(test_loader.dataset)))
#print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
# test_loss, correct, data_len,
# 100. * correct / data_len))
def demo_basic(rank, world_size, args, use_cuda):
#--------------- Setup -------------#
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
#-----------------------------------#
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# ---------------- Data ------------#
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=train_sampler)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True)
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=world_size,
rank=rank
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=test_sampler)
#-----------------------------------#
model = Net().to(rank)
model = DDP(model, device_ids=[rank])
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch, rank)
test(model, device, test_loader, rank)
scheduler.step()
cleanup()
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=5, metavar='N',
help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--gpus', type=int, default=1, metavar='N',
help='Number of GPUs')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
world_size = args.gpus
if torch.cuda.device_count() > 1:
print("We have available ", torch.cuda.device_count(), "GPUs! but using ",world_size," GPUs")
#########################################################
mp.spawn(demo_basic, args=(world_size, args, use_cuda), nprocs=world_size, join=True)
#########################################################
if __name__ == '__main__':
main()
4. Saving and Loading Distributed Models
a. DataParallel Models
When you save a DataParallel model, you have to extract the state dictionary module as shown below, otherwise the keys will be improperly named and your loader will throw an error.
if args.save_model:
torch.save(model.module.state_dict(), "mnist_cnn.pt")
Now load the model as shown below.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import os
# Model definition has to match the original architecture
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.batchnorm = nn.BatchNorm1d(128)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
# Get one batch of data from the test dataset to evaluate the
# loaded model
def get_data(device="cuda", batch_size=16):
use_cuda = True if device == "cuda" else False
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, **kwargs)
for data, target in test_loader:
data, target = data.to(device), target.to(device)
break
return(data, target)
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
model = Net()
model.to(device)
# map_location allows us to use models trained on another device
# For loading dataparallel models, mnist_cnn.pt should be saved using torch.save(model.module.state_dict(), 'mnist_cnn.pt')
model.load_state_dict(torch.load('/home/centos/tf2/srijith/mnist_cnn.pt', map_location=device), strict=False)
model.eval()
data, target = get_data()
output = model(data)
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct = pred.eq(target.view_as(pred)).sum().item()
print(correct, len(target))
b. Distributed DataParallel Models
Save the model on single rank.
# Save model on rank 0
if(rank == 0):
torch.save(model.module.state_dict(), "mnist_ddp.pt")
Once it is saved, you can load and reuse the model as shown below.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
import torchvision
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7,8,9"
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.batchnorm = nn.BatchNorm1d(128)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch, rank):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
if(rank == 0):
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader, rank):
model.eval()
test_loss = 0
correct = 0
data_len = 0
test_loss_tensor = 0
correct_tensor = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
test_loss_tensor += F.nll_loss(output, target, reduction='sum')
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
correct_tensor += pred.eq(target.view_as(pred)).sum()
data_len += len(data)
test_loss /= data_len
dist.all_reduce(test_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
if(rank == 0):
print("Test average loss: {}, correct predictons: {}, total: {}, accuracy: {}% \n".format(test_loss_tensor.item() / len(test_loader.dataset), correct_tensor.item(), len(test_loader.dataset),
100.0 * correct_tensor.item() / len(test_loader.dataset)))
#print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
# test_loss, correct, data_len,
# 100. * correct / data_len))
def demo_basic(rank, world_size, args, use_cuda):
#--------------- Setup -------------#
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
#-----------------------------------#
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# ---------------- Data ------------#
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=train_sampler)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True)
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=world_size,
rank=rank
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=test_sampler)
#------------- MODEL LOAD----------------------#
model = Net().to(rank)
model.load_state_dict(torch.load('/home/centos/tf2/srijith/mnist_ddp.pt', map_location=torch.device(rank)), strict=False)
model = DDP(model, device_ids=[rank])
model.eval()
test(model, device, test_loader, rank)
#-----------------------------------------------#
cleanup()
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=5, metavar='N',
help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--gpus', type=int, default=1, metavar='N',
help='Number of GPUs')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
world_size = args.gpus
if torch.cuda.device_count() > 1:
print("We have available ", torch.cuda.device_count(), "GPUs! but using ",world_size," GPUs")
#########################################################
mp.spawn(demo_basic, args=(world_size, args, use_cuda), nprocs=world_size, join=True)
#########################################################
if __name__ == '__main__':
main()