diff --git a/models/configs.py b/models/configs.py index 3a5e93c..7fa3acb 100755 --- a/models/configs.py +++ b/models/configs.py @@ -1,7 +1,5 @@ import ml_collections -#几种模型的配置 - def get_testing(): """Returns a minimal configuration for testing.""" config = ml_collections.ConfigDict() @@ -23,18 +21,12 @@ def get_b16_config(): config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.split = 'non-overlap' - # config.slide_step = 12 - # config.hidden_size = 768 - # config.transformer = ml_collections.ConfigDict() - # config.transformer.mlp_dim = 3072 - # config.transformer.num_heads = 12 - # config.transformer.num_layers = 12 - config.slide_step = 2 - config.hidden_size = 768 #VIT16不能改 + config.slide_step = 12 + config.hidden_size = 768 config.transformer = ml_collections.ConfigDict() - config.transformer.mlp_dim = 3072 #VIT16不能改 - config.transformer.num_heads = 2 - config.transformer.num_layers = 2 + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' @@ -65,7 +57,7 @@ def get_l16_config(): def get_l32_config(): """Returns the ViT-L/32 configuration.""" config = get_l16_config() - config.patches.size = (32, 32) #patchsize就是16与32的区别 + config.patches.size = (32, 32) return config def get_h14_config():