多机并行计算
This commit is contained in:
@ -266,11 +266,15 @@ def main():
|
||||
|
||||
if distributed:
|
||||
# 分布式训练:使用mp.spawn启动多个进程
|
||||
world_size = torch.cuda.device_count()
|
||||
local_size = torch.cuda.device_count()
|
||||
world_size = int(conf['distributed']['node_num'])*local_size
|
||||
mp.spawn(
|
||||
run_training,
|
||||
args=(world_size, conf),
|
||||
nprocs=world_size,
|
||||
args=(conf['distributed']['node_rank'],
|
||||
local_size,
|
||||
world_size,
|
||||
conf),
|
||||
nprocs=local_size,
|
||||
join=True
|
||||
)
|
||||
else:
|
||||
@ -279,11 +283,12 @@ def main():
|
||||
run_training_loop(components)
|
||||
|
||||
|
||||
def run_training(rank, world_size, conf):
|
||||
def run_training(local_rank, node_rank, local_size, world_size, conf):
|
||||
"""实际执行训练的函数,供mp.spawn调用"""
|
||||
# 初始化分布式环境
|
||||
rank = local_rank + node_rank * local_size
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['WORLD_SIZE'] = str(local_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
|
Reference in New Issue
Block a user