Files
ieemoo-ai-isempty/train.py
2023-09-21 18:28:43 +08:00

395 lines
16 KiB
Python
Executable File

# coding=utf-8
from __future__ import absolute_import, division, print_function
import logging
import argparse
import os
import random
import numpy as np
import time
from datetime import timedelta
import torch
import torch.distributed as dist
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size
import pdb
logger = logging.getLogger(__name__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
def save_model(args, model):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
checkpoint = {
'model': model_to_save.state_dict(),
}
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
def save_eve_model(args, model, eve_name):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % eve_name)
checkpoint = {
'model': model_to_save.state_dict(),
}
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
def setup(args):
# Prepare model
config = CONFIGS[args.model_type]
config.split = args.split
config.slide_step = args.slide_step
if args.dataset == "CUB_200_2011":
num_classes = 200
elif args.dataset == "car":
num_classes = 196
elif args.dataset == "nabirds":
num_classes = 555
elif args.dataset == "dog":
num_classes = 120
elif args.dataset == "INat2017":
num_classes = 5089
elif args.dataset == "emptyJudge5":
num_classes = 5
elif args.dataset == "emptyJudge4":
num_classes = 4
elif args.dataset == "emptyJudge3":
num_classes = 3
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
model.load_from(np.load(args.pretrained_dir))
if args.pretrained_model is not None:
pretrained_model = torch.load(args.pretrained_model)['model']
model.load_state_dict(pretrained_model)
#model.to(args.device)
#pdb.set_trace()
num_params = count_parameters(model)
logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
return args, model
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def valid(args, model, writer, test_loader, global_step):
eval_losses = AverageMeter()
logger.info("***** Running Validation *****")
# logger.info("val Num steps = %d", len(test_loader))
# logger.info("val Batch size = %d", args.eval_batch_size)
model.eval()
all_preds, all_label = [], []
epoch_iterator = tqdm(test_loader,
desc="Validating... (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True,
disable=args.local_rank not in [-1, 0])
loss_fct = torch.nn.CrossEntropyLoss()
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
with torch.no_grad():
logits = model(x)
eval_loss = loss_fct(logits, y)
eval_loss = eval_loss.mean()
eval_losses.update(eval_loss.item())
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(args.device)
# dist.barrier()
# val_accuracy = reduce_mean(accuracy, args.nprocs)
# val_accuracy = val_accuracy.detach().cpu().numpy()
val_accuracy = accuracy.detach().cpu().numpy()
logger.info("\n")
logger.info("Validation Results")
logger.info("Global Steps: %d" % global_step)
logger.info("Valid Loss: %2.5f" % eval_losses.avg)
logger.info("Valid Accuracy: %2.5f" % val_accuracy)
if args.local_rank in [-1, 0]:
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
return val_accuracy
def train(args, model):
""" Train the model """
if args.local_rank in [-1, 0]:
os.makedirs(args.output_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
# Prepare dataset
train_loader, test_loader = get_loader(args)
logger.info("train Num steps = %d", len(train_loader))
logger.info("val Num steps = %d", len(test_loader))
# Prepare optimizer and scheduler
optimizer = torch.optim.SGD(model.parameters(),
lr=args.learning_rate,
momentum=0.9,
weight_decay=args.weight_decay)
t_total = args.num_steps
if args.decay_type == "cosine":
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
else:
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
# Train!
logger.info("***** Running training *****")
logger.info(" Total optimization steps = %d", args.num_steps)
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
model.zero_grad()
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
losses = AverageMeter()
global_step, best_acc = 0, 0
start_time = time.time()
while True:
model.train()
epoch_iterator = tqdm(train_loader,
desc="Training (X / X Steps) (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True,
disable=args.local_rank not in [-1, 0])
all_preds, all_label = [], []
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
loss, logits = model(x, y)
loss = loss.mean()
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
losses.update(loss.item()*args.gradient_accumulation_steps)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
)
if args.local_rank in [-1, 0]:
writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
if global_step % args.eval_every == 0:
with torch.no_grad():
accuracy = valid(args, model, writer, test_loader, global_step)
if args.local_rank in [-1, 0]:
if best_acc < accuracy:
save_model(args, model)
best_acc = accuracy
logger.info("best accuracy so far: %f" % best_acc)
if int(global_step)%10000 == 0:
save_eve_model(args, model, str(global_step))
model.train()
if global_step % t_total == 0:
break
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(args.device)
# dist.barrier()
# train_accuracy = reduce_mean(accuracy, args.nprocs)
# train_accuracy = train_accuracy.detach().cpu().numpy()
train_accuracy = accuracy.detach().cpu().numpy()
logger.info("train accuracy so far: %f" % train_accuracy)
losses.reset()
if global_step % t_total == 0:
break
writer.close()
logger.info("Best Accuracy: \t%f" % best_acc)
logger.info("End Training!")
end_time = time.time()
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", required=True,
help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
default="CUB_200_2011", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained')
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
default="ViT-B_16",help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
parser.add_argument("--output_dir", default="./output", type=str,
help="The output directory where checkpoints will be written.")
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument("--train_batch_size", default=64, type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=16, type=int,
help="Total batch size for eval.")
parser.add_argument("--eval_every", default=200, type=int, #200
help="Run prediction on validation set every so many steps."
"Will always run one evaluation at the end of training.")
parser.add_argument("--learning_rate", default=3e-2, type=float,
help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=40000, type=int, #100000
help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
args = parser.parse_args()
args.data_root = '{}/{}'.format(args.data_root, args.dataset)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
args.n_gpu = torch.cuda.device_count()
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60))
args.n_gpu = 1
args.device = device
args.nprocs = torch.cuda.device_count()
# Setup logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
# Set seed
set_seed(args)
# Model & Tokenizer Setup
args, model = setup(args)
# Training
train(args, model)
if __name__ == "__main__":
main()