update hello.py.
This commit is contained in:
165
hello.py
165
hello.py
@ -1,6 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vit_pytorch import ViT,SimpleViT
|
from vit_pytorch import ViT,SimpleViT
|
||||||
|
from vit_pytorch.distill import DistillableViT
|
||||||
|
from vit_pytorch.deepvit import DeepViT
|
||||||
|
from vit_pytorch.cait import CaiT
|
||||||
|
from vit_pytorch.pit import PiT
|
||||||
|
from vit_pytorch.regionvit import RegionViT
|
||||||
|
from vit_pytorch.sep_vit import SepViT
|
||||||
|
from vit_pytorch.crossformer import CrossFormer
|
||||||
|
from vit_pytorch.nest import NesT
|
||||||
|
|
||||||
from utils.data_utils import get_loader_new
|
from utils.data_utils import get_loader_new
|
||||||
from utils.scheduler import WarmupCosineSchedule
|
from utils.scheduler import WarmupCosineSchedule
|
||||||
@ -12,18 +20,108 @@ import numpy as np
|
|||||||
|
|
||||||
def net():
|
def net():
|
||||||
|
|
||||||
model = ViT(
|
|
||||||
|
model = NesT(
|
||||||
image_size = 600,
|
image_size = 600,
|
||||||
patch_size = 30,
|
patch_size = 30,
|
||||||
num_classes = 5,
|
dim = 256,
|
||||||
dim = 768,
|
heads = 16,
|
||||||
depth = 12,
|
num_hierarchies = 3, # number of hierarchies
|
||||||
heads = 12,
|
block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||||
mlp_dim = 3072,
|
num_classes = 5
|
||||||
dropout = 0.1,
|
|
||||||
emb_dropout = 0.1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# model = CrossFormer( #图片尺寸要是7的倍数,如448
|
||||||
|
# num_classes = 5, # number of output classes
|
||||||
|
# dim = (64, 128, 256, 512), # dimension at each stage
|
||||||
|
# depth = (2, 2, 8, 2), # depth of transformer at each stage
|
||||||
|
# global_window_size = (8, 4, 2, 1), # global window sizes at each stage
|
||||||
|
# local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# model = RegionViT( #图片尺寸要是7的倍数,如448
|
||||||
|
# dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||||
|
# depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||||
|
# window_size = 7, # window size, which should be either 7 or 14
|
||||||
|
# num_classes = 5, # number of output classes
|
||||||
|
# tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||||
|
# use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||||
|
# )
|
||||||
|
|
||||||
|
# model = SepViT( #图片尺寸要是7的倍数,如448
|
||||||
|
# num_classes = 5,
|
||||||
|
# dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
|
||||||
|
# dim_head = 32, # attention head dimension
|
||||||
|
# heads = (1, 2, 4, 8), # number of heads per stage
|
||||||
|
# depth = (1, 2, 6, 2), # number of transformer blocks per stage
|
||||||
|
# window_size = 7, # window size of DSS Attention block
|
||||||
|
# dropout = 0.1 # dropout
|
||||||
|
# )
|
||||||
|
|
||||||
|
# model = PiT(
|
||||||
|
# image_size = 600,
|
||||||
|
# patch_size = 30,
|
||||||
|
# dim = 1024,
|
||||||
|
# num_classes = 5,
|
||||||
|
# depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample
|
||||||
|
# heads = 8,
|
||||||
|
# mlp_dim = 3072,
|
||||||
|
# dropout = 0.1,
|
||||||
|
# emb_dropout = 0.1
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# model = CaiT(
|
||||||
|
# image_size = 600,
|
||||||
|
# patch_size = 30,
|
||||||
|
# num_classes = 1000,
|
||||||
|
# dim = 1024,
|
||||||
|
# depth = 12, # depth of transformer for patch to patch attention only
|
||||||
|
# cls_depth = 2, # depth of cross attention of CLS tokens to patch
|
||||||
|
# heads = 16,
|
||||||
|
# mlp_dim = 2048,
|
||||||
|
# dropout = 0.1,
|
||||||
|
# emb_dropout = 0.1,
|
||||||
|
# layer_dropout = 0.05 # randomly dropout 5% of the layers
|
||||||
|
# )
|
||||||
|
|
||||||
|
# model = DeepViT(
|
||||||
|
# image_size = 600,
|
||||||
|
# patch_size = 30,
|
||||||
|
# num_classes = 5,
|
||||||
|
# dim = 256,
|
||||||
|
# depth = 6,
|
||||||
|
# heads = 6,
|
||||||
|
# mlp_dim = 256,
|
||||||
|
# dropout = 0.1,
|
||||||
|
# emb_dropout = 0.1
|
||||||
|
# )
|
||||||
|
|
||||||
|
# model = DistillableViT(
|
||||||
|
# image_size = 600,
|
||||||
|
# patch_size = 30,
|
||||||
|
# num_classes = 5,
|
||||||
|
# dim = 1080,
|
||||||
|
# depth = 12,
|
||||||
|
# heads = 12,
|
||||||
|
# mlp_dim = 3072,
|
||||||
|
# dropout = 0.1,
|
||||||
|
# emb_dropout = 0.1
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# model = ViT(
|
||||||
|
# image_size = 600,
|
||||||
|
# patch_size = 30,
|
||||||
|
# num_classes = 5,
|
||||||
|
# dim = 768,
|
||||||
|
# depth = 12,
|
||||||
|
# heads = 12,
|
||||||
|
# mlp_dim = 3072,
|
||||||
|
# dropout = 0.1,
|
||||||
|
# emb_dropout = 0.1
|
||||||
|
# )
|
||||||
|
|
||||||
# model = SimpleViT(
|
# model = SimpleViT(
|
||||||
# image_size = 600,
|
# image_size = 600,
|
||||||
# patch_size = 30,
|
# patch_size = 30,
|
||||||
@ -36,18 +134,18 @@ def net():
|
|||||||
|
|
||||||
# model = ViT(
|
# model = ViT(
|
||||||
# #Vit-best
|
# #Vit-best
|
||||||
# # image_size = 600,
|
# image_size = 600,
|
||||||
# # patch_size = 30,
|
# patch_size = 30,
|
||||||
# # num_classes = 5,
|
# num_classes = 5,
|
||||||
# # dim = 512,
|
# dim = 512,
|
||||||
# # depth = 6,
|
# depth = 6,
|
||||||
# # heads = 8,
|
# heads = 8,
|
||||||
# # mlp_dim = 512,
|
# mlp_dim = 512,
|
||||||
# # pool = 'cls',
|
# pool = 'cls',
|
||||||
# # channels = 3,
|
# channels = 3,
|
||||||
# # dim_head = 12,
|
# dim_head = 12,
|
||||||
# # dropout = 0.1,
|
# dropout = 0.1,
|
||||||
# # emb_dropout = 0.1
|
# emb_dropout = 0.1
|
||||||
|
|
||||||
# #Vit-small
|
# #Vit-small
|
||||||
# image_size = 600,
|
# image_size = 600,
|
||||||
@ -64,23 +162,26 @@ def net():
|
|||||||
# emb_dropout = 0.1
|
# emb_dropout = 0.1
|
||||||
|
|
||||||
# #Vit-tiny
|
# #Vit-tiny
|
||||||
# # image_size = 600,
|
# image_size = 600,
|
||||||
# # patch_size = 30,
|
# patch_size = 30,
|
||||||
# # num_classes = 5,
|
# num_classes = 5,
|
||||||
# # dim = 256,
|
# dim = 256,
|
||||||
# # depth = 4,
|
# depth = 4,
|
||||||
# # heads = 6,
|
# heads = 6,
|
||||||
# # mlp_dim = 256,
|
# mlp_dim = 256,
|
||||||
# # pool = 'cls',
|
# pool = 'cls',
|
||||||
# # channels = 3,
|
# channels = 3,
|
||||||
# # dim_head = 6,
|
# dim_head = 6,
|
||||||
# # dropout = 0.1,
|
# dropout = 0.1,
|
||||||
# # emb_dropout = 0.1
|
# emb_dropout = 0.1
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# img = torch.randn(1, 3, 448, 448)
|
||||||
|
# model = net()
|
||||||
|
# preds = model(img) # (1, 1000)
|
||||||
|
|
||||||
|
|
||||||
#计算模型参数数量
|
#计算模型参数数量
|
||||||
|
Reference in New Issue
Block a user