From f3dd348fc8d0fda6ff803e78692461f7cd1f8011 Mon Sep 17 00:00:00 2001 From: Brainway Date: Tue, 1 Nov 2022 02:58:57 +0000 Subject: [PATCH] rename trial.py to hello.py. --- trial.py => hello.py | 80 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 14 deletions(-) rename trial.py => hello.py (81%) diff --git a/trial.py b/hello.py similarity index 81% rename from trial.py rename to hello.py index 8d9c93b..936f952 100644 --- a/trial.py +++ b/hello.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from vit_pytorch import ViT +from vit_pytorch import ViT,SimpleViT from utils.data_utils import get_loader_new from utils.scheduler import WarmupCosineSchedule @@ -11,20 +11,72 @@ import os import numpy as np def net(): + model = ViT( - image_size = 320, - patch_size = 32, - num_classes = 5, - dim = 768, - depth = 4, - heads = 12, - mlp_dim = 1024, - pool = 'cls', - channels = 3, - dim_head = 12, - dropout = 0.1, - emb_dropout = 0.1 - ) + image_size = 600, + patch_size = 30, + num_classes = 5, + dim = 768, + depth = 12, + heads = 12, + mlp_dim = 3072, + dropout = 0.1, + emb_dropout = 0.1 +) + +# model = SimpleViT( +# image_size = 600, +# patch_size = 30, +# num_classes = 2, +# dim = 256, +# depth = 6, +# heads = 16, +# mlp_dim = 256 +# ) + + # model = ViT( + # #Vit-best + # # image_size = 600, + # # patch_size = 30, + # # num_classes = 5, + # # dim = 512, + # # depth = 6, + # # heads = 8, + # # mlp_dim = 512, + # # pool = 'cls', + # # channels = 3, + # # dim_head = 12, + # # dropout = 0.1, + # # emb_dropout = 0.1 + + # #Vit-small + # image_size = 600, + # patch_size = 30, + # num_classes = 5, + # dim = 256, + # depth = 8, + # heads = 16, + # mlp_dim = 256, + # pool = 'cls', + # channels = 3, + # dim_head = 16, + # dropout = 0.1, + # emb_dropout = 0.1 + + # #Vit-tiny + # # image_size = 600, + # # patch_size = 30, + # # num_classes = 5, + # # dim = 256, + # # depth = 4, + # # heads = 6, + # # mlp_dim = 256, + # # pool = 'cls', + # # channels = 3, + # # dim_head = 6, + # # dropout = 0.1, + # # emb_dropout = 0.1 + # ) return model