Files
ieemoo-ai-isempty/traintry.py
Brainway 99f6ee4298 update
2022-09-27 02:09:58 +00:00

363 lines
15 KiB
Python

# coding=utf-8
from __future__ import absolute_import, division, print_function
import logging
import argparse
import os
import random
import numpy as np
import time
from datetime import timedelta
import torch
import torch.distributed as dist
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size
import pdb
logger = logging.getLogger(__name__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#计算并存储平均值
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
#简单准确率
def simple_accuracy(preds, labels):
return (preds == labels).mean()
#求均值
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
#保存模型
def save_model(args, model):
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.pth" % args.name)
torch.save(model, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
#根据数据集配置模型
def setup(args):
# Prepare model
config = CONFIGS[args.model_type]
config.split = args.split
config.slide_step = args.slide_step
if args.dataset == "emptyJudge5":
num_classes = 5
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
if args.pretrained_dir is not None:
model.load_from(np.load(args.pretrained_dir)) #他人预训练模型
if args.pretrained_model is not None:
model = torch.load(args.pretrained_model) #自己预训练模型
#model.to(args.device)
#pdb.set_trace()
num_params = count_parameters(model)
logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
return args, model
#计算模型参数数量
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
#随机种子
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
#模型验证
def valid(args, model, writer, test_loader, global_step):
eval_losses = AverageMeter()
logger.info("***** Running Validation *****")
# logger.info("val Num steps = %d", len(test_loader))
# logger.info("val Batch size = %d", args.eval_batch_size)
model.eval()
all_preds, all_label = [], []
epoch_iterator = tqdm(test_loader,
desc="Validating... (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True,
disable=args.local_rank not in [-1, 0])
loss_fct = torch.nn.CrossEntropyLoss()
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
with torch.no_grad():
logits = model(x)
eval_loss = loss_fct(logits, y)
eval_loss = eval_loss.mean()
eval_losses.update(eval_loss.item())
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(args.device)
# dist.barrier()
# val_accuracy = reduce_mean(accuracy, args.nprocs)
# val_accuracy = val_accuracy.detach().cpu().numpy()
val_accuracy = accuracy.detach().cpu().numpy()
logger.info("\n")
logger.info("Validation Results")
logger.info("Global Steps: %d" % global_step)
logger.info("Valid Loss: %2.5f" % eval_losses.avg)
logger.info("Valid Accuracy: %2.5f" % val_accuracy)
if args.local_rank in [-1, 0]:
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
return val_accuracy
#模型训练
def train(args, model):
""" Train the model """
if args.local_rank in [-1, 0]:
os.makedirs(args.output_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
# Prepare dataset
train_loader, test_loader = get_loader(args)
logger.info("train Num steps = %d", len(train_loader))
logger.info("val Num steps = %d", len(test_loader))
# Prepare optimizer and scheduler
optimizer = torch.optim.SGD(model.parameters(),
lr=args.learning_rate,
momentum=0.9,
weight_decay=args.weight_decay)
t_total = args.num_steps
if args.decay_type == "cosine":
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
else:
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
# Train!
logger.info("***** Running training *****")
logger.info(" Total optimization steps = %d", args.num_steps)
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
model.zero_grad()
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
losses = AverageMeter()
global_step, best_acc = 0, 0
start_time = time.time()
while True:
model.train()
epoch_iterator = tqdm(train_loader,
desc="Training (X / X Steps) (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True,
disable=args.local_rank not in [-1, 0])
all_preds, all_label = [], []
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
loss, logits = model(x, y)
loss = loss.mean()
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
losses.update(loss.item()*args.gradient_accumulation_steps)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
)
if args.local_rank in [-1, 0]:
writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
if global_step % args.eval_every == 0:
with torch.no_grad():
accuracy = valid(args, model, writer, test_loader, global_step)
if args.local_rank in [-1, 0]:
if best_acc < accuracy:
save_model(args, model)
best_acc = accuracy
logger.info("best accuracy so far: %f" % best_acc)
model.train()
if global_step % t_total == 0:
break
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(args.device)
# dist.barrier()
# train_accuracy = reduce_mean(accuracy, args.nprocs)
# train_accuracy = train_accuracy.detach().cpu().numpy()
train_accuracy = accuracy.detach().cpu().numpy()
logger.info("train accuracy so far: %f" % train_accuracy)
losses.reset()
if global_step % t_total == 0:
break
writer.close()
logger.info("Best Accuracy: \t%f" % best_acc)
logger.info("End Training!")
end_time = time.time()
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
#主函数
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", type=str, default='ieemooempty',
help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
default="emptyJudge5", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='./')
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
default="ViT-B_16",help="Which variant to use.")
# parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz",
# help="Where to search for pretrained ViT models.")
#parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None
parser.add_argument("--pretrained_dir", type=str, default=None,
help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None
parser.add_argument("--output_dir", default="./output", type=str,
help="The output directory where checkpoints will be written.")
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
parser.add_argument("--train_batch_size", default=8, type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
parser.add_argument("--eval_every", default=786, type=int,
help="Run prediction on validation set every so many steps."
"Will always run one evaluation at the end of training.")
parser.add_argument("--learning_rate", default=3e-2, type=float,
help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0.00001, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=78600, type=int, #100000
help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
args = parser.parse_args()
args.data_root = '{}/{}'.format(args.data_root, args.dataset)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60))
args.n_gpu = 1
args.device = device
args.nprocs = torch.cuda.device_count()
# Setup logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" %
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1)))
# Set seed
set_seed(args)
# Model & Tokenizer Setup
args, model = setup(args)
# Training
train(args, model)
if __name__ == "__main__":
torch.cuda.empty_cache()
main()