多机并行计算

This commit is contained in:
lee
2025-08-14 10:09:54 +08:00
parent bc896fc688
commit 99a204ee22
18 changed files with 105 additions and 55 deletions

View File

@ -266,11 +266,15 @@ def main():
if distributed:
# 分布式训练使用mp.spawn启动多个进程
world_size = torch.cuda.device_count()
local_size = torch.cuda.device_count()
world_size = int(conf['distributed']['node_num'])*local_size
mp.spawn(
run_training,
args=(world_size, conf),
nprocs=world_size,
args=(conf['distributed']['node_rank'],
local_size,
world_size,
conf),
nprocs=local_size,
join=True
)
else:
@ -279,11 +283,12 @@ def main():
run_training_loop(components)
def run_training(rank, world_size, conf):
def run_training(local_rank, node_rank, local_size, world_size, conf):
"""实际执行训练的函数供mp.spawn调用"""
# 初始化分布式环境
rank = local_rank + node_rank * local_size
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['WORLD_SIZE'] = str(local_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'