From 1afc0d6617c44bca4c6eebc95c6b1123aa151ecd Mon Sep 17 00:00:00 2001 From: Brainway Date: Thu, 29 Dec 2022 06:19:17 +0000 Subject: [PATCH] update hello.py. --- hello.py | 47 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/hello.py b/hello.py index 8b5f78f..3a6f262 100644 --- a/hello.py +++ b/hello.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from vit_pytorch import ViT,SimpleViT +from vit_pytorch import ViT,SimpleViT,MAE from vit_pytorch.distill import DistillableViT from vit_pytorch.deepvit import DeepViT from vit_pytorch.cait import CaiT @@ -9,6 +9,9 @@ from vit_pytorch.regionvit import RegionViT from vit_pytorch.sep_vit import SepViT from vit_pytorch.crossformer import CrossFormer from vit_pytorch.nest import NesT +from vit_pytorch.mobile_vit import MobileViT +from vit_pytorch.simmim import SimMIM +from vit_pytorch.ats_vit import ViT from utils.data_utils import get_loader_new from utils.scheduler import WarmupCosineSchedule @@ -20,17 +23,45 @@ import numpy as np def net(): - - model = NesT( + model = ViT( image_size = 600, patch_size = 30, - dim = 256, + num_classes = 5, + dim = 1024, + depth = 6, + max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling heads = 16, - num_hierarchies = 3, # number of hierarchies - block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom - num_classes = 5 + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 ) +# modelv = ViT( +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 1024, +# depth = 6, +# heads = 8, +# mlp_dim = 2048 +# ) + +# model = MAE( +# encoder = modelv, +# masking_ratio = 0.5 # they found 50% to yield the best results +# ) + + +# model = NesT( +# image_size = 600, +# patch_size = 30, +# dim = 256, +# heads = 16, +# num_hierarchies = 3, # number of hierarchies +# block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom +# num_classes = 5 +# ) + # model = CrossFormer( #图片尺寸要是7的倍数,如448 # num_classes = 5, # number of output classes # dim = (64, 128, 256, 512), # dimension at each stage @@ -296,7 +327,7 @@ def train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,W batch = tuple(t.to(device) for t in batch) x, y = batch logits = model(x) - + loss = criterion(logits,y) loss.backward()