update train.py.

This commit is contained in:
Brainway
2022-10-18 03:48:18 +00:00
committed by Gitee
parent fb45a96528
commit 9912dca40c

103
train.py
View File

@ -24,7 +24,9 @@ import pdb
logger = logging.getLogger(__name__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#计算并存储平均值
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
@ -42,57 +44,39 @@ class AverageMeter(object):
self.count += n
self.avg = self.sum / self.count
#简单准确率
def simple_accuracy(preds, labels):
return (preds == labels).mean()
#求均值
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
#保存模型
def save_model(args, model):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
checkpoint = {
'model': model_to_save.state_dict(),
}
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
model_checkpoint = os.path.join("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
torch.save(model, model_checkpoint)
logger.info("Saved model checkpoint to [File: %s]", "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
#根据数据集配置模型
def setup(args):
# Prepare model
config = CONFIGS[args.model_type]
config.split = args.split
config.slide_step = args.slide_step
if args.dataset == "CUB_200_2011":
num_classes = 200
elif args.dataset == "car":
num_classes = 196
elif args.dataset == "nabirds":
num_classes = 555
elif args.dataset == "dog":
num_classes = 120
elif args.dataset == "INat2017":
num_classes = 5089
elif args.dataset == "emptyJudge5":
if args.dataset == "emptyJudge5":
num_classes = 5
elif args.dataset == "emptyJudge4":
num_classes = 4
elif args.dataset == "emptyJudge3":
num_classes = 3
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
model.load_from(np.load(args.pretrained_dir))
if args.pretrained_dir is not None:
model.load_from(np.load(args.pretrained_dir)) #他人预训练模型
if args.pretrained_model is not None:
pretrained_model = torch.load(args.pretrained_model)['model']
model.load_state_dict(pretrained_model)
model = torch.load(args.pretrained_model) #自己预训练模型
#model.to(args.device)
#pdb.set_trace()
num_params = count_parameters(model)
@ -100,15 +84,15 @@ def setup(args):
logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
return args, model
#计算模型参数数量
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
#随机种子
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
@ -116,7 +100,7 @@ def set_seed(args):
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
#模型验证
def valid(args, model, writer, test_loader, global_step):
eval_losses = AverageMeter()
@ -173,7 +157,7 @@ def valid(args, model, writer, test_loader, global_step):
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
return val_accuracy
#模型训练
def train(args, model):
""" Train the model """
if args.local_rank in [-1, 0]:
@ -287,37 +271,41 @@ def train(args, model):
end_time = time.time()
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
#主函数
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", required=True,
parser.add_argument("--name", type=str, default='ieemooempty',
help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
default="CUB_200_2011", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='/data/fineGrained')
default="emptyJudge5", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='./')
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
default="ViT-B_16",help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz",
help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
#parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None
# parser.add_argument("--pretrained_dir", type=str, default=None,
# help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None
parser.add_argument("--output_dir", default="./output", type=str,
help="The output directory where checkpoints will be written.")
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument("--train_batch_size", default=64, type=int,
parser.add_argument("--img_size", default=320, type=int, help="Resolution size")
parser.add_argument("--train_batch_size", default=8, type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=16, type=int,
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
parser.add_argument("--eval_every", default=200, type=int,
parser.add_argument("--eval_every", default=786, type=int,
help="Run prediction on validation set every so many steps."
"Will always run one evaluation at the end of training.")
parser.add_argument("--learning_rate", default=3e-2, type=float,
parser.add_argument("--learning_rate", default=3e-2, type=float,
help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
parser.add_argument("--weight_decay", default=0.00001, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=8000, type=int, #100000
parser.add_argument("--num_steps", default=78600, type=int, #100000
help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
help="How to decay the learning rate.")
@ -332,15 +320,6 @@ def main():
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
@ -352,7 +331,6 @@ def main():
# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
args.n_gpu = torch.cuda.device_count()
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
@ -367,8 +345,8 @@ def main():
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" %
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1)))
# Set seed
set_seed(args)
@ -380,4 +358,5 @@ def main():
if __name__ == "__main__":
torch.cuda.empty_cache()
main()