import ml_collections def get_testing(): """Returns a minimal configuration for testing.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 1 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 1 config.transformer.num_heads = 1 config.transformer.num_layers = 1 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_b16_config(): """Returns the ViT-B/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.split = 'non-overlap' config.slide_step = 12 config.hidden_size = 768 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 3072 config.transformer.num_heads = 12 config.transformer.num_layers = 12 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_b32_config(): """Returns the ViT-B/32 configuration.""" config = get_b16_config() config.patches.size = (32, 32) return config def get_l16_config(): """Returns the ViT-L/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 1024 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 4096 config.transformer.num_heads = 16 config.transformer.num_layers = 24 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_l32_config(): """Returns the ViT-L/32 configuration.""" config = get_l16_config() config.patches.size = (32, 32) return config def get_h14_config(): """Returns the ViT-L/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (14, 14)}) config.hidden_size = 1280 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 5120 config.transformer.num_heads = 16 config.transformer.num_layers = 32 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config