# coding=utf-8 from __future__ import absolute_import, division, print_function import logging import argparse import os import random import numpy as np import time from datetime import timedelta import torch import torch.distributed as dist from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter from models.modeling import VisionTransformer, CONFIGS from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule from utils.data_utils import get_loader from utils.dist_util import get_world_size import pdb logger = logging.getLogger(__name__) os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def simple_accuracy(preds, labels): return (preds == labels).mean() def reduce_mean(tensor, nprocs): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt def save_model(args, model): model_to_save = model.module if hasattr(model, 'module') else model model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) checkpoint = { 'model': model_to_save.state_dict(), } torch.save(checkpoint, model_checkpoint) logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) def save_eve_model(args, model, eve_name): model_to_save = model.module if hasattr(model, 'module') else model model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % eve_name) checkpoint = { 'model': model_to_save.state_dict(), } torch.save(checkpoint, model_checkpoint) logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) def setup(args): # Prepare model config = CONFIGS[args.model_type] config.split = args.split config.slide_step = args.slide_step if args.dataset == "CUB_200_2011": num_classes = 200 elif args.dataset == "car": num_classes = 196 elif args.dataset == "nabirds": num_classes = 555 elif args.dataset == "dog": num_classes = 120 elif args.dataset == "INat2017": num_classes = 5089 elif args.dataset == "emptyJudge5": num_classes = 5 elif args.dataset == "emptyJudge4": num_classes = 4 elif args.dataset == "emptyJudge3": num_classes = 3 model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) model.load_from(np.load(args.pretrained_dir)) if args.pretrained_model is not None: pretrained_model = torch.load(args.pretrained_model)['model'] model.load_state_dict(pretrained_model) #model.to(args.device) #pdb.set_trace() num_params = count_parameters(model) logger.info("{}".format(config)) logger.info("Training parameters %s", args) logger.info("Total Parameter: \t%2.1fM" % num_params) model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda() return args, model def count_parameters(model): params = sum(p.numel() for p in model.parameters() if p.requires_grad) return params/1000000 def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) def valid(args, model, writer, test_loader, global_step): eval_losses = AverageMeter() logger.info("***** Running Validation *****") # logger.info("val Num steps = %d", len(test_loader)) # logger.info("val Batch size = %d", args.eval_batch_size) model.eval() all_preds, all_label = [], [] epoch_iterator = tqdm(test_loader, desc="Validating... (loss=X.X)", bar_format="{l_bar}{r_bar}", dynamic_ncols=True, disable=args.local_rank not in [-1, 0]) loss_fct = torch.nn.CrossEntropyLoss() for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) x, y = batch with torch.no_grad(): logits = model(x) eval_loss = loss_fct(logits, y) eval_loss = eval_loss.mean() eval_losses.update(eval_loss.item()) preds = torch.argmax(logits, dim=-1) if len(all_preds) == 0: all_preds.append(preds.detach().cpu().numpy()) all_label.append(y.detach().cpu().numpy()) else: all_preds[0] = np.append( all_preds[0], preds.detach().cpu().numpy(), axis=0 ) all_label[0] = np.append( all_label[0], y.detach().cpu().numpy(), axis=0 ) epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) all_preds, all_label = all_preds[0], all_label[0] accuracy = simple_accuracy(all_preds, all_label) accuracy = torch.tensor(accuracy).to(args.device) # dist.barrier() # val_accuracy = reduce_mean(accuracy, args.nprocs) # val_accuracy = val_accuracy.detach().cpu().numpy() val_accuracy = accuracy.detach().cpu().numpy() logger.info("\n") logger.info("Validation Results") logger.info("Global Steps: %d" % global_step) logger.info("Valid Loss: %2.5f" % eval_losses.avg) logger.info("Valid Accuracy: %2.5f" % val_accuracy) if args.local_rank in [-1, 0]: writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step) return val_accuracy def train(args, model): """ Train the model """ if args.local_rank in [-1, 0]: os.makedirs(args.output_dir, exist_ok=True) writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps # Prepare dataset train_loader, test_loader = get_loader(args) logger.info("train Num steps = %d", len(train_loader)) logger.info("val Num steps = %d", len(test_loader)) # Prepare optimizer and scheduler optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) t_total = args.num_steps if args.decay_type == "cosine": scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) else: scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) # Train! logger.info("***** Running training *****") logger.info(" Total optimization steps = %d", args.num_steps) logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) model.zero_grad() set_seed(args) # Added here for reproducibility (even between python 2 and 3) losses = AverageMeter() global_step, best_acc = 0, 0 start_time = time.time() while True: model.train() epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", bar_format="{l_bar}{r_bar}", dynamic_ncols=True, disable=args.local_rank not in [-1, 0]) all_preds, all_label = [], [] for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) x, y = batch loss, logits = model(x, y) loss = loss.mean() preds = torch.argmax(logits, dim=-1) if len(all_preds) == 0: all_preds.append(preds.detach().cpu().numpy()) all_label.append(y.detach().cpu().numpy()) else: all_preds[0] = np.append( all_preds[0], preds.detach().cpu().numpy(), axis=0 ) all_label[0] = np.append( all_label[0], y.detach().cpu().numpy(), axis=0 ) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: losses.update(loss.item()*args.gradient_accumulation_steps) torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() optimizer.zero_grad() global_step += 1 epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) ) if args.local_rank in [-1, 0]: writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) if global_step % args.eval_every == 0: with torch.no_grad(): accuracy = valid(args, model, writer, test_loader, global_step) if args.local_rank in [-1, 0]: if best_acc < accuracy: save_model(args, model) best_acc = accuracy logger.info("best accuracy so far: %f" % best_acc) if int(global_step)%10000 == 0: save_eve_model(args, model, str(global_step)) model.train() if global_step % t_total == 0: break all_preds, all_label = all_preds[0], all_label[0] accuracy = simple_accuracy(all_preds, all_label) accuracy = torch.tensor(accuracy).to(args.device) # dist.barrier() # train_accuracy = reduce_mean(accuracy, args.nprocs) # train_accuracy = train_accuracy.detach().cpu().numpy() train_accuracy = accuracy.detach().cpu().numpy() logger.info("train accuracy so far: %f" % train_accuracy) losses.reset() if global_step % t_total == 0: break writer.close() logger.info("Best Accuracy: \t%f" % best_acc) logger.info("End Training!") end_time = time.time() logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600)) def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--name", required=True, help="Name of this run. Used for monitoring.") parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"], default="CUB_200_2011", help="Which dataset.") parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained') parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"], default="ViT-B_16",help="Which variant to use.") parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz", help="Where to search for pretrained ViT models.") parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") parser.add_argument("--output_dir", default="./output", type=str, help="The output directory where checkpoints will be written.") parser.add_argument("--img_size", default=448, type=int, help="Resolution size") parser.add_argument("--train_batch_size", default=64, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=16, type=int, help="Total batch size for eval.") parser.add_argument("--eval_every", default=200, type=int, #200 help="Run prediction on validation set every so many steps." "Will always run one evaluation at the end of training.") parser.add_argument("--learning_rate", default=3e-2, type=float, help="The initial learning rate for SGD.") parser.add_argument("--weight_decay", default=0, type=float, help="Weight deay if we apply some.") parser.add_argument("--num_steps", default=40000, type=int, #100000 help="Total number of training epochs to perform.") parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", help="How to decay the learning rate.") parser.add_argument("--warmup_steps", default=500, type=int, help="Step of training to perform learning rate warmup for.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O2', help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument('--loss_scale', type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") args = parser.parse_args() args.data_root = '{}/{}'.format(args.data_root, args.dataset) # Setup CUDA, GPU & distributed training if args.local_rank == -1: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) args.n_gpu = torch.cuda.device_count() #print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60)) args.n_gpu = 1 args.device = device args.nprocs = torch.cuda.device_count() # Setup logging logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) # Set seed set_seed(args) # Model & Tokenizer Setup args, model = setup(args) # Training train(args, model) if __name__ == "__main__": main()