多机并行计算
This commit is contained in:
@ -288,13 +288,13 @@ def run_training(local_rank, node_rank, local_size, world_size, conf):
|
||||
# 初始化分布式环境
|
||||
rank = local_rank + node_rank * local_size
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(local_size)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device('cuda', rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device('cuda', local_rank)
|
||||
|
||||
# 获取数据集而不是DataLoader
|
||||
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||
@ -306,8 +306,8 @@ def run_training(local_rank, node_rank, local_size, world_size, conf):
|
||||
metric = metric.to(device)
|
||||
|
||||
# 包装为DistributedDataParallel模型
|
||||
model = DDP(model, device_ids=[rank], output_device=rank)
|
||||
metric = DDP(metric, device_ids=[rank], output_device=rank)
|
||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
||||
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
# 设置损失函数、优化器和调度器
|
||||
criterion = setup_loss_function(conf)
|
||||
|
Reference in New Issue
Block a user