31 lines
711 B
Python
Executable File
31 lines
711 B
Python
Executable File
import torch.distributed as dist
|
|
|
|
def get_rank():
|
|
if not dist.is_available():
|
|
return 0
|
|
if not dist.is_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
def get_world_size():
|
|
if not dist.is_available():
|
|
return 1
|
|
if not dist.is_initialized():
|
|
return 1
|
|
return dist.get_world_size()
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0
|
|
|
|
def format_step(step):
|
|
if isinstance(step, str):
|
|
return step
|
|
s = ""
|
|
if len(step) > 0:
|
|
s += "Training Epoch: {} ".format(step[0])
|
|
if len(step) > 1:
|
|
s += "Training Iteration: {} ".format(step[1])
|
|
if len(step) > 2:
|
|
s += "Validation Iteration: {} ".format(step[2])
|
|
return s
|