masterUpdate

This commit is contained in:
2022-12-29 13:45:07 +08:00
parent b84a92f67a
commit 615c91feb6
2 changed files with 18 additions and 6 deletions

View File

@ -63,6 +63,15 @@ def save_model(args, model):
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
def save_eve_model(args, model, eve_name):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % eve_name)
checkpoint = {
'model': model_to_save.state_dict(),
}
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
def setup(args):
# Prepare model
@ -265,6 +274,8 @@ def train(args, model):
save_model(args, model)
best_acc = accuracy
logger.info("best accuracy so far: %f" % best_acc)
if int(global_step)%10000 == 0:
save_eve_model(args, model, str(global_step))
model.train()
if global_step % t_total == 0:
@ -295,7 +306,7 @@ def main():
help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
default="CUB_200_2011", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='/data/fineGrained')
parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained')
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
default="ViT-B_16",help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
@ -309,7 +320,7 @@ def main():
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=16, type=int,
help="Total batch size for eval.")
parser.add_argument("--eval_every", default=200, type=int,
parser.add_argument("--eval_every", default=200, type=int, #200
help="Run prediction on validation set every so many steps."
"Will always run one evaluation at the end of training.")
@ -317,7 +328,7 @@ def main():
help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=8000, type=int, #100000
parser.add_argument("--num_steps", default=40000, type=int, #100000
help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
help="How to decay the learning rate.")
@ -352,9 +363,9 @@ def main():
# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
args.n_gpu = torch.cuda.device_count()
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)