update hello.py.
This commit is contained in:
201
hello.py
201
hello.py
@ -1,6 +1,14 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vit_pytorch import ViT,SimpleViT
|
||||
from vit_pytorch.distill import DistillableViT
|
||||
from vit_pytorch.deepvit import DeepViT
|
||||
from vit_pytorch.cait import CaiT
|
||||
from vit_pytorch.pit import PiT
|
||||
from vit_pytorch.regionvit import RegionViT
|
||||
from vit_pytorch.sep_vit import SepViT
|
||||
from vit_pytorch.crossformer import CrossFormer
|
||||
from vit_pytorch.nest import NesT
|
||||
|
||||
from utils.data_utils import get_loader_new
|
||||
from utils.scheduler import WarmupCosineSchedule
|
||||
@ -11,19 +19,109 @@ import os
|
||||
import numpy as np
|
||||
|
||||
def net():
|
||||
|
||||
model = ViT(
|
||||
|
||||
|
||||
model = NesT(
|
||||
image_size = 600,
|
||||
patch_size = 30,
|
||||
num_classes = 5,
|
||||
dim = 768,
|
||||
depth = 12,
|
||||
heads = 12,
|
||||
mlp_dim = 3072,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
dim = 256,
|
||||
heads = 16,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 5
|
||||
)
|
||||
|
||||
# model = CrossFormer( #图片尺寸要是7的倍数,如448
|
||||
# num_classes = 5, # number of output classes
|
||||
# dim = (64, 128, 256, 512), # dimension at each stage
|
||||
# depth = (2, 2, 8, 2), # depth of transformer at each stage
|
||||
# global_window_size = (8, 4, 2, 1), # global window sizes at each stage
|
||||
# local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
|
||||
# )
|
||||
|
||||
# model = RegionViT( #图片尺寸要是7的倍数,如448
|
||||
# dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||
# depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||
# window_size = 7, # window size, which should be either 7 or 14
|
||||
# num_classes = 5, # number of output classes
|
||||
# tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||
# use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||
# )
|
||||
|
||||
# model = SepViT( #图片尺寸要是7的倍数,如448
|
||||
# num_classes = 5,
|
||||
# dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
|
||||
# dim_head = 32, # attention head dimension
|
||||
# heads = (1, 2, 4, 8), # number of heads per stage
|
||||
# depth = (1, 2, 6, 2), # number of transformer blocks per stage
|
||||
# window_size = 7, # window size of DSS Attention block
|
||||
# dropout = 0.1 # dropout
|
||||
# )
|
||||
|
||||
# model = PiT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# dim = 1024,
|
||||
# num_classes = 5,
|
||||
# depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample
|
||||
# heads = 8,
|
||||
# mlp_dim = 3072,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# )
|
||||
|
||||
|
||||
# model = CaiT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 1000,
|
||||
# dim = 1024,
|
||||
# depth = 12, # depth of transformer for patch to patch attention only
|
||||
# cls_depth = 2, # depth of cross attention of CLS tokens to patch
|
||||
# heads = 16,
|
||||
# mlp_dim = 2048,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1,
|
||||
# layer_dropout = 0.05 # randomly dropout 5% of the layers
|
||||
# )
|
||||
|
||||
# model = DeepViT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 256,
|
||||
# depth = 6,
|
||||
# heads = 6,
|
||||
# mlp_dim = 256,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# )
|
||||
|
||||
# model = DistillableViT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 1080,
|
||||
# depth = 12,
|
||||
# heads = 12,
|
||||
# mlp_dim = 3072,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# )
|
||||
|
||||
|
||||
# model = ViT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 768,
|
||||
# depth = 12,
|
||||
# heads = 12,
|
||||
# mlp_dim = 3072,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# )
|
||||
|
||||
# model = SimpleViT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
@ -34,53 +132,56 @@ def net():
|
||||
# mlp_dim = 256
|
||||
# )
|
||||
|
||||
# model = ViT(
|
||||
# #Vit-best
|
||||
# # image_size = 600,
|
||||
# # patch_size = 30,
|
||||
# # num_classes = 5,
|
||||
# # dim = 512,
|
||||
# # depth = 6,
|
||||
# # heads = 8,
|
||||
# # mlp_dim = 512,
|
||||
# # pool = 'cls',
|
||||
# # channels = 3,
|
||||
# # dim_head = 12,
|
||||
# # dropout = 0.1,
|
||||
# # emb_dropout = 0.1
|
||||
# model = ViT(
|
||||
# #Vit-best
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 512,
|
||||
# depth = 6,
|
||||
# heads = 8,
|
||||
# mlp_dim = 512,
|
||||
# pool = 'cls',
|
||||
# channels = 3,
|
||||
# dim_head = 12,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
|
||||
# #Vit-small
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 256,
|
||||
# depth = 8,
|
||||
# heads = 16,
|
||||
# mlp_dim = 256,
|
||||
# pool = 'cls',
|
||||
# channels = 3,
|
||||
# dim_head = 16,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# #Vit-small
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 256,
|
||||
# depth = 8,
|
||||
# heads = 16,
|
||||
# mlp_dim = 256,
|
||||
# pool = 'cls',
|
||||
# channels = 3,
|
||||
# dim_head = 16,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
|
||||
# #Vit-tiny
|
||||
# # image_size = 600,
|
||||
# # patch_size = 30,
|
||||
# # num_classes = 5,
|
||||
# # dim = 256,
|
||||
# # depth = 4,
|
||||
# # heads = 6,
|
||||
# # mlp_dim = 256,
|
||||
# # pool = 'cls',
|
||||
# # channels = 3,
|
||||
# # dim_head = 6,
|
||||
# # dropout = 0.1,
|
||||
# # emb_dropout = 0.1
|
||||
# )
|
||||
# #Vit-tiny
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 256,
|
||||
# depth = 4,
|
||||
# heads = 6,
|
||||
# mlp_dim = 256,
|
||||
# pool = 'cls',
|
||||
# channels = 3,
|
||||
# dim_head = 6,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# )
|
||||
|
||||
|
||||
return model
|
||||
|
||||
# img = torch.randn(1, 3, 448, 448)
|
||||
# model = net()
|
||||
# preds = model(img) # (1, 1000)
|
||||
|
||||
|
||||
#计算模型参数数量
|
||||
|
Reference in New Issue
Block a user