update hello.py.

This commit is contained in:
Brainway
2023-03-20 00:44:36 +00:00
committed by Gitee
parent 6373df08b9
commit 90a4691e15

View File

@ -51,17 +51,27 @@ def net():
# masking_ratio = 0.5 # they found 50% to yield the best results # masking_ratio = 0.5 # they found 50% to yield the best results
# ) # )
model = NesT( model = NesT(
image_size = 600, image_size = 600,
patch_size = 30, patch_size = 30,
dim = 256, dim = 256,
heads = 18,#16 heads = 18,#16
num_hierarchies = 3, # number of hierarchies num_hierarchies = 3, # number of hierarchies
block_repeats = (3, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom block_repeats = (2, 4, 18), # (2,4,16)(2,4,12)(2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 5 num_classes = 5
) )
# model = NesT(
# image_size = 600,
# patch_size = 30,
# dim = 256,
# heads = 16,#16
# num_hierarchies = 3, # number of hierarchies
# block_repeats = (2, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
# num_classes = 5
# )
# model = CrossFormer( #图片尺寸要是7的倍数如448 # model = CrossFormer( #图片尺寸要是7的倍数如448
# num_classes = 5, # number of output classes # num_classes = 5, # number of output classes
# dim = (64, 128, 256, 512), # dimension at each stage # dim = (64, 128, 256, 512), # dimension at each stage
@ -405,7 +415,3 @@ if __name__ == "__main__":
WEIGHT_DECAY = 0 WEIGHT_DECAY = 0
WARMUP_STEPS = 500 WARMUP_STEPS = 500
train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader) train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)