update hello.py.
This commit is contained in:
20
hello.py
20
hello.py
@ -51,17 +51,27 @@ def net():
|
||||
# masking_ratio = 0.5 # they found 50% to yield the best results
|
||||
# )
|
||||
|
||||
|
||||
model = NesT(
|
||||
image_size = 600,
|
||||
patch_size = 30,
|
||||
dim = 256,
|
||||
heads = 18,#16
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (3, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
block_repeats = (2, 4, 18), # (2,4,16)(2,4,12)(2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 5
|
||||
)
|
||||
|
||||
|
||||
# model = NesT(
|
||||
# image_size = 600,
|
||||
# patch_size = 30,
|
||||
# dim = 256,
|
||||
# heads = 16,#16
|
||||
# num_hierarchies = 3, # number of hierarchies
|
||||
# block_repeats = (2, 3, 16), # (2,2,12) the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
# num_classes = 5
|
||||
# )
|
||||
|
||||
# model = CrossFormer( #图片尺寸要是7的倍数,如448
|
||||
# num_classes = 5, # number of output classes
|
||||
# dim = (64, 128, 256, 512), # dimension at each stage
|
||||
@ -404,8 +414,4 @@ if __name__ == "__main__":
|
||||
LEARNING_RATE = 3e-2
|
||||
WEIGHT_DECAY = 0
|
||||
WARMUP_STEPS = 500
|
||||
train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)
|
||||
|
||||
|
||||
|
||||
|
||||
train(model,train_loader,device,train_NUM_STEPS,LEARNING_RATE,WEIGHT_DECAY,WARMUP_STEPS,test_loader)
|
Reference in New Issue
Block a user