多机并行计算

This commit is contained in:
lee
2025-08-18 10:14:05 +08:00
parent 99a204ee22
commit c978787ff8
8 changed files with 43 additions and 25 deletions

View File

@ -288,13 +288,13 @@ def run_training(local_rank, node_rank, local_size, world_size, conf):
# 初始化分布式环境
rank = local_rank + node_rank * local_size
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(local_size)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device('cuda', rank)
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
# 获取数据集而不是DataLoader
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
@ -306,8 +306,8 @@ def run_training(local_rank, node_rank, local_size, world_size, conf):
metric = metric.to(device)
# 包装为DistributedDataParallel模型
model = DDP(model, device_ids=[rank], output_device=rank)
metric = DDP(metric, device_ids=[rank], output_device=rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)