update
This commit is contained in:
30
utils/dist_util.py
Executable file
30
utils/dist_util.py
Executable file
@ -0,0 +1,30 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
def format_step(step):
|
||||
if isinstance(step, str):
|
||||
return step
|
||||
s = ""
|
||||
if len(step) > 0:
|
||||
s += "Training Epoch: {} ".format(step[0])
|
||||
if len(step) > 1:
|
||||
s += "Training Iteration: {} ".format(step[1])
|
||||
if len(step) > 2:
|
||||
s += "Validation Iteration: {} ".format(step[2])
|
||||
return s
|
Reference in New Issue
Block a user