From 99f6ee4298defc4c61b6e4af9366feb58901dc58 Mon Sep 17 00:00:00 2001 From: Brainway Date: Tue, 27 Sep 2022 02:09:58 +0000 Subject: [PATCH] update --- traintry.py | 362 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 traintry.py diff --git a/traintry.py b/traintry.py new file mode 100644 index 0000000..d0ef4c9 --- /dev/null +++ b/traintry.py @@ -0,0 +1,362 @@ +# coding=utf-8 +from __future__ import absolute_import, division, print_function + +import logging +import argparse +import os +import random +import numpy as np +import time + +from datetime import timedelta + +import torch +import torch.distributed as dist + +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +from models.modeling import VisionTransformer, CONFIGS +from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule +from utils.data_utils import get_loader +from utils.dist_util import get_world_size +import pdb + +logger = logging.getLogger(__name__) + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +#计算并存储平均值 +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +#简单准确率 +def simple_accuracy(preds, labels): + return (preds == labels).mean() + +#求均值 +def reduce_mean(tensor, nprocs): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= nprocs + return rt + +#保存模型 +def save_model(args, model): + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.pth" % args.name) + torch.save(model, model_checkpoint) + logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) + +#根据数据集配置模型 +def setup(args): + # Prepare model + config = CONFIGS[args.model_type] + config.split = args.split + config.slide_step = args.slide_step + + if args.dataset == "emptyJudge5": + num_classes = 5 + + model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) + + if args.pretrained_dir is not None: + model.load_from(np.load(args.pretrained_dir)) #他人预训练模型 + if args.pretrained_model is not None: + model = torch.load(args.pretrained_model) #自己预训练模型 + #model.to(args.device) + #pdb.set_trace() + num_params = count_parameters(model) + + logger.info("{}".format(config)) + logger.info("Training parameters %s", args) + logger.info("Total Parameter: \t%2.1fM" % num_params) + model = torch.nn.DataParallel(model, device_ids=[0]).cuda() + return args, model + +#计算模型参数数量 +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params/1000000 + +#随机种子 +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + +#模型验证 +def valid(args, model, writer, test_loader, global_step): + eval_losses = AverageMeter() + + logger.info("***** Running Validation *****") + # logger.info("val Num steps = %d", len(test_loader)) + # logger.info("val Batch size = %d", args.eval_batch_size) + + model.eval() + all_preds, all_label = [], [] + epoch_iterator = tqdm(test_loader, + desc="Validating... (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + loss_fct = torch.nn.CrossEntropyLoss() + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + with torch.no_grad(): + logits = model(x) + + eval_loss = loss_fct(logits, y) + eval_loss = eval_loss.mean() + eval_losses.update(eval_loss.item()) + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) + + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + accuracy = torch.tensor(accuracy).to(args.device) + # dist.barrier() + # val_accuracy = reduce_mean(accuracy, args.nprocs) + # val_accuracy = val_accuracy.detach().cpu().numpy() + val_accuracy = accuracy.detach().cpu().numpy() + + logger.info("\n") + logger.info("Validation Results") + logger.info("Global Steps: %d" % global_step) + logger.info("Valid Loss: %2.5f" % eval_losses.avg) + logger.info("Valid Accuracy: %2.5f" % val_accuracy) + if args.local_rank in [-1, 0]: + writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step) + return val_accuracy + +#模型训练 +def train(args, model): + """ Train the model """ + if args.local_rank in [-1, 0]: + os.makedirs(args.output_dir, exist_ok=True) + writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + # Prepare dataset + train_loader, test_loader = get_loader(args) + logger.info("train Num steps = %d", len(train_loader)) + logger.info("val Num steps = %d", len(test_loader)) + # Prepare optimizer and scheduler + optimizer = torch.optim.SGD(model.parameters(), + lr=args.learning_rate, + momentum=0.9, + weight_decay=args.weight_decay) + t_total = args.num_steps + if args.decay_type == "cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + else: + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + + # Train! + logger.info("***** Running training *****") + logger.info(" Total optimization steps = %d", args.num_steps) + logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + + model.zero_grad() + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + losses = AverageMeter() + global_step, best_acc = 0, 0 + start_time = time.time() + while True: + model.train() + epoch_iterator = tqdm(train_loader, + desc="Training (X / X Steps) (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + all_preds, all_label = [], [] + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + + loss, logits = model(x, y) + loss = loss.mean() + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + losses.update(loss.item()*args.gradient_accumulation_steps) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scheduler.step() + optimizer.step() + optimizer.zero_grad() + global_step += 1 + + epoch_iterator.set_description( + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) + ) + if args.local_rank in [-1, 0]: + writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) + writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) + if global_step % args.eval_every == 0: + with torch.no_grad(): + accuracy = valid(args, model, writer, test_loader, global_step) + if args.local_rank in [-1, 0]: + if best_acc < accuracy: + save_model(args, model) + best_acc = accuracy + logger.info("best accuracy so far: %f" % best_acc) + model.train() + + if global_step % t_total == 0: + break + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + accuracy = torch.tensor(accuracy).to(args.device) + # dist.barrier() + # train_accuracy = reduce_mean(accuracy, args.nprocs) + # train_accuracy = train_accuracy.detach().cpu().numpy() + train_accuracy = accuracy.detach().cpu().numpy() + logger.info("train accuracy so far: %f" % train_accuracy) + losses.reset() + if global_step % t_total == 0: + break + + writer.close() + logger.info("Best Accuracy: \t%f" % best_acc) + logger.info("End Training!") + end_time = time.time() + logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600)) + + + +#主函数 +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--name", type=str, default='ieemooempty', + help="Name of this run. Used for monitoring.") + parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"], + default="emptyJudge5", help="Which dataset.") + parser.add_argument('--data_root', type=str, default='./') + parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"], + default="ViT-B_16",help="Which variant to use.") + # parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz", + # help="Where to search for pretrained ViT models.") + #parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None + parser.add_argument("--pretrained_dir", type=str, default=None, + help="Where to search for pretrained ViT models.") + parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None + parser.add_argument("--output_dir", default="./output", type=str, + help="The output directory where checkpoints will be written.") + parser.add_argument("--img_size", default=600, type=int, help="Resolution size") + parser.add_argument("--train_batch_size", default=8, type=int, + help="Total batch size for training.") + parser.add_argument("--eval_batch_size", default=8, type=int, + help="Total batch size for eval.") + parser.add_argument("--eval_every", default=786, type=int, + help="Run prediction on validation set every so many steps." + "Will always run one evaluation at the end of training.") + + parser.add_argument("--learning_rate", default=3e-2, type=float, + help="The initial learning rate for SGD.") + parser.add_argument("--weight_decay", default=0.00001, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--num_steps", default=78600, type=int, #100000 + help="Total number of training epochs to perform.") + parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", + help="How to decay the learning rate.") + parser.add_argument("--warmup_steps", default=500, type=int, + help="Step of training to perform learning rate warmup for.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + + parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") + + parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap + parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") + args = parser.parse_args() + + args.data_root = '{}/{}'.format(args.data_root, args.dataset) + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.n_gpu = torch.cuda.device_count() + print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60)) + args.n_gpu = 1 + args.device = device + args.nprocs = torch.cuda.device_count() + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))) + + # Set seed + set_seed(args) + + # Model & Tokenizer Setup + args, model = setup(args) + # Training + train(args, model) + + +if __name__ == "__main__": + torch.cuda.empty_cache() + main()