77 lines
2.6 KiB
Python
Executable File
77 lines
2.6 KiB
Python
Executable File
import ml_collections
|
|
|
|
def get_testing():
|
|
"""Returns a minimal configuration for testing."""
|
|
config = ml_collections.ConfigDict()
|
|
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
|
config.hidden_size = 1
|
|
config.transformer = ml_collections.ConfigDict()
|
|
config.transformer.mlp_dim = 1
|
|
config.transformer.num_heads = 1
|
|
config.transformer.num_layers = 1
|
|
config.transformer.attention_dropout_rate = 0.0
|
|
config.transformer.dropout_rate = 0.1
|
|
config.classifier = 'token'
|
|
config.representation_size = None
|
|
return config
|
|
|
|
|
|
def get_b16_config():
|
|
"""Returns the ViT-B/16 configuration."""
|
|
config = ml_collections.ConfigDict()
|
|
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
|
config.split = 'non-overlap'
|
|
config.slide_step = 12
|
|
config.hidden_size = 768
|
|
config.transformer = ml_collections.ConfigDict()
|
|
config.transformer.mlp_dim = 3072
|
|
config.transformer.num_heads = 12
|
|
config.transformer.num_layers = 12
|
|
config.transformer.attention_dropout_rate = 0.0
|
|
config.transformer.dropout_rate = 0.1
|
|
config.classifier = 'token'
|
|
config.representation_size = None
|
|
return config
|
|
|
|
def get_b32_config():
|
|
"""Returns the ViT-B/32 configuration."""
|
|
config = get_b16_config()
|
|
config.patches.size = (32, 32)
|
|
return config
|
|
|
|
def get_l16_config():
|
|
"""Returns the ViT-L/16 configuration."""
|
|
config = ml_collections.ConfigDict()
|
|
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
|
config.hidden_size = 1024
|
|
config.transformer = ml_collections.ConfigDict()
|
|
config.transformer.mlp_dim = 4096
|
|
config.transformer.num_heads = 16
|
|
config.transformer.num_layers = 24
|
|
config.transformer.attention_dropout_rate = 0.0
|
|
config.transformer.dropout_rate = 0.1
|
|
config.classifier = 'token'
|
|
config.representation_size = None
|
|
return config
|
|
|
|
def get_l32_config():
|
|
"""Returns the ViT-L/32 configuration."""
|
|
config = get_l16_config()
|
|
config.patches.size = (32, 32)
|
|
return config
|
|
|
|
def get_h14_config():
|
|
"""Returns the ViT-L/16 configuration."""
|
|
config = ml_collections.ConfigDict()
|
|
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
|
|
config.hidden_size = 1280
|
|
config.transformer = ml_collections.ConfigDict()
|
|
config.transformer.mlp_dim = 5120
|
|
config.transformer.num_heads = 16
|
|
config.transformer.num_layers = 32
|
|
config.transformer.attention_dropout_rate = 0.0
|
|
config.transformer.dropout_rate = 0.1
|
|
config.classifier = 'token'
|
|
config.representation_size = None
|
|
return config
|