import torch import torch.nn as nn from vit_pytorch import ViT,SimpleViT,MAE from vit_pytorch.distill import DistillableViT from vit_pytorch.deepvit import DeepViT from vit_pytorch.cait import CaiT from vit_pytorch.pit import PiT from vit_pytorch.regionvit import RegionViT from vit_pytorch.sep_vit import SepViT from vit_pytorch.crossformer import CrossFormer from vit_pytorch.nest import NesT from vit_pytorch.mobile_vit import MobileViT from vit_pytorch.simmim import SimMIM from vit_pytorch.ats_vit import ViT from utils.data_utils import get_loader_new from utils.scheduler import WarmupCosineSchedule from tqdm import tqdm import os import numpy as np def net(): # model = ViT( # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 1024, # depth = 6, # max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling # heads = 16, # mlp_dim = 2048, # dropout = 0.1, # emb_dropout = 0.1 # ) # modelv = ViT( # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 1024, # depth = 6, # heads = 8, # mlp_dim = 2048 # ) # model = MAE( # encoder = modelv, # masking_ratio = 0.5 # they found 50% to yield the best results # ) model = NesT( image_size = 600, patch_size = 30, dim = 256, heads = 18,#16 num_hierarchies = 3, # number of hierarchies block_repeats = (2, 4, 18), # (2,4,16)(2,4,12)(2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom num_classes = 5 ) # model = NesT( # image_size = 600, # patch_size = 30, # dim = 256, # heads = 16,#16 # num_hierarchies = 3, # number of hierarchies # block_repeats = (2, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom # num_classes = 5 # ) # model = CrossFormer( #图片尺寸要是7的倍数,如448 # num_classes = 5, # number of output classes # dim = (64, 128, 256, 512), # dimension at each stage # depth = (2, 2, 8, 2), # depth of transformer at each stage # global_window_size = (8, 4, 2, 1), # global window sizes at each stage # local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages) # ) # model = RegionViT( #图片尺寸要是7的倍数,如448 # dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage # depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage # window_size = 7, # window size, which should be either 7 or 14 # num_classes = 5, # number of output classes # tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models # use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance # ) # model = SepViT( #图片尺寸要是7的倍数,如448 # num_classes = 5, # dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite # dim_head = 32, # attention head dimension # heads = (1, 2, 4, 8), # number of heads per stage # depth = (1, 2, 6, 2), # number of transformer blocks per stage # window_size = 7, # window size of DSS Attention block # dropout = 0.1 # dropout # ) # model = PiT( # image_size = 600, # patch_size = 30, # dim = 1024, # num_classes = 5, # depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample # heads = 8, # mlp_dim = 3072, # dropout = 0.1, # emb_dropout = 0.1 # ) # model = CaiT( # image_size = 600, # patch_size = 30, # num_classes = 1000, # dim = 1024, # depth = 12, # depth of transformer for patch to patch attention only # cls_depth = 2, # depth of cross attention of CLS tokens to patch # heads = 16, # mlp_dim = 2048, # dropout = 0.1, # emb_dropout = 0.1, # layer_dropout = 0.05 # randomly dropout 5% of the layers # ) # model = DeepViT( # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 256, # depth = 6, # heads = 6, # mlp_dim = 256, # dropout = 0.1, # emb_dropout = 0.1 # ) # model = DistillableViT( # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 1080, # depth = 12, # heads = 12, # mlp_dim = 3072, # dropout = 0.1, # emb_dropout = 0.1 # ) # model = ViT( # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 768, # depth = 12, # heads = 12, # mlp_dim = 3072, # dropout = 0.1, # emb_dropout = 0.1 # ) # model = SimpleViT( # image_size = 600, # patch_size = 30, # num_classes = 2, # dim = 256, # depth = 6, # heads = 16, # mlp_dim = 256 # ) # model = ViT( # #Vit-best # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 512, # depth = 6, # heads = 8, # mlp_dim = 512, # pool = 'cls', # channels = 3, # dim_head = 12, # dropout = 0.1, # emb_dropout = 0.1 # #Vit-small # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 256, # depth = 8, # heads = 16, # mlp_dim = 256, # pool = 'cls', # channels = 3, # dim_head = 16, # dropout = 0.1, # emb_dropout = 0.1 # #Vit-tiny # image_size = 600, # patch_size = 30, # num_classes = 5, # dim = 256, # depth = 4, # heads = 6, # mlp_dim = 256, # pool = 'cls', # channels = 3, # dim_head = 6, # dropout = 0.1, # emb_dropout = 0.1 # ) return model # img = torch.randn(1, 3, 448, 448) # model = net() # preds = model(img) # (1, 1000) #计算模型参数数量 def count_parameters(model): params = sum(p.numel() for p in model.parameters() if p.requires_grad) return params/1000000 #Loss平均 class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count criterion = nn.CrossEntropyLoss() #简单准确率 def simple_accuracy(preds, labels): return (preds == labels).mean() #模型测试 def test(device, model, test_loader, global_step): eval_losses = AverageMeter() print("***** Running Validation *****") model.eval() all_preds, all_label = [], [] epoch_iterator = tqdm(test_loader, desc="Validating... (loss=X.X)", bar_format="{l_bar}{r_bar}", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(device) for t in batch) x, y = batch with torch.no_grad(): logits = model(x) eval_loss = criterion(logits, y) eval_loss = eval_loss.mean() eval_losses.update(eval_loss.item()) preds = torch.argmax(logits, dim=-1) if len(all_preds) == 0: all_preds.append(preds.detach().cpu().numpy()) all_label.append(y.detach().cpu().numpy()) else: all_preds[0] = np.append( all_preds[0], preds.detach().cpu().numpy(), axis=0 ) all_label[0] = np.append( all_label[0], y.detach().cpu().numpy(), axis=0 ) epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) all_preds, all_label = all_preds[0], all_label[0] accuracy = simple_accuracy(all_preds, all_label) accuracy = torch.tensor(accuracy).to(device) val_accuracy = accuracy.detach().cpu().numpy() print("test Loss: %2.5f" % eval_losses.avg) print("test Accuracy: %2.5f" % val_accuracy) return val_accuracy #保存模型 def save_model(model): model_checkpoint = os.path.join('./output', "%s_vit_checkpoint.pth" % 'ieemooempty') torch.save(model, model_checkpoint) print("Saved model checkpoint to [DIR: %s]", './output') #训练 def train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader): optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=WEIGHT_DECAY) t_total = train_NUM_STEPS scheduler = WarmupCosineSchedule(optimizer, warmup_steps=WARMUP_STEPS, t_total=t_total) model.zero_grad() losses = AverageMeter() global_step, best_acc = 0, 0 gradient_accumulation_steps = 1 while True: model.train() epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", bar_format="{l_bar}{r_bar}", dynamic_ncols=True) all_preds, all_label = [], [] for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(device) for t in batch) x, y = batch logits = model(x) loss = criterion(logits,y) loss.backward() preds = torch.argmax(logits, dim=-1) if len(all_preds) == 0: all_preds.append(preds.detach().cpu().numpy()) all_label.append(y.detach().cpu().numpy()) else: all_preds[0] = np.append( all_preds[0], preds.detach().cpu().numpy(), axis=0 ) all_label[0] = np.append( all_label[0], y.detach().cpu().numpy(), axis=0 ) if (step + 1) % gradient_accumulation_steps == 0: losses.update(loss.item()*gradient_accumulation_steps) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scheduler.step() optimizer.step() optimizer.zero_grad() global_step += 1 epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) ) model.train() #需要2次,才会保存训练好的模型 if global_step % t_total == 0: break all_preds, all_label = all_preds[0], all_label[0] accuracy = simple_accuracy(all_preds, all_label) accuracy = torch.tensor(accuracy).to(device) train_accuracy = accuracy.detach().cpu().numpy() print("train accuracy: %f" % train_accuracy) accuracy = test(device, model, test_loader, global_step) if best_acc < accuracy: save_model(model) best_acc = accuracy losses.reset() if global_step % t_total == 0: break if __name__ == "__main__": model = net() train_loader, test_loader = get_loader_new() trainnumsteps = len(train_loader) testnumsteps = len(test_loader) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(str(count_parameters(model))+'MB') # img = torch.randn(1, 3, 320, 320) #(batchsize,channels,width,higth) #3072是默认channels为3,3*1024 # preds = model(img) # print(preds) model = model.to(device) epoch = 300 train_NUM_STEPS = trainnumsteps * epoch LEARNING_RATE = 3e-2 WEIGHT_DECAY = 0 WARMUP_STEPS = 500 train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)