diff --git a/models/configs.py b/models/configs.py index 7fa3acb..3a5e93c 100755 --- a/models/configs.py +++ b/models/configs.py @@ -1,5 +1,7 @@ import ml_collections +#几种模型的配置 + def get_testing(): """Returns a minimal configuration for testing.""" config = ml_collections.ConfigDict() @@ -21,12 +23,18 @@ def get_b16_config(): config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.split = 'non-overlap' - config.slide_step = 12 - config.hidden_size = 768 + # config.slide_step = 12 + # config.hidden_size = 768 + # config.transformer = ml_collections.ConfigDict() + # config.transformer.mlp_dim = 3072 + # config.transformer.num_heads = 12 + # config.transformer.num_layers = 12 + config.slide_step = 2 + config.hidden_size = 768 #VIT16不能改 config.transformer = ml_collections.ConfigDict() - config.transformer.mlp_dim = 3072 - config.transformer.num_heads = 12 - config.transformer.num_layers = 12 + config.transformer.mlp_dim = 3072 #VIT16不能改 + config.transformer.num_heads = 2 + config.transformer.num_layers = 2 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' @@ -57,7 +65,7 @@ def get_l16_config(): def get_l32_config(): """Returns the ViT-L/32 configuration.""" config = get_l16_config() - config.patches.size = (32, 32) + config.patches.size = (32, 32) #patchsize就是16与32的区别 return config def get_h14_config():