412 lines
12 KiB
Python
412 lines
12 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from vit_pytorch import ViT,SimpleViT,MAE
|
||
from vit_pytorch.distill import DistillableViT
|
||
from vit_pytorch.deepvit import DeepViT
|
||
from vit_pytorch.cait import CaiT
|
||
from vit_pytorch.pit import PiT
|
||
from vit_pytorch.regionvit import RegionViT
|
||
from vit_pytorch.sep_vit import SepViT
|
||
from vit_pytorch.crossformer import CrossFormer
|
||
from vit_pytorch.nest import NesT
|
||
from vit_pytorch.mobile_vit import MobileViT
|
||
from vit_pytorch.simmim import SimMIM
|
||
from vit_pytorch.ats_vit import ViT
|
||
|
||
from utils.data_utils import get_loader_new
|
||
from utils.scheduler import WarmupCosineSchedule
|
||
|
||
from tqdm import tqdm
|
||
|
||
import os
|
||
import numpy as np
|
||
|
||
def net():
|
||
|
||
# model = ViT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 1024,
|
||
# depth = 6,
|
||
# max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
|
||
# heads = 16,
|
||
# mlp_dim = 2048,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
# )
|
||
|
||
# modelv = ViT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 1024,
|
||
# depth = 6,
|
||
# heads = 8,
|
||
# mlp_dim = 2048
|
||
# )
|
||
|
||
# model = MAE(
|
||
# encoder = modelv,
|
||
# masking_ratio = 0.5 # they found 50% to yield the best results
|
||
# )
|
||
|
||
|
||
model = NesT(
|
||
image_size = 600,
|
||
patch_size = 30,
|
||
dim = 256,
|
||
heads = 18,#16
|
||
num_hierarchies = 3, # number of hierarchies
|
||
block_repeats = (3, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
|
||
num_classes = 5
|
||
)
|
||
|
||
# model = CrossFormer( #图片尺寸要是7的倍数,如448
|
||
# num_classes = 5, # number of output classes
|
||
# dim = (64, 128, 256, 512), # dimension at each stage
|
||
# depth = (2, 2, 8, 2), # depth of transformer at each stage
|
||
# global_window_size = (8, 4, 2, 1), # global window sizes at each stage
|
||
# local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
|
||
# )
|
||
|
||
# model = RegionViT( #图片尺寸要是7的倍数,如448
|
||
# dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||
# depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||
# window_size = 7, # window size, which should be either 7 or 14
|
||
# num_classes = 5, # number of output classes
|
||
# tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||
# use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||
# )
|
||
|
||
# model = SepViT( #图片尺寸要是7的倍数,如448
|
||
# num_classes = 5,
|
||
# dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
|
||
# dim_head = 32, # attention head dimension
|
||
# heads = (1, 2, 4, 8), # number of heads per stage
|
||
# depth = (1, 2, 6, 2), # number of transformer blocks per stage
|
||
# window_size = 7, # window size of DSS Attention block
|
||
# dropout = 0.1 # dropout
|
||
# )
|
||
|
||
# model = PiT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# dim = 1024,
|
||
# num_classes = 5,
|
||
# depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample
|
||
# heads = 8,
|
||
# mlp_dim = 3072,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
# )
|
||
|
||
|
||
# model = CaiT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 1000,
|
||
# dim = 1024,
|
||
# depth = 12, # depth of transformer for patch to patch attention only
|
||
# cls_depth = 2, # depth of cross attention of CLS tokens to patch
|
||
# heads = 16,
|
||
# mlp_dim = 2048,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1,
|
||
# layer_dropout = 0.05 # randomly dropout 5% of the layers
|
||
# )
|
||
|
||
# model = DeepViT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 256,
|
||
# depth = 6,
|
||
# heads = 6,
|
||
# mlp_dim = 256,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
# )
|
||
|
||
# model = DistillableViT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 1080,
|
||
# depth = 12,
|
||
# heads = 12,
|
||
# mlp_dim = 3072,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
# )
|
||
|
||
|
||
# model = ViT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 768,
|
||
# depth = 12,
|
||
# heads = 12,
|
||
# mlp_dim = 3072,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
# )
|
||
|
||
# model = SimpleViT(
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 2,
|
||
# dim = 256,
|
||
# depth = 6,
|
||
# heads = 16,
|
||
# mlp_dim = 256
|
||
# )
|
||
|
||
# model = ViT(
|
||
# #Vit-best
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 512,
|
||
# depth = 6,
|
||
# heads = 8,
|
||
# mlp_dim = 512,
|
||
# pool = 'cls',
|
||
# channels = 3,
|
||
# dim_head = 12,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
|
||
# #Vit-small
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 256,
|
||
# depth = 8,
|
||
# heads = 16,
|
||
# mlp_dim = 256,
|
||
# pool = 'cls',
|
||
# channels = 3,
|
||
# dim_head = 16,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
|
||
# #Vit-tiny
|
||
# image_size = 600,
|
||
# patch_size = 30,
|
||
# num_classes = 5,
|
||
# dim = 256,
|
||
# depth = 4,
|
||
# heads = 6,
|
||
# mlp_dim = 256,
|
||
# pool = 'cls',
|
||
# channels = 3,
|
||
# dim_head = 6,
|
||
# dropout = 0.1,
|
||
# emb_dropout = 0.1
|
||
# )
|
||
|
||
|
||
return model
|
||
|
||
# img = torch.randn(1, 3, 448, 448)
|
||
# model = net()
|
||
# preds = model(img) # (1, 1000)
|
||
|
||
|
||
#计算模型参数数量
|
||
def count_parameters(model):
|
||
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
return params/1000000
|
||
|
||
|
||
#Loss平均
|
||
class AverageMeter(object):
|
||
"""Computes and stores the average and current value"""
|
||
def __init__(self):
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
self.val = 0
|
||
self.avg = 0
|
||
self.sum = 0
|
||
self.count = 0
|
||
|
||
def update(self, val, n=1):
|
||
self.val = val
|
||
self.sum += val * n
|
||
self.count += n
|
||
self.avg = self.sum / self.count
|
||
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
#简单准确率
|
||
def simple_accuracy(preds, labels):
|
||
return (preds == labels).mean()
|
||
|
||
#模型测试
|
||
def test(device, model, test_loader, global_step):
|
||
eval_losses = AverageMeter()
|
||
|
||
print("***** Running Validation *****")
|
||
|
||
|
||
model.eval()
|
||
all_preds, all_label = [], []
|
||
epoch_iterator = tqdm(test_loader,
|
||
desc="Validating... (loss=X.X)",
|
||
bar_format="{l_bar}{r_bar}",
|
||
dynamic_ncols=True)
|
||
|
||
for step, batch in enumerate(epoch_iterator):
|
||
batch = tuple(t.to(device) for t in batch)
|
||
x, y = batch
|
||
with torch.no_grad():
|
||
logits = model(x)
|
||
|
||
eval_loss = criterion(logits, y)
|
||
eval_loss = eval_loss.mean()
|
||
eval_losses.update(eval_loss.item())
|
||
|
||
preds = torch.argmax(logits, dim=-1)
|
||
|
||
if len(all_preds) == 0:
|
||
all_preds.append(preds.detach().cpu().numpy())
|
||
all_label.append(y.detach().cpu().numpy())
|
||
else:
|
||
all_preds[0] = np.append(
|
||
all_preds[0], preds.detach().cpu().numpy(), axis=0
|
||
)
|
||
all_label[0] = np.append(
|
||
all_label[0], y.detach().cpu().numpy(), axis=0
|
||
)
|
||
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
|
||
|
||
all_preds, all_label = all_preds[0], all_label[0]
|
||
accuracy = simple_accuracy(all_preds, all_label)
|
||
accuracy = torch.tensor(accuracy).to(device)
|
||
|
||
val_accuracy = accuracy.detach().cpu().numpy()
|
||
|
||
|
||
print("test Loss: %2.5f" % eval_losses.avg)
|
||
print("test Accuracy: %2.5f" % val_accuracy)
|
||
|
||
return val_accuracy
|
||
|
||
|
||
#保存模型
|
||
def save_model(model):
|
||
model_checkpoint = os.path.join('./output', "%s_vit_checkpoint.pth" % 'ieemooempty')
|
||
torch.save(model, model_checkpoint)
|
||
print("Saved model checkpoint to [DIR: %s]", './output')
|
||
|
||
|
||
#训练
|
||
def train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader):
|
||
|
||
optimizer = torch.optim.SGD(model.parameters(),
|
||
lr=LEARNING_RATE,
|
||
momentum=0.9,
|
||
weight_decay=WEIGHT_DECAY)
|
||
t_total = train_NUM_STEPS
|
||
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=WARMUP_STEPS, t_total=t_total)
|
||
model.zero_grad()
|
||
losses = AverageMeter()
|
||
global_step, best_acc = 0, 0
|
||
gradient_accumulation_steps = 1
|
||
while True:
|
||
model.train()
|
||
epoch_iterator = tqdm(train_loader,
|
||
desc="Training (X / X Steps) (loss=X.X)",
|
||
bar_format="{l_bar}{r_bar}",
|
||
dynamic_ncols=True)
|
||
all_preds, all_label = [], []
|
||
for step, batch in enumerate(epoch_iterator):
|
||
batch = tuple(t.to(device) for t in batch)
|
||
x, y = batch
|
||
logits = model(x)
|
||
|
||
loss = criterion(logits,y)
|
||
|
||
loss.backward()
|
||
|
||
preds = torch.argmax(logits, dim=-1)
|
||
|
||
if len(all_preds) == 0:
|
||
all_preds.append(preds.detach().cpu().numpy())
|
||
all_label.append(y.detach().cpu().numpy())
|
||
else:
|
||
all_preds[0] = np.append(
|
||
all_preds[0], preds.detach().cpu().numpy(), axis=0
|
||
)
|
||
all_label[0] = np.append(
|
||
all_label[0], y.detach().cpu().numpy(), axis=0
|
||
)
|
||
|
||
if (step + 1) % gradient_accumulation_steps == 0:
|
||
losses.update(loss.item()*gradient_accumulation_steps)
|
||
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||
scheduler.step()
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
global_step += 1
|
||
|
||
epoch_iterator.set_description(
|
||
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
|
||
)
|
||
|
||
model.train() #需要2次,才会保存训练好的模型
|
||
|
||
|
||
if global_step % t_total == 0:
|
||
|
||
|
||
break
|
||
|
||
all_preds, all_label = all_preds[0], all_label[0]
|
||
accuracy = simple_accuracy(all_preds, all_label)
|
||
accuracy = torch.tensor(accuracy).to(device)
|
||
train_accuracy = accuracy.detach().cpu().numpy()
|
||
print("train accuracy: %f" % train_accuracy)
|
||
accuracy = test(device, model, test_loader, global_step)
|
||
|
||
if best_acc < accuracy:
|
||
save_model(model)
|
||
best_acc = accuracy
|
||
|
||
losses.reset()
|
||
if global_step % t_total == 0:
|
||
break
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
model = net()
|
||
train_loader, test_loader = get_loader_new()
|
||
trainnumsteps = len(train_loader)
|
||
testnumsteps = len(test_loader)
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
print(str(count_parameters(model))+'MB')
|
||
# img = torch.randn(1, 3, 320, 320) #(batchsize,channels,width,higth) #3072是默认channels为3,3*1024
|
||
# preds = model(img)
|
||
# print(preds)
|
||
|
||
model = model.to(device)
|
||
epoch = 300
|
||
train_NUM_STEPS = trainnumsteps * epoch
|
||
LEARNING_RATE = 3e-2
|
||
WEIGHT_DECAY = 0
|
||
WARMUP_STEPS = 500
|
||
train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)
|
||
|
||
|
||
|
||
|