update models/configs.py.
This commit is contained in:
@ -1,7 +1,5 @@
|
|||||||
import ml_collections
|
import ml_collections
|
||||||
|
|
||||||
#几种模型的配置
|
|
||||||
|
|
||||||
def get_testing():
|
def get_testing():
|
||||||
"""Returns a minimal configuration for testing."""
|
"""Returns a minimal configuration for testing."""
|
||||||
config = ml_collections.ConfigDict()
|
config = ml_collections.ConfigDict()
|
||||||
@ -23,18 +21,12 @@ def get_b16_config():
|
|||||||
config = ml_collections.ConfigDict()
|
config = ml_collections.ConfigDict()
|
||||||
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
||||||
config.split = 'non-overlap'
|
config.split = 'non-overlap'
|
||||||
# config.slide_step = 12
|
config.slide_step = 12
|
||||||
# config.hidden_size = 768
|
config.hidden_size = 768
|
||||||
# config.transformer = ml_collections.ConfigDict()
|
|
||||||
# config.transformer.mlp_dim = 3072
|
|
||||||
# config.transformer.num_heads = 12
|
|
||||||
# config.transformer.num_layers = 12
|
|
||||||
config.slide_step = 2
|
|
||||||
config.hidden_size = 768 #VIT16不能改
|
|
||||||
config.transformer = ml_collections.ConfigDict()
|
config.transformer = ml_collections.ConfigDict()
|
||||||
config.transformer.mlp_dim = 3072 #VIT16不能改
|
config.transformer.mlp_dim = 3072
|
||||||
config.transformer.num_heads = 2
|
config.transformer.num_heads = 12
|
||||||
config.transformer.num_layers = 2
|
config.transformer.num_layers = 12
|
||||||
config.transformer.attention_dropout_rate = 0.0
|
config.transformer.attention_dropout_rate = 0.0
|
||||||
config.transformer.dropout_rate = 0.1
|
config.transformer.dropout_rate = 0.1
|
||||||
config.classifier = 'token'
|
config.classifier = 'token'
|
||||||
@ -65,7 +57,7 @@ def get_l16_config():
|
|||||||
def get_l32_config():
|
def get_l32_config():
|
||||||
"""Returns the ViT-L/32 configuration."""
|
"""Returns the ViT-L/32 configuration."""
|
||||||
config = get_l16_config()
|
config = get_l16_config()
|
||||||
config.patches.size = (32, 32) #patchsize就是16与32的区别
|
config.patches.size = (32, 32)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def get_h14_config():
|
def get_h14_config():
|
||||||
|
Reference in New Issue
Block a user