This commit is contained in:
lichen
2022-06-02 16:23:27 +08:00
parent 0b34aa05e0
commit b84a92f67a
8 changed files with 82 additions and 31 deletions

View File

@ -20,10 +20,11 @@ from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size
import pdb
logger = logging.getLogger(__name__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
@ -92,12 +93,14 @@ def setup(args):
if args.pretrained_model is not None:
pretrained_model = torch.load(args.pretrained_model)['model']
model.load_state_dict(pretrained_model)
model.to(args.device)
#model.to(args.device)
#pdb.set_trace()
num_params = count_parameters(model)
logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
return args, model
@ -351,6 +354,7 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
args.n_gpu = torch.cuda.device_count()
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)