From 90a4691e15d1156c88e11a1472978b1842badaef Mon Sep 17 00:00:00 2001 From: Brainway Date: Mon, 20 Mar 2023 00:44:36 +0000 Subject: [PATCH] update hello.py. --- hello.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/hello.py b/hello.py index 0b7ebc7..0798377 100644 --- a/hello.py +++ b/hello.py @@ -51,17 +51,27 @@ def net(): # masking_ratio = 0.5 # they found 50% to yield the best results # ) - model = NesT( image_size = 600, patch_size = 30, dim = 256, heads = 18,#16 num_hierarchies = 3, # number of hierarchies - block_repeats = (3, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom + block_repeats = (2, 4, 18), # (2,4,16)(2,4,12)(2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom num_classes = 5 ) + +# model = NesT( +# image_size = 600, +# patch_size = 30, +# dim = 256, +# heads = 16,#16 +# num_hierarchies = 3, # number of hierarchies +# block_repeats = (2, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom +# num_classes = 5 +# ) + # model = CrossFormer( #图片尺寸要是7的倍数,如448 # num_classes = 5, # number of output classes # dim = (64, 128, 256, 512), # dimension at each stage @@ -404,8 +414,4 @@ if __name__ == "__main__": LEARNING_RATE = 3e-2 WEIGHT_DECAY = 0 WARMUP_STEPS = 500 - train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader) - - - - + train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader) \ No newline at end of file