update train.py.
This commit is contained in:
103
train.py
103
train.py
@ -24,7 +24,9 @@ import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
#计算并存储平均值
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
@ -42,57 +44,39 @@ class AverageMeter(object):
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
#简单准确率
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
#求均值
|
||||
def reduce_mean(tensor, nprocs):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= nprocs
|
||||
return rt
|
||||
|
||||
|
||||
#保存模型
|
||||
def save_model(args, model):
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
|
||||
checkpoint = {
|
||||
'model': model_to_save.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, model_checkpoint)
|
||||
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
||||
|
||||
model_checkpoint = os.path.join("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
|
||||
torch.save(model, model_checkpoint)
|
||||
logger.info("Saved model checkpoint to [File: %s]", "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
|
||||
|
||||
#根据数据集配置模型
|
||||
def setup(args):
|
||||
# Prepare model
|
||||
config = CONFIGS[args.model_type]
|
||||
config.split = args.split
|
||||
config.slide_step = args.slide_step
|
||||
|
||||
if args.dataset == "CUB_200_2011":
|
||||
num_classes = 200
|
||||
elif args.dataset == "car":
|
||||
num_classes = 196
|
||||
elif args.dataset == "nabirds":
|
||||
num_classes = 555
|
||||
elif args.dataset == "dog":
|
||||
num_classes = 120
|
||||
elif args.dataset == "INat2017":
|
||||
num_classes = 5089
|
||||
elif args.dataset == "emptyJudge5":
|
||||
|
||||
if args.dataset == "emptyJudge5":
|
||||
num_classes = 5
|
||||
elif args.dataset == "emptyJudge4":
|
||||
num_classes = 4
|
||||
elif args.dataset == "emptyJudge3":
|
||||
num_classes = 3
|
||||
|
||||
|
||||
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
|
||||
|
||||
model.load_from(np.load(args.pretrained_dir))
|
||||
if args.pretrained_dir is not None:
|
||||
model.load_from(np.load(args.pretrained_dir)) #他人预训练模型
|
||||
if args.pretrained_model is not None:
|
||||
pretrained_model = torch.load(args.pretrained_model)['model']
|
||||
model.load_state_dict(pretrained_model)
|
||||
model = torch.load(args.pretrained_model) #自己预训练模型
|
||||
#model.to(args.device)
|
||||
#pdb.set_trace()
|
||||
num_params = count_parameters(model)
|
||||
@ -100,15 +84,15 @@ def setup(args):
|
||||
logger.info("{}".format(config))
|
||||
logger.info("Training parameters %s", args)
|
||||
logger.info("Total Parameter: \t%2.1fM" % num_params)
|
||||
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
|
||||
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
|
||||
return args, model
|
||||
|
||||
|
||||
#计算模型参数数量
|
||||
def count_parameters(model):
|
||||
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return params/1000000
|
||||
|
||||
|
||||
#随机种子
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@ -116,7 +100,7 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
#模型验证
|
||||
def valid(args, model, writer, test_loader, global_step):
|
||||
eval_losses = AverageMeter()
|
||||
|
||||
@ -173,7 +157,7 @@ def valid(args, model, writer, test_loader, global_step):
|
||||
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
|
||||
return val_accuracy
|
||||
|
||||
|
||||
#模型训练
|
||||
def train(args, model):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
@ -287,37 +271,41 @@ def train(args, model):
|
||||
end_time = time.time()
|
||||
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
|
||||
|
||||
|
||||
|
||||
#主函数
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--name", required=True,
|
||||
parser.add_argument("--name", type=str, default='ieemooempty',
|
||||
help="Name of this run. Used for monitoring.")
|
||||
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
|
||||
default="CUB_200_2011", help="Which dataset.")
|
||||
parser.add_argument('--data_root', type=str, default='/data/fineGrained')
|
||||
default="emptyJudge5", help="Which dataset.")
|
||||
parser.add_argument('--data_root', type=str, default='./')
|
||||
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
|
||||
default="ViT-B_16",help="Which variant to use.")
|
||||
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
|
||||
parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz",
|
||||
help="Where to search for pretrained ViT models.")
|
||||
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
||||
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
|
||||
#parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None
|
||||
# parser.add_argument("--pretrained_dir", type=str, default=None,
|
||||
# help="Where to search for pretrained ViT models.")
|
||||
parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None
|
||||
parser.add_argument("--output_dir", default="./output", type=str,
|
||||
help="The output directory where checkpoints will be written.")
|
||||
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||
parser.add_argument("--train_batch_size", default=64, type=int,
|
||||
parser.add_argument("--img_size", default=320, type=int, help="Resolution size")
|
||||
parser.add_argument("--train_batch_size", default=8, type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size", default=16, type=int,
|
||||
parser.add_argument("--eval_batch_size", default=8, type=int,
|
||||
help="Total batch size for eval.")
|
||||
parser.add_argument("--eval_every", default=200, type=int,
|
||||
parser.add_argument("--eval_every", default=786, type=int,
|
||||
help="Run prediction on validation set every so many steps."
|
||||
"Will always run one evaluation at the end of training.")
|
||||
|
||||
parser.add_argument("--learning_rate", default=3e-2, type=float,
|
||||
parser.add_argument("--learning_rate", default=3e-2, type=float,
|
||||
help="The initial learning rate for SGD.")
|
||||
parser.add_argument("--weight_decay", default=0, type=float,
|
||||
parser.add_argument("--weight_decay", default=0.00001, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--num_steps", default=8000, type=int, #100000
|
||||
parser.add_argument("--num_steps", default=78600, type=int, #100000
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
|
||||
help="How to decay the learning rate.")
|
||||
@ -332,15 +320,6 @@ def main():
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O2',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--loss_scale', type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
|
||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
||||
|
||||
@ -352,7 +331,6 @@ def main():
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
@ -367,8 +345,8 @@ def main():
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
|
||||
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" %
|
||||
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1)))
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@ -380,4 +358,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.cuda.empty_cache()
|
||||
main()
|
||||
|
Reference in New Issue
Block a user