395 lines
16 KiB
Python
Executable File
395 lines
16 KiB
Python
Executable File
# coding=utf-8
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import logging
|
|
import argparse
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
import time
|
|
|
|
from datetime import timedelta
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from tqdm import tqdm
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from models.modeling import VisionTransformer, CONFIGS
|
|
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
|
|
from utils.data_utils import get_loader
|
|
from utils.dist_util import get_world_size
|
|
import pdb
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def simple_accuracy(preds, labels):
|
|
return (preds == labels).mean()
|
|
|
|
|
|
def reduce_mean(tensor, nprocs):
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
rt /= nprocs
|
|
return rt
|
|
|
|
|
|
def save_model(args, model):
|
|
model_to_save = model.module if hasattr(model, 'module') else model
|
|
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
|
|
checkpoint = {
|
|
'model': model_to_save.state_dict(),
|
|
}
|
|
torch.save(checkpoint, model_checkpoint)
|
|
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
|
|
|
def save_eve_model(args, model, eve_name):
|
|
model_to_save = model.module if hasattr(model, 'module') else model
|
|
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % eve_name)
|
|
checkpoint = {
|
|
'model': model_to_save.state_dict(),
|
|
}
|
|
torch.save(checkpoint, model_checkpoint)
|
|
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
|
|
|
|
|
def setup(args):
|
|
# Prepare model
|
|
config = CONFIGS[args.model_type]
|
|
config.split = args.split
|
|
config.slide_step = args.slide_step
|
|
|
|
if args.dataset == "CUB_200_2011":
|
|
num_classes = 200
|
|
elif args.dataset == "car":
|
|
num_classes = 196
|
|
elif args.dataset == "nabirds":
|
|
num_classes = 555
|
|
elif args.dataset == "dog":
|
|
num_classes = 120
|
|
elif args.dataset == "INat2017":
|
|
num_classes = 5089
|
|
elif args.dataset == "emptyJudge5":
|
|
num_classes = 5
|
|
elif args.dataset == "emptyJudge4":
|
|
num_classes = 4
|
|
elif args.dataset == "emptyJudge3":
|
|
num_classes = 3
|
|
|
|
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
|
|
|
|
model.load_from(np.load(args.pretrained_dir))
|
|
if args.pretrained_model is not None:
|
|
pretrained_model = torch.load(args.pretrained_model)['model']
|
|
model.load_state_dict(pretrained_model)
|
|
#model.to(args.device)
|
|
#pdb.set_trace()
|
|
num_params = count_parameters(model)
|
|
|
|
logger.info("{}".format(config))
|
|
logger.info("Training parameters %s", args)
|
|
logger.info("Total Parameter: \t%2.1fM" % num_params)
|
|
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
|
|
return args, model
|
|
|
|
|
|
def count_parameters(model):
|
|
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
return params/1000000
|
|
|
|
|
|
def set_seed(args):
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
def valid(args, model, writer, test_loader, global_step):
|
|
eval_losses = AverageMeter()
|
|
|
|
logger.info("***** Running Validation *****")
|
|
# logger.info("val Num steps = %d", len(test_loader))
|
|
# logger.info("val Batch size = %d", args.eval_batch_size)
|
|
|
|
model.eval()
|
|
all_preds, all_label = [], []
|
|
epoch_iterator = tqdm(test_loader,
|
|
desc="Validating... (loss=X.X)",
|
|
bar_format="{l_bar}{r_bar}",
|
|
dynamic_ncols=True,
|
|
disable=args.local_rank not in [-1, 0])
|
|
loss_fct = torch.nn.CrossEntropyLoss()
|
|
for step, batch in enumerate(epoch_iterator):
|
|
batch = tuple(t.to(args.device) for t in batch)
|
|
x, y = batch
|
|
with torch.no_grad():
|
|
logits = model(x)
|
|
|
|
eval_loss = loss_fct(logits, y)
|
|
eval_loss = eval_loss.mean()
|
|
eval_losses.update(eval_loss.item())
|
|
|
|
preds = torch.argmax(logits, dim=-1)
|
|
|
|
if len(all_preds) == 0:
|
|
all_preds.append(preds.detach().cpu().numpy())
|
|
all_label.append(y.detach().cpu().numpy())
|
|
else:
|
|
all_preds[0] = np.append(
|
|
all_preds[0], preds.detach().cpu().numpy(), axis=0
|
|
)
|
|
all_label[0] = np.append(
|
|
all_label[0], y.detach().cpu().numpy(), axis=0
|
|
)
|
|
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
|
|
|
|
all_preds, all_label = all_preds[0], all_label[0]
|
|
accuracy = simple_accuracy(all_preds, all_label)
|
|
accuracy = torch.tensor(accuracy).to(args.device)
|
|
# dist.barrier()
|
|
# val_accuracy = reduce_mean(accuracy, args.nprocs)
|
|
# val_accuracy = val_accuracy.detach().cpu().numpy()
|
|
val_accuracy = accuracy.detach().cpu().numpy()
|
|
|
|
logger.info("\n")
|
|
logger.info("Validation Results")
|
|
logger.info("Global Steps: %d" % global_step)
|
|
logger.info("Valid Loss: %2.5f" % eval_losses.avg)
|
|
logger.info("Valid Accuracy: %2.5f" % val_accuracy)
|
|
if args.local_rank in [-1, 0]:
|
|
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
|
|
return val_accuracy
|
|
|
|
|
|
def train(args, model):
|
|
""" Train the model """
|
|
if args.local_rank in [-1, 0]:
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
|
|
|
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
|
|
|
# Prepare dataset
|
|
train_loader, test_loader = get_loader(args)
|
|
logger.info("train Num steps = %d", len(train_loader))
|
|
logger.info("val Num steps = %d", len(test_loader))
|
|
# Prepare optimizer and scheduler
|
|
optimizer = torch.optim.SGD(model.parameters(),
|
|
lr=args.learning_rate,
|
|
momentum=0.9,
|
|
weight_decay=args.weight_decay)
|
|
t_total = args.num_steps
|
|
if args.decay_type == "cosine":
|
|
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
|
else:
|
|
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
|
|
|
# Train!
|
|
logger.info("***** Running training *****")
|
|
logger.info(" Total optimization steps = %d", args.num_steps)
|
|
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
|
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
|
args.train_batch_size * args.gradient_accumulation_steps * (
|
|
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
|
|
|
model.zero_grad()
|
|
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
|
losses = AverageMeter()
|
|
global_step, best_acc = 0, 0
|
|
start_time = time.time()
|
|
while True:
|
|
model.train()
|
|
epoch_iterator = tqdm(train_loader,
|
|
desc="Training (X / X Steps) (loss=X.X)",
|
|
bar_format="{l_bar}{r_bar}",
|
|
dynamic_ncols=True,
|
|
disable=args.local_rank not in [-1, 0])
|
|
all_preds, all_label = [], []
|
|
for step, batch in enumerate(epoch_iterator):
|
|
batch = tuple(t.to(args.device) for t in batch)
|
|
x, y = batch
|
|
|
|
loss, logits = model(x, y)
|
|
loss = loss.mean()
|
|
|
|
preds = torch.argmax(logits, dim=-1)
|
|
|
|
if len(all_preds) == 0:
|
|
all_preds.append(preds.detach().cpu().numpy())
|
|
all_label.append(y.detach().cpu().numpy())
|
|
else:
|
|
all_preds[0] = np.append(
|
|
all_preds[0], preds.detach().cpu().numpy(), axis=0
|
|
)
|
|
all_label[0] = np.append(
|
|
all_label[0], y.detach().cpu().numpy(), axis=0
|
|
)
|
|
|
|
if args.gradient_accumulation_steps > 1:
|
|
loss = loss / args.gradient_accumulation_steps
|
|
loss.backward()
|
|
|
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
|
losses.update(loss.item()*args.gradient_accumulation_steps)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
scheduler.step()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
global_step += 1
|
|
|
|
epoch_iterator.set_description(
|
|
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
|
|
)
|
|
if args.local_rank in [-1, 0]:
|
|
writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
|
|
writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
|
|
if global_step % args.eval_every == 0:
|
|
with torch.no_grad():
|
|
accuracy = valid(args, model, writer, test_loader, global_step)
|
|
if args.local_rank in [-1, 0]:
|
|
if best_acc < accuracy:
|
|
save_model(args, model)
|
|
best_acc = accuracy
|
|
logger.info("best accuracy so far: %f" % best_acc)
|
|
if int(global_step)%10000 == 0:
|
|
save_eve_model(args, model, str(global_step))
|
|
model.train()
|
|
|
|
if global_step % t_total == 0:
|
|
break
|
|
all_preds, all_label = all_preds[0], all_label[0]
|
|
accuracy = simple_accuracy(all_preds, all_label)
|
|
accuracy = torch.tensor(accuracy).to(args.device)
|
|
# dist.barrier()
|
|
# train_accuracy = reduce_mean(accuracy, args.nprocs)
|
|
# train_accuracy = train_accuracy.detach().cpu().numpy()
|
|
train_accuracy = accuracy.detach().cpu().numpy()
|
|
logger.info("train accuracy so far: %f" % train_accuracy)
|
|
losses.reset()
|
|
if global_step % t_total == 0:
|
|
break
|
|
|
|
writer.close()
|
|
logger.info("Best Accuracy: \t%f" % best_acc)
|
|
logger.info("End Training!")
|
|
end_time = time.time()
|
|
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
# Required parameters
|
|
parser.add_argument("--name", required=True,
|
|
help="Name of this run. Used for monitoring.")
|
|
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
|
|
default="CUB_200_2011", help="Which dataset.")
|
|
parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained')
|
|
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
|
|
default="ViT-B_16",help="Which variant to use.")
|
|
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
|
|
help="Where to search for pretrained ViT models.")
|
|
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
|
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
|
|
parser.add_argument("--output_dir", default="./output", type=str,
|
|
help="The output directory where checkpoints will be written.")
|
|
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
|
parser.add_argument("--train_batch_size", default=64, type=int,
|
|
help="Total batch size for training.")
|
|
parser.add_argument("--eval_batch_size", default=16, type=int,
|
|
help="Total batch size for eval.")
|
|
parser.add_argument("--eval_every", default=200, type=int, #200
|
|
help="Run prediction on validation set every so many steps."
|
|
"Will always run one evaluation at the end of training.")
|
|
|
|
parser.add_argument("--learning_rate", default=3e-2, type=float,
|
|
help="The initial learning rate for SGD.")
|
|
parser.add_argument("--weight_decay", default=0, type=float,
|
|
help="Weight deay if we apply some.")
|
|
parser.add_argument("--num_steps", default=40000, type=int, #100000
|
|
help="Total number of training epochs to perform.")
|
|
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
|
|
help="How to decay the learning rate.")
|
|
parser.add_argument("--warmup_steps", default=500, type=int,
|
|
help="Step of training to perform learning rate warmup for.")
|
|
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
|
help="Max gradient norm.")
|
|
|
|
parser.add_argument("--local_rank", type=int, default=-1,
|
|
help="local_rank for distributed training on gpus")
|
|
parser.add_argument('--seed', type=int, default=42,
|
|
help="random seed for initialization")
|
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
|
parser.add_argument('--fp16', action='store_true',
|
|
help="Whether to use 16-bit float precision instead of 32-bit")
|
|
parser.add_argument('--fp16_opt_level', type=str, default='O2',
|
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
|
"See details at https://nvidia.github.io/apex/amp.html")
|
|
parser.add_argument('--loss_scale', type=float, default=0,
|
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
|
"0 (default value): dynamic loss scaling.\n"
|
|
"Positive power of 2: static loss scaling value.\n")
|
|
|
|
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
|
|
|
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
|
|
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
|
args = parser.parse_args()
|
|
|
|
args.data_root = '{}/{}'.format(args.data_root, args.dataset)
|
|
# Setup CUDA, GPU & distributed training
|
|
if args.local_rank == -1:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
|
args.n_gpu = torch.cuda.device_count()
|
|
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
|
torch.cuda.set_device(args.local_rank)
|
|
device = torch.device("cuda", args.local_rank)
|
|
torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60))
|
|
args.n_gpu = 1
|
|
args.device = device
|
|
args.nprocs = torch.cuda.device_count()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
datefmt='%m/%d/%Y %H:%M:%S',
|
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
|
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
|
|
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
|
|
|
|
# Set seed
|
|
set_seed(args)
|
|
|
|
# Model & Tokenizer Setup
|
|
args, model = setup(args)
|
|
# Training
|
|
train(args, model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|