update
This commit is contained in:
8
train.py
8
train.py
@ -20,10 +20,11 @@ from models.modeling import VisionTransformer, CONFIGS
|
||||
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
|
||||
from utils.data_utils import get_loader
|
||||
from utils.dist_util import get_world_size
|
||||
import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
@ -92,12 +93,14 @@ def setup(args):
|
||||
if args.pretrained_model is not None:
|
||||
pretrained_model = torch.load(args.pretrained_model)['model']
|
||||
model.load_state_dict(pretrained_model)
|
||||
model.to(args.device)
|
||||
#model.to(args.device)
|
||||
#pdb.set_trace()
|
||||
num_params = count_parameters(model)
|
||||
|
||||
logger.info("{}".format(config))
|
||||
logger.info("Training parameters %s", args)
|
||||
logger.info("Total Parameter: \t%2.1fM" % num_params)
|
||||
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
|
||||
return args, model
|
||||
|
||||
|
||||
@ -351,6 +354,7 @@ def main():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
|
Reference in New Issue
Block a user