Files
ieemoo-ai-isempty/hello.py
2023-02-06 06:17:57 +00:00

412 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from vit_pytorch import ViT,SimpleViT,MAE
from vit_pytorch.distill import DistillableViT
from vit_pytorch.deepvit import DeepViT
from vit_pytorch.cait import CaiT
from vit_pytorch.pit import PiT
from vit_pytorch.regionvit import RegionViT
from vit_pytorch.sep_vit import SepViT
from vit_pytorch.crossformer import CrossFormer
from vit_pytorch.nest import NesT
from vit_pytorch.mobile_vit import MobileViT
from vit_pytorch.simmim import SimMIM
from vit_pytorch.ats_vit import ViT
from utils.data_utils import get_loader_new
from utils.scheduler import WarmupCosineSchedule
from tqdm import tqdm
import os
import numpy as np
def net():
# model = ViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 1024,
# depth = 6,
# max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
# heads = 16,
# mlp_dim = 2048,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# modelv = ViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 1024,
# depth = 6,
# heads = 8,
# mlp_dim = 2048
# )
# model = MAE(
# encoder = modelv,
# masking_ratio = 0.5 # they found 50% to yield the best results
# )
model = NesT(
image_size = 600,
patch_size = 30,
dim = 256,
heads = 18,#16
num_hierarchies = 3, # number of hierarchies
block_repeats = (3, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 5
)
# model = CrossFormer( #图片尺寸要是7的倍数如448
# num_classes = 5, # number of output classes
# dim = (64, 128, 256, 512), # dimension at each stage
# depth = (2, 2, 8, 2), # depth of transformer at each stage
# global_window_size = (8, 4, 2, 1), # global window sizes at each stage
# local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
# )
# model = RegionViT( #图片尺寸要是7的倍数如448
# dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
# depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
# window_size = 7, # window size, which should be either 7 or 14
# num_classes = 5, # number of output classes
# tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
# use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
# )
# model = SepViT( #图片尺寸要是7的倍数如448
# num_classes = 5,
# dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
# dim_head = 32, # attention head dimension
# heads = (1, 2, 4, 8), # number of heads per stage
# depth = (1, 2, 6, 2), # number of transformer blocks per stage
# window_size = 7, # window size of DSS Attention block
# dropout = 0.1 # dropout
# )
# model = PiT(
# image_size = 600,
# patch_size = 30,
# dim = 1024,
# num_classes = 5,
# depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample
# heads = 8,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = CaiT(
# image_size = 600,
# patch_size = 30,
# num_classes = 1000,
# dim = 1024,
# depth = 12, # depth of transformer for patch to patch attention only
# cls_depth = 2, # depth of cross attention of CLS tokens to patch
# heads = 16,
# mlp_dim = 2048,
# dropout = 0.1,
# emb_dropout = 0.1,
# layer_dropout = 0.05 # randomly dropout 5% of the layers
# )
# model = DeepViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 6,
# heads = 6,
# mlp_dim = 256,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = DistillableViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 1080,
# depth = 12,
# heads = 12,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = ViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 768,
# depth = 12,
# heads = 12,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = SimpleViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 2,
# dim = 256,
# depth = 6,
# heads = 16,
# mlp_dim = 256
# )
# model = ViT(
# #Vit-best
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 512,
# depth = 6,
# heads = 8,
# mlp_dim = 512,
# pool = 'cls',
# channels = 3,
# dim_head = 12,
# dropout = 0.1,
# emb_dropout = 0.1
# #Vit-small
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 8,
# heads = 16,
# mlp_dim = 256,
# pool = 'cls',
# channels = 3,
# dim_head = 16,
# dropout = 0.1,
# emb_dropout = 0.1
# #Vit-tiny
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 4,
# heads = 6,
# mlp_dim = 256,
# pool = 'cls',
# channels = 3,
# dim_head = 6,
# dropout = 0.1,
# emb_dropout = 0.1
# )
return model
# img = torch.randn(1, 3, 448, 448)
# model = net()
# preds = model(img) # (1, 1000)
#计算模型参数数量
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
#Loss平均
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
criterion = nn.CrossEntropyLoss()
#简单准确率
def simple_accuracy(preds, labels):
return (preds == labels).mean()
#模型测试
def test(device, model, test_loader, global_step):
eval_losses = AverageMeter()
print("***** Running Validation *****")
model.eval()
all_preds, all_label = [], []
epoch_iterator = tqdm(test_loader,
desc="Validating... (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(device) for t in batch)
x, y = batch
with torch.no_grad():
logits = model(x)
eval_loss = criterion(logits, y)
eval_loss = eval_loss.mean()
eval_losses.update(eval_loss.item())
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(device)
val_accuracy = accuracy.detach().cpu().numpy()
print("test Loss: %2.5f" % eval_losses.avg)
print("test Accuracy: %2.5f" % val_accuracy)
return val_accuracy
#保存模型
def save_model(model):
model_checkpoint = os.path.join('./output', "%s_vit_checkpoint.pth" % 'ieemooempty')
torch.save(model, model_checkpoint)
print("Saved model checkpoint to [DIR: %s]", './output')
#训练
def train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader):
optimizer = torch.optim.SGD(model.parameters(),
lr=LEARNING_RATE,
momentum=0.9,
weight_decay=WEIGHT_DECAY)
t_total = train_NUM_STEPS
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=WARMUP_STEPS, t_total=t_total)
model.zero_grad()
losses = AverageMeter()
global_step, best_acc = 0, 0
gradient_accumulation_steps = 1
while True:
model.train()
epoch_iterator = tqdm(train_loader,
desc="Training (X / X Steps) (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True)
all_preds, all_label = [], []
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(device) for t in batch)
x, y = batch
logits = model(x)
loss = criterion(logits,y)
loss.backward()
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
if (step + 1) % gradient_accumulation_steps == 0:
losses.update(loss.item()*gradient_accumulation_steps)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
)
model.train() #需要2次才会保存训练好的模型
if global_step % t_total == 0:
break
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(device)
train_accuracy = accuracy.detach().cpu().numpy()
print("train accuracy: %f" % train_accuracy)
accuracy = test(device, model, test_loader, global_step)
if best_acc < accuracy:
save_model(model)
best_acc = accuracy
losses.reset()
if global_step % t_total == 0:
break
if __name__ == "__main__":
model = net()
train_loader, test_loader = get_loader_new()
trainnumsteps = len(train_loader)
testnumsteps = len(test_loader)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(str(count_parameters(model))+'MB')
# img = torch.randn(1, 3, 320, 320) #(batchsize,channels,width,higth) #3072是默认channels为33*1024
# preds = model(img)
# print(preds)
model = model.to(device)
epoch = 300
train_NUM_STEPS = trainnumsteps * epoch
LEARNING_RATE = 3e-2
WEIGHT_DECAY = 0
WARMUP_STEPS = 500
train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)