update hello.py.
This commit is contained in:
47
hello.py
47
hello.py
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vit_pytorch import ViT,SimpleViT
|
||||
from vit_pytorch import ViT,SimpleViT,MAE
|
||||
from vit_pytorch.distill import DistillableViT
|
||||
from vit_pytorch.deepvit import DeepViT
|
||||
from vit_pytorch.cait import CaiT
|
||||
@ -9,6 +9,9 @@ from vit_pytorch.regionvit import RegionViT
|
||||
from vit_pytorch.sep_vit import SepViT
|
||||
from vit_pytorch.crossformer import CrossFormer
|
||||
from vit_pytorch.nest import NesT
|
||||
from vit_pytorch.mobile_vit import MobileViT
|
||||
from vit_pytorch.simmim import SimMIM
|
||||
from vit_pytorch.ats_vit import ViT
|
||||
|
||||
from utils.data_utils import get_loader_new
|
||||
from utils.scheduler import WarmupCosineSchedule
|
||||
@ -20,17 +23,45 @@ import numpy as np
|
||||
|
||||
def net():
|
||||
|
||||
|
||||
model = NesT(
|
||||
model = ViT(
|
||||
image_size = 600,
|
||||
patch_size = 30,
|
||||
dim = 256,
|
||||
num_classes = 5,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
|
||||
heads = 16,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 5
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# modelv = ViT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# num_classes = 5,
|
||||
# dim = 1024,
|
||||
# depth = 6,
|
||||
# heads = 8,
|
||||
# mlp_dim = 2048
|
||||
# )
|
||||
|
||||
# model = MAE(
|
||||
# encoder = modelv,
|
||||
# masking_ratio = 0.5 # they found 50% to yield the best results
|
||||
# )
|
||||
|
||||
|
||||
# model = NesT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# dim = 256,
|
||||
# heads = 16,
|
||||
# num_hierarchies = 3, # number of hierarchies
|
||||
# block_repeats = (2, 2, 12), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
# num_classes = 5
|
||||
# )
|
||||
|
||||
# model = CrossFormer( #图片尺寸要是7的倍数,如448
|
||||
# num_classes = 5, # number of output classes
|
||||
# dim = (64, 128, 256, 512), # dimension at each stage
|
||||
@ -296,7 +327,7 @@ def train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,W
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
x, y = batch
|
||||
logits = model(x)
|
||||
|
||||
|
||||
loss = criterion(logits,y)
|
||||
|
||||
loss.backward()
|
||||
|
Reference in New Issue
Block a user