update train.py.

This commit is contained in:
Brainway
2022-10-18 03:48:18 +00:00
committed by Gitee
parent fb45a96528
commit 9912dca40c

103
train.py
View File

@ -24,7 +24,9 @@ import pdb
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#计算并存储平均值
class AverageMeter(object): class AverageMeter(object):
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
def __init__(self): def __init__(self):
@ -42,57 +44,39 @@ class AverageMeter(object):
self.count += n self.count += n
self.avg = self.sum / self.count self.avg = self.sum / self.count
#简单准确率
def simple_accuracy(preds, labels): def simple_accuracy(preds, labels):
return (preds == labels).mean() return (preds == labels).mean()
#求均值
def reduce_mean(tensor, nprocs): def reduce_mean(tensor, nprocs):
rt = tensor.clone() rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM) dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs rt /= nprocs
return rt return rt
#保存模型
def save_model(args, model): def save_model(args, model):
model_to_save = model.module if hasattr(model, 'module') else model model_checkpoint = os.path.join("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) torch.save(model, model_checkpoint)
checkpoint = { logger.info("Saved model checkpoint to [File: %s]", "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
'model': model_to_save.state_dict(),
}
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
#根据数据集配置模型
def setup(args): def setup(args):
# Prepare model # Prepare model
config = CONFIGS[args.model_type] config = CONFIGS[args.model_type]
config.split = args.split config.split = args.split
config.slide_step = args.slide_step config.slide_step = args.slide_step
if args.dataset == "CUB_200_2011": if args.dataset == "emptyJudge5":
num_classes = 200
elif args.dataset == "car":
num_classes = 196
elif args.dataset == "nabirds":
num_classes = 555
elif args.dataset == "dog":
num_classes = 120
elif args.dataset == "INat2017":
num_classes = 5089
elif args.dataset == "emptyJudge5":
num_classes = 5 num_classes = 5
elif args.dataset == "emptyJudge4":
num_classes = 4
elif args.dataset == "emptyJudge3":
num_classes = 3
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
model.load_from(np.load(args.pretrained_dir)) if args.pretrained_dir is not None:
model.load_from(np.load(args.pretrained_dir)) #他人预训练模型
if args.pretrained_model is not None: if args.pretrained_model is not None:
pretrained_model = torch.load(args.pretrained_model)['model'] model = torch.load(args.pretrained_model) #自己预训练模型
model.load_state_dict(pretrained_model)
#model.to(args.device) #model.to(args.device)
#pdb.set_trace() #pdb.set_trace()
num_params = count_parameters(model) num_params = count_parameters(model)
@ -100,15 +84,15 @@ def setup(args):
logger.info("{}".format(config)) logger.info("{}".format(config))
logger.info("Training parameters %s", args) logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params) logger.info("Total Parameter: \t%2.1fM" % num_params)
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda() model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
return args, model return args, model
#计算模型参数数量
def count_parameters(model): def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad) params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000 return params/1000000
#随机种子
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
@ -116,7 +100,7 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
#模型验证
def valid(args, model, writer, test_loader, global_step): def valid(args, model, writer, test_loader, global_step):
eval_losses = AverageMeter() eval_losses = AverageMeter()
@ -173,7 +157,7 @@ def valid(args, model, writer, test_loader, global_step):
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step) writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
return val_accuracy return val_accuracy
#模型训练
def train(args, model): def train(args, model):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
@ -287,37 +271,41 @@ def train(args, model):
end_time = time.time() end_time = time.time()
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600)) logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
#主函数
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
parser.add_argument("--name", required=True, parser.add_argument("--name", type=str, default='ieemooempty',
help="Name of this run. Used for monitoring.") help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"], parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
default="CUB_200_2011", help="Which dataset.") default="emptyJudge5", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='/data/fineGrained') parser.add_argument('--data_root', type=str, default='./')
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"], parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
default="ViT-B_16",help="Which variant to use.") default="ViT-B_16",help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz", parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz",
help="Where to search for pretrained ViT models.") help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model") #parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") # parser.add_argument("--pretrained_dir", type=str, default=None,
# help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None
parser.add_argument("--output_dir", default="./output", type=str, parser.add_argument("--output_dir", default="./output", type=str,
help="The output directory where checkpoints will be written.") help="The output directory where checkpoints will be written.")
parser.add_argument("--img_size", default=448, type=int, help="Resolution size") parser.add_argument("--img_size", default=320, type=int, help="Resolution size")
parser.add_argument("--train_batch_size", default=64, type=int, parser.add_argument("--train_batch_size", default=8, type=int,
help="Total batch size for training.") help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=16, type=int, parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.") help="Total batch size for eval.")
parser.add_argument("--eval_every", default=200, type=int, parser.add_argument("--eval_every", default=786, type=int,
help="Run prediction on validation set every so many steps." help="Run prediction on validation set every so many steps."
"Will always run one evaluation at the end of training.") "Will always run one evaluation at the end of training.")
parser.add_argument("--learning_rate", default=3e-2, type=float, parser.add_argument("--learning_rate", default=3e-2, type=float,
help="The initial learning rate for SGD.") help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float, parser.add_argument("--weight_decay", default=0.00001, type=float,
help="Weight deay if we apply some.") help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=8000, type=int, #100000 parser.add_argument("--num_steps", default=78600, type=int, #100000
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
help="How to decay the learning rate.") help="How to decay the learning rate.")
@ -332,15 +320,6 @@ def main():
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.") help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
@ -352,7 +331,6 @@ def main():
# Setup CUDA, GPU & distributed training # Setup CUDA, GPU & distributed training
if args.local_rank == -1: if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
args.n_gpu = torch.cuda.device_count() args.n_gpu = torch.cuda.device_count()
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
@ -367,8 +345,8 @@ def main():
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" %
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1)))
# Set seed # Set seed
set_seed(args) set_seed(args)
@ -380,4 +358,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
torch.cuda.empty_cache()
main() main()