Files
ieemoo-ai-isempty/trial.py
Brainway 33143a10ed update
2022-09-27 02:32:20 +00:00

228 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from vit_pytorch import ViT
from utils.data_utils import get_loader_new
from utils.scheduler import WarmupCosineSchedule
from tqdm import tqdm
import os
import numpy as np
def net():
model = ViT(
image_size = 320,
patch_size = 32,
num_classes = 5,
dim = 768,
depth = 4,
heads = 12,
mlp_dim = 1024,
pool = 'cls',
channels = 3,
dim_head = 12,
dropout = 0.1,
emb_dropout = 0.1
)
return model
#计算模型参数数量
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
#Loss平均
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
criterion = nn.CrossEntropyLoss()
#简单准确率
def simple_accuracy(preds, labels):
return (preds == labels).mean()
#模型测试
def test(device, model, test_loader, global_step):
eval_losses = AverageMeter()
print("***** Running Validation *****")
model.eval()
all_preds, all_label = [], []
epoch_iterator = tqdm(test_loader,
desc="Validating... (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(device) for t in batch)
x, y = batch
with torch.no_grad():
logits = model(x)
eval_loss = criterion(logits, y)
eval_loss = eval_loss.mean()
eval_losses.update(eval_loss.item())
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(device)
val_accuracy = accuracy.detach().cpu().numpy()
print("test Loss: %2.5f" % eval_losses.avg)
print("test Accuracy: %2.5f" % val_accuracy)
return val_accuracy
#保存模型
def save_model(model):
model_checkpoint = os.path.join('./output', "%s_vit_checkpoint.pth" % 'ieemooempty')
torch.save(model, model_checkpoint)
print("Saved model checkpoint to [DIR: %s]", './output')
#训练
def train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader):
optimizer = torch.optim.SGD(model.parameters(),
lr=LEARNING_RATE,
momentum=0.9,
weight_decay=WEIGHT_DECAY)
t_total = train_NUM_STEPS
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=WARMUP_STEPS, t_total=t_total)
model.zero_grad()
losses = AverageMeter()
global_step, best_acc = 0, 0
gradient_accumulation_steps = 1
while True:
model.train()
epoch_iterator = tqdm(train_loader,
desc="Training (X / X Steps) (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True)
all_preds, all_label = [], []
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(device) for t in batch)
x, y = batch
logits = model(x)
loss = criterion(logits,y)
loss.backward()
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
if (step + 1) % gradient_accumulation_steps == 0:
losses.update(loss.item()*gradient_accumulation_steps)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
)
model.train() #需要2次才会保存训练好的模型
if global_step % t_total == 0:
break
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(device)
train_accuracy = accuracy.detach().cpu().numpy()
print("train accuracy: %f" % train_accuracy)
accuracy = test(device, model, test_loader, global_step)
if best_acc < accuracy:
save_model(model)
best_acc = accuracy
losses.reset()
if global_step % t_total == 0:
break
if __name__ == "__main__":
model = net()
train_loader, test_loader = get_loader_new()
trainnumsteps = len(train_loader)
testnumsteps = len(test_loader)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(str(count_parameters(model))+'MB')
# img = torch.randn(1, 3, 320, 320) #(batchsize,channels,width,higth) #3072是默认channels为33*1024
# preds = model(img)
# print(preds)
model = model.to(device)
epoch = 300
train_NUM_STEPS = trainnumsteps * epoch
LEARNING_RATE = 3e-2
WEIGHT_DECAY = 0
WARMUP_STEPS = 500
train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)