diff --git a/train.py b/train.py index bf63476..dff4905 100755 --- a/train.py +++ b/train.py @@ -24,7 +24,9 @@ import pdb logger = logging.getLogger(__name__) -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +#计算并存储平均值 class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): @@ -42,57 +44,39 @@ class AverageMeter(object): self.count += n self.avg = self.sum / self.count - +#简单准确率 def simple_accuracy(preds, labels): return (preds == labels).mean() - +#求均值 def reduce_mean(tensor, nprocs): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt - +#保存模型 def save_model(args, model): - model_to_save = model.module if hasattr(model, 'module') else model - model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) - checkpoint = { - 'model': model_to_save.state_dict(), - } - torch.save(checkpoint, model_checkpoint) - logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) - + model_checkpoint = os.path.join("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin") + torch.save(model, model_checkpoint) + logger.info("Saved model checkpoint to [File: %s]", "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin") +#根据数据集配置模型 def setup(args): # Prepare model config = CONFIGS[args.model_type] config.split = args.split config.slide_step = args.slide_step - - if args.dataset == "CUB_200_2011": - num_classes = 200 - elif args.dataset == "car": - num_classes = 196 - elif args.dataset == "nabirds": - num_classes = 555 - elif args.dataset == "dog": - num_classes = 120 - elif args.dataset == "INat2017": - num_classes = 5089 - elif args.dataset == "emptyJudge5": + + if args.dataset == "emptyJudge5": num_classes = 5 - elif args.dataset == "emptyJudge4": - num_classes = 4 - elif args.dataset == "emptyJudge3": - num_classes = 3 - + model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) - model.load_from(np.load(args.pretrained_dir)) + if args.pretrained_dir is not None: + model.load_from(np.load(args.pretrained_dir)) #他人预训练模型 if args.pretrained_model is not None: - pretrained_model = torch.load(args.pretrained_model)['model'] - model.load_state_dict(pretrained_model) + model = torch.load(args.pretrained_model) #自己预训练模型 #model.to(args.device) #pdb.set_trace() num_params = count_parameters(model) @@ -100,15 +84,15 @@ def setup(args): logger.info("{}".format(config)) logger.info("Training parameters %s", args) logger.info("Total Parameter: \t%2.1fM" % num_params) - model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda() + model = torch.nn.DataParallel(model, device_ids=[0]).cuda() return args, model - +#计算模型参数数量 def count_parameters(model): params = sum(p.numel() for p in model.parameters() if p.requires_grad) return params/1000000 - +#随机种子 def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) @@ -116,7 +100,7 @@ def set_seed(args): if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) - +#模型验证 def valid(args, model, writer, test_loader, global_step): eval_losses = AverageMeter() @@ -173,7 +157,7 @@ def valid(args, model, writer, test_loader, global_step): writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step) return val_accuracy - +#模型训练 def train(args, model): """ Train the model """ if args.local_rank in [-1, 0]: @@ -287,37 +271,41 @@ def train(args, model): end_time = time.time() logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600)) + +#主函数 def main(): parser = argparse.ArgumentParser() # Required parameters - parser.add_argument("--name", required=True, + parser.add_argument("--name", type=str, default='ieemooempty', help="Name of this run. Used for monitoring.") parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"], - default="CUB_200_2011", help="Which dataset.") - parser.add_argument('--data_root', type=str, default='/data/fineGrained') + default="emptyJudge5", help="Which dataset.") + parser.add_argument('--data_root', type=str, default='./') parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"], default="ViT-B_16",help="Which variant to use.") - parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz", + parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz", help="Where to search for pretrained ViT models.") - parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model") - #parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") + #parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None + # parser.add_argument("--pretrained_dir", type=str, default=None, + # help="Where to search for pretrained ViT models.") + parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None parser.add_argument("--output_dir", default="./output", type=str, help="The output directory where checkpoints will be written.") - parser.add_argument("--img_size", default=448, type=int, help="Resolution size") - parser.add_argument("--train_batch_size", default=64, type=int, + parser.add_argument("--img_size", default=320, type=int, help="Resolution size") + parser.add_argument("--train_batch_size", default=8, type=int, help="Total batch size for training.") - parser.add_argument("--eval_batch_size", default=16, type=int, + parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") - parser.add_argument("--eval_every", default=200, type=int, + parser.add_argument("--eval_every", default=786, type=int, help="Run prediction on validation set every so many steps." "Will always run one evaluation at the end of training.") - parser.add_argument("--learning_rate", default=3e-2, type=float, + parser.add_argument("--learning_rate", default=3e-2, type=float, help="The initial learning rate for SGD.") - parser.add_argument("--weight_decay", default=0, type=float, + parser.add_argument("--weight_decay", default=0.00001, type=float, help="Weight deay if we apply some.") - parser.add_argument("--num_steps", default=8000, type=int, #100000 + parser.add_argument("--num_steps", default=78600, type=int, #100000 help="Total number of training epochs to perform.") parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", help="How to decay the learning rate.") @@ -332,15 +320,6 @@ def main(): help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") - parser.add_argument('--fp16', action='store_true', - help="Whether to use 16-bit float precision instead of 32-bit") - parser.add_argument('--fp16_opt_level', type=str, default='O2', - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." - "See details at https://nvidia.github.io/apex/amp.html") - parser.add_argument('--loss_scale', type=float, default=0, - help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" - "0 (default value): dynamic loss scaling.\n" - "Positive power of 2: static loss scaling value.\n") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") @@ -352,7 +331,6 @@ def main(): # Setup CUDA, GPU & distributed training if args.local_rank == -1: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) args.n_gpu = torch.cuda.device_count() print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs @@ -367,8 +345,8 @@ def main(): logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) - logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % - (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))) # Set seed set_seed(args) @@ -380,4 +358,5 @@ def main(): if __name__ == "__main__": + torch.cuda.empty_cache() main()