update
This commit is contained in:
41
train.py
41
train.py
@ -57,9 +57,29 @@ def reduce_mean(tensor, nprocs):
|
||||
|
||||
#保存模型
|
||||
def save_model(args, model):
|
||||
<<<<<<< HEAD
|
||||
model_checkpoint = os.path.join("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
|
||||
torch.save(model, model_checkpoint)
|
||||
logger.info("Saved model checkpoint to [File: %s]", "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
|
||||
=======
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
|
||||
checkpoint = {
|
||||
'model': model_to_save.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, model_checkpoint)
|
||||
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
||||
|
||||
def save_eve_model(args, model, eve_name):
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % eve_name)
|
||||
checkpoint = {
|
||||
'model': model_to_save.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, model_checkpoint)
|
||||
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
||||
|
||||
>>>>>>> develop
|
||||
|
||||
#根据数据集配置模型
|
||||
def setup(args):
|
||||
@ -249,6 +269,8 @@ def train(args, model):
|
||||
save_model(args, model)
|
||||
best_acc = accuracy
|
||||
logger.info("best accuracy so far: %f" % best_acc)
|
||||
if int(global_step)%10000 == 0:
|
||||
save_eve_model(args, model, str(global_step))
|
||||
model.train()
|
||||
|
||||
if global_step % t_total == 0:
|
||||
@ -280,8 +302,13 @@ def main():
|
||||
parser.add_argument("--name", type=str, default='ieemooempty',
|
||||
help="Name of this run. Used for monitoring.")
|
||||
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
|
||||
<<<<<<< HEAD
|
||||
default="emptyJudge5", help="Which dataset.")
|
||||
parser.add_argument('--data_root', type=str, default='./')
|
||||
=======
|
||||
default="CUB_200_2011", help="Which dataset.")
|
||||
parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained')
|
||||
>>>>>>> develop
|
||||
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
|
||||
default="ViT-B_16",help="Which variant to use.")
|
||||
parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz",
|
||||
@ -297,7 +324,11 @@ def main():
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size", default=8, type=int,
|
||||
help="Total batch size for eval.")
|
||||
<<<<<<< HEAD
|
||||
parser.add_argument("--eval_every", default=786, type=int,
|
||||
=======
|
||||
parser.add_argument("--eval_every", default=200, type=int, #200
|
||||
>>>>>>> develop
|
||||
help="Run prediction on validation set every so many steps."
|
||||
"Will always run one evaluation at the end of training.")
|
||||
|
||||
@ -305,7 +336,11 @@ def main():
|
||||
help="The initial learning rate for SGD.")
|
||||
parser.add_argument("--weight_decay", default=0.00001, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
<<<<<<< HEAD
|
||||
parser.add_argument("--num_steps", default=78600, type=int, #100000
|
||||
=======
|
||||
parser.add_argument("--num_steps", default=40000, type=int, #100000
|
||||
>>>>>>> develop
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
|
||||
help="How to decay the learning rate.")
|
||||
@ -331,8 +366,12 @@ def main():
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
>>>>>>> develop
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
|
Reference in New Issue
Block a user