update models/configs.py.

This commit is contained in:
Brainway
2022-10-18 03:35:33 +00:00
committed by Gitee
parent 52427ac8a9
commit accca98d1c

View File

@ -1,7 +1,5 @@
import ml_collections
#几种模型的配置
def get_testing():
"""Returns a minimal configuration for testing."""
config = ml_collections.ConfigDict()
@ -23,18 +21,12 @@ def get_b16_config():
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.split = 'non-overlap'
# config.slide_step = 12
# config.hidden_size = 768
# config.transformer = ml_collections.ConfigDict()
# config.transformer.mlp_dim = 3072
# config.transformer.num_heads = 12
# config.transformer.num_layers = 12
config.slide_step = 2
config.hidden_size = 768 #VIT16不能改
config.slide_step = 12
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072 #VIT16不能改
config.transformer.num_heads = 2
config.transformer.num_layers = 2
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
@ -65,7 +57,7 @@ def get_l16_config():
def get_l32_config():
"""Returns the ViT-L/32 configuration."""
config = get_l16_config()
config.patches.size = (32, 32) #patchsize就是16与32的区别
config.patches.size = (32, 32)
return config
def get_h14_config():