import ml_collections #几种模型的配置 def get_testing(): """Returns a minimal configuration for testing.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 1 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 1 config.transformer.num_heads = 1 config.transformer.num_layers = 1 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_b16_config(): """Returns the ViT-B/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.split = 'non-overlap' # config.slide_step = 12 # config.hidden_size = 768 # config.transformer = ml_collections.ConfigDict() # config.transformer.mlp_dim = 3072 # config.transformer.num_heads = 12 # config.transformer.num_layers = 12 config.slide_step = 2 config.hidden_size = 768 #VIT16不能改 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 3072 #VIT16不能改 config.transformer.num_heads = 2 config.transformer.num_layers = 2 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_b32_config(): """Returns the ViT-B/32 configuration.""" config = get_b16_config() config.patches.size = (32, 32) return config def get_l16_config(): """Returns the ViT-L/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 1024 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 4096 config.transformer.num_heads = 16 config.transformer.num_layers = 24 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_l32_config(): """Returns the ViT-L/32 configuration.""" config = get_l16_config() config.patches.size = (32, 32) #patchsize就是16与32的区别 return config def get_h14_config(): """Returns the ViT-L/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (14, 14)}) config.hidden_size = 1280 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 5120 config.transformer.num_heads = 16 config.transformer.num_layers = 32 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config