diff --git a/hello.py b/hello.py index 936f952..8b5f78f 100644 --- a/hello.py +++ b/hello.py @@ -1,6 +1,14 @@ import torch import torch.nn as nn from vit_pytorch import ViT,SimpleViT +from vit_pytorch.distill import DistillableViT +from vit_pytorch.deepvit import DeepViT +from vit_pytorch.cait import CaiT +from vit_pytorch.pit import PiT +from vit_pytorch.regionvit import RegionViT +from vit_pytorch.sep_vit import SepViT +from vit_pytorch.crossformer import CrossFormer +from vit_pytorch.nest import NesT from utils.data_utils import get_loader_new from utils.scheduler import WarmupCosineSchedule @@ -11,19 +19,109 @@ import os import numpy as np def net(): - - model = ViT( + + + model = NesT( image_size = 600, patch_size = 30, - num_classes = 5, - dim = 768, - depth = 12, - heads = 12, - mlp_dim = 3072, - dropout = 0.1, - emb_dropout = 0.1 + dim = 256, + heads = 16, + num_hierarchies = 3, # number of hierarchies + block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom + num_classes = 5 ) +# model = CrossFormer( #图片尺寸要是7的倍数,如448 +# num_classes = 5, # number of output classes +# dim = (64, 128, 256, 512), # dimension at each stage +# depth = (2, 2, 8, 2), # depth of transformer at each stage +# global_window_size = (8, 4, 2, 1), # global window sizes at each stage +# local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages) +# ) + +# model = RegionViT( #图片尺寸要是7的倍数,如448 +# dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage +# depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage +# window_size = 7, # window size, which should be either 7 or 14 +# num_classes = 5, # number of output classes +# tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models +# use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance +# ) + +# model = SepViT( #图片尺寸要是7的倍数,如448 +# num_classes = 5, +# dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite +# dim_head = 32, # attention head dimension +# heads = (1, 2, 4, 8), # number of heads per stage +# depth = (1, 2, 6, 2), # number of transformer blocks per stage +# window_size = 7, # window size of DSS Attention block +# dropout = 0.1 # dropout +# ) + +# model = PiT( +# image_size = 600, +# patch_size = 30, +# dim = 1024, +# num_classes = 5, +# depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample +# heads = 8, +# mlp_dim = 3072, +# dropout = 0.1, +# emb_dropout = 0.1 +# ) + + +# model = CaiT( +# image_size = 600, +# patch_size = 30, +# num_classes = 1000, +# dim = 1024, +# depth = 12, # depth of transformer for patch to patch attention only +# cls_depth = 2, # depth of cross attention of CLS tokens to patch +# heads = 16, +# mlp_dim = 2048, +# dropout = 0.1, +# emb_dropout = 0.1, +# layer_dropout = 0.05 # randomly dropout 5% of the layers +# ) + +# model = DeepViT( +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 256, +# depth = 6, +# heads = 6, +# mlp_dim = 256, +# dropout = 0.1, +# emb_dropout = 0.1 +# ) + +# model = DistillableViT( +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 1080, +# depth = 12, +# heads = 12, +# mlp_dim = 3072, +# dropout = 0.1, +# emb_dropout = 0.1 +# ) + + +# model = ViT( +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 768, +# depth = 12, +# heads = 12, +# mlp_dim = 3072, +# dropout = 0.1, +# emb_dropout = 0.1 +# ) + # model = SimpleViT( # image_size = 600, # patch_size = 30, @@ -34,53 +132,56 @@ def net(): # mlp_dim = 256 # ) - # model = ViT( - # #Vit-best - # # image_size = 600, - # # patch_size = 30, - # # num_classes = 5, - # # dim = 512, - # # depth = 6, - # # heads = 8, - # # mlp_dim = 512, - # # pool = 'cls', - # # channels = 3, - # # dim_head = 12, - # # dropout = 0.1, - # # emb_dropout = 0.1 +# model = ViT( +# #Vit-best +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 512, +# depth = 6, +# heads = 8, +# mlp_dim = 512, +# pool = 'cls', +# channels = 3, +# dim_head = 12, +# dropout = 0.1, +# emb_dropout = 0.1 - # #Vit-small - # image_size = 600, - # patch_size = 30, - # num_classes = 5, - # dim = 256, - # depth = 8, - # heads = 16, - # mlp_dim = 256, - # pool = 'cls', - # channels = 3, - # dim_head = 16, - # dropout = 0.1, - # emb_dropout = 0.1 +# #Vit-small +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 256, +# depth = 8, +# heads = 16, +# mlp_dim = 256, +# pool = 'cls', +# channels = 3, +# dim_head = 16, +# dropout = 0.1, +# emb_dropout = 0.1 - # #Vit-tiny - # # image_size = 600, - # # patch_size = 30, - # # num_classes = 5, - # # dim = 256, - # # depth = 4, - # # heads = 6, - # # mlp_dim = 256, - # # pool = 'cls', - # # channels = 3, - # # dim_head = 6, - # # dropout = 0.1, - # # emb_dropout = 0.1 - # ) +# #Vit-tiny +# image_size = 600, +# patch_size = 30, +# num_classes = 5, +# dim = 256, +# depth = 4, +# heads = 6, +# mlp_dim = 256, +# pool = 'cls', +# channels = 3, +# dim_head = 6, +# dropout = 0.1, +# emb_dropout = 0.1 +# ) return model +# img = torch.randn(1, 3, 448, 448) +# model = net() +# preds = model(img) # (1, 1000) #计算模型参数数量