update hello.py.

This commit is contained in:
Brainway
2022-12-07 15:14:43 +00:00
committed by Gitee
parent 2fb1b29f5c
commit 29ae39e46f

201
hello.py
View File

@ -1,6 +1,14 @@
import torch
import torch.nn as nn
from vit_pytorch import ViT,SimpleViT
from vit_pytorch.distill import DistillableViT
from vit_pytorch.deepvit import DeepViT
from vit_pytorch.cait import CaiT
from vit_pytorch.pit import PiT
from vit_pytorch.regionvit import RegionViT
from vit_pytorch.sep_vit import SepViT
from vit_pytorch.crossformer import CrossFormer
from vit_pytorch.nest import NesT
from utils.data_utils import get_loader_new
from utils.scheduler import WarmupCosineSchedule
@ -11,19 +19,109 @@ import os
import numpy as np
def net():
model = ViT(
model = NesT(
image_size = 600,
patch_size = 30,
num_classes = 5,
dim = 768,
depth = 12,
heads = 12,
mlp_dim = 3072,
dropout = 0.1,
emb_dropout = 0.1
dim = 256,
heads = 16,
num_hierarchies = 3, # number of hierarchies
block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 5
)
# model = CrossFormer( #图片尺寸要是7的倍数如448
# num_classes = 5, # number of output classes
# dim = (64, 128, 256, 512), # dimension at each stage
# depth = (2, 2, 8, 2), # depth of transformer at each stage
# global_window_size = (8, 4, 2, 1), # global window sizes at each stage
# local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
# )
# model = RegionViT( #图片尺寸要是7的倍数如448
# dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
# depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
# window_size = 7, # window size, which should be either 7 or 14
# num_classes = 5, # number of output classes
# tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
# use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
# )
# model = SepViT( #图片尺寸要是7的倍数如448
# num_classes = 5,
# dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
# dim_head = 32, # attention head dimension
# heads = (1, 2, 4, 8), # number of heads per stage
# depth = (1, 2, 6, 2), # number of transformer blocks per stage
# window_size = 7, # window size of DSS Attention block
# dropout = 0.1 # dropout
# )
# model = PiT(
# image_size = 600,
# patch_size = 30,
# dim = 1024,
# num_classes = 5,
# depth = (1, 1, 1), # list of depths, indicating the number of rounds of each stage before a downsample
# heads = 8,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = CaiT(
# image_size = 600,
# patch_size = 30,
# num_classes = 1000,
# dim = 1024,
# depth = 12, # depth of transformer for patch to patch attention only
# cls_depth = 2, # depth of cross attention of CLS tokens to patch
# heads = 16,
# mlp_dim = 2048,
# dropout = 0.1,
# emb_dropout = 0.1,
# layer_dropout = 0.05 # randomly dropout 5% of the layers
# )
# model = DeepViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 6,
# heads = 6,
# mlp_dim = 256,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = DistillableViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 1080,
# depth = 12,
# heads = 12,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = ViT(
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 768,
# depth = 12,
# heads = 12,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# model = SimpleViT(
# image_size = 600,
# patch_size = 30,
@ -34,53 +132,56 @@ def net():
# mlp_dim = 256
# )
# model = ViT(
# #Vit-best
# # image_size = 600,
# # patch_size = 30,
# # num_classes = 5,
# # dim = 512,
# # depth = 6,
# # heads = 8,
# # mlp_dim = 512,
# # pool = 'cls',
# # channels = 3,
# # dim_head = 12,
# # dropout = 0.1,
# # emb_dropout = 0.1
# model = ViT(
# #Vit-best
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 512,
# depth = 6,
# heads = 8,
# mlp_dim = 512,
# pool = 'cls',
# channels = 3,
# dim_head = 12,
# dropout = 0.1,
# emb_dropout = 0.1
# #Vit-small
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 8,
# heads = 16,
# mlp_dim = 256,
# pool = 'cls',
# channels = 3,
# dim_head = 16,
# dropout = 0.1,
# emb_dropout = 0.1
# #Vit-small
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 8,
# heads = 16,
# mlp_dim = 256,
# pool = 'cls',
# channels = 3,
# dim_head = 16,
# dropout = 0.1,
# emb_dropout = 0.1
# #Vit-tiny
# # image_size = 600,
# # patch_size = 30,
# # num_classes = 5,
# # dim = 256,
# # depth = 4,
# # heads = 6,
# # mlp_dim = 256,
# # pool = 'cls',
# # channels = 3,
# # dim_head = 6,
# # dropout = 0.1,
# # emb_dropout = 0.1
# )
# #Vit-tiny
# image_size = 600,
# patch_size = 30,
# num_classes = 5,
# dim = 256,
# depth = 4,
# heads = 6,
# mlp_dim = 256,
# pool = 'cls',
# channels = 3,
# dim_head = 6,
# dropout = 0.1,
# emb_dropout = 0.1
# )
return model
# img = torch.randn(1, 3, 448, 448)
# model = net()
# preds = model(img) # (1, 1000)
#计算模型参数数量