update models/modeling.py.
This commit is contained in:
@ -22,35 +22,26 @@ import models.configs as configs
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
#多头注意力参数
|
|
||||||
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
||||||
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
||||||
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
||||||
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
||||||
|
|
||||||
#Dense全连接层
|
|
||||||
FC_0 = "MlpBlock_3/Dense_0"
|
FC_0 = "MlpBlock_3/Dense_0"
|
||||||
FC_1 = "MlpBlock_3/Dense_1"
|
FC_1 = "MlpBlock_3/Dense_1"
|
||||||
|
|
||||||
#批归一化曾
|
|
||||||
ATTENTION_NORM = "LayerNorm_0"
|
ATTENTION_NORM = "LayerNorm_0"
|
||||||
MLP_NORM = "LayerNorm_2"
|
MLP_NORM = "LayerNorm_2"
|
||||||
|
|
||||||
#numpy转tensor
|
|
||||||
def np2th(weights, conv=False):
|
def np2th(weights, conv=False):
|
||||||
"""Possibly convert HWIO to OIHW."""
|
"""Possibly convert HWIO to OIHW."""
|
||||||
if conv:
|
if conv:
|
||||||
weights = weights.transpose([3, 2, 0, 1])
|
weights = weights.transpose([3, 2, 0, 1])
|
||||||
return torch.from_numpy(weights)
|
return torch.from_numpy(weights)
|
||||||
|
|
||||||
#swish激活函数
|
|
||||||
def swish(x):
|
def swish(x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
#gelu激活函数
|
|
||||||
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||||
|
|
||||||
#标签平滑类,用于数据增强
|
|
||||||
class LabelSmoothing(nn.Module):
|
class LabelSmoothing(nn.Module):
|
||||||
"""
|
"""
|
||||||
NLL loss with label smoothing.
|
NLL loss with label smoothing.
|
||||||
@ -73,7 +64,6 @@ class LabelSmoothing(nn.Module):
|
|||||||
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
#注意力层
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
@ -119,7 +109,6 @@ class Attention(nn.Module):
|
|||||||
attention_output = self.proj_dropout(attention_output)
|
attention_output = self.proj_dropout(attention_output)
|
||||||
return attention_output, weights
|
return attention_output, weights
|
||||||
|
|
||||||
#全连接层
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Mlp, self).__init__()
|
super(Mlp, self).__init__()
|
||||||
@ -144,7 +133,6 @@ class Mlp(nn.Module):
|
|||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
#嵌入编码
|
|
||||||
class Embeddings(nn.Module):
|
class Embeddings(nn.Module):
|
||||||
"""Construct the embeddings from patch, position embeddings.
|
"""Construct the embeddings from patch, position embeddings.
|
||||||
"""
|
"""
|
||||||
@ -186,7 +174,6 @@ class Embeddings(nn.Module):
|
|||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
#块
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
@ -245,7 +232,7 @@ class Block(nn.Module):
|
|||||||
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
||||||
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
||||||
|
|
||||||
#部分注意力层
|
|
||||||
class Part_Attention(nn.Module):
|
class Part_Attention(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Part_Attention, self).__init__()
|
super(Part_Attention, self).__init__()
|
||||||
@ -260,7 +247,7 @@ class Part_Attention(nn.Module):
|
|||||||
_, max_inx = last_map.max(2)
|
_, max_inx = last_map.max(2)
|
||||||
return _, max_inx
|
return _, max_inx
|
||||||
|
|
||||||
#编码器
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
@ -290,7 +277,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return part_encoded
|
return part_encoded
|
||||||
|
|
||||||
#Transformer层
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
def __init__(self, config, img_size):
|
def __init__(self, config, img_size):
|
||||||
super(Transformer, self).__init__()
|
super(Transformer, self).__init__()
|
||||||
@ -302,85 +289,7 @@ class Transformer(nn.Module):
|
|||||||
part_encoded = self.encoder(embedding_output)
|
part_encoded = self.encoder(embedding_output)
|
||||||
return part_encoded
|
return part_encoded
|
||||||
|
|
||||||
#VIT层
|
|
||||||
class OldVisionTransformer(nn.Module):
|
|
||||||
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
|
|
||||||
super(VisionTransformer, self).__init__()
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.zero_head = zero_head
|
|
||||||
self.classifier = config.classifier
|
|
||||||
|
|
||||||
self.transformer = Transformer(config, img_size, vis)
|
|
||||||
self.head = Linear(config.hidden_size, num_classes)
|
|
||||||
|
|
||||||
def forward(self, x, labels=None):
|
|
||||||
x, attn_weights = self.transformer(x)
|
|
||||||
logits = self.head(x[:, 0])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
|
|
||||||
return loss
|
|
||||||
else:
|
|
||||||
return logits, attn_weights
|
|
||||||
|
|
||||||
def load_from(self, weights):
|
|
||||||
with torch.no_grad():
|
|
||||||
if self.zero_head:
|
|
||||||
nn.init.zeros_(self.head.weight)
|
|
||||||
nn.init.zeros_(self.head.bias)
|
|
||||||
else:
|
|
||||||
self.head.weight.copy_(np2th(weights["head/kernel"]).t())
|
|
||||||
self.head.bias.copy_(np2th(weights["head/bias"]).t())
|
|
||||||
|
|
||||||
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
|
|
||||||
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
|
|
||||||
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
|
|
||||||
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
|
|
||||||
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
|
|
||||||
|
|
||||||
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
|
|
||||||
posemb_new = self.transformer.embeddings.position_embeddings
|
|
||||||
if posemb.size() == posemb_new.size():
|
|
||||||
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
|
||||||
else:
|
|
||||||
print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
|
|
||||||
ntok_new = posemb_new.size(1)
|
|
||||||
|
|
||||||
if self.classifier == "token":
|
|
||||||
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
|
||||||
ntok_new -= 1
|
|
||||||
else:
|
|
||||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
||||||
|
|
||||||
gs_old = int(np.sqrt(len(posemb_grid)))
|
|
||||||
gs_new = int(np.sqrt(ntok_new))
|
|
||||||
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
|
|
||||||
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
|
||||||
|
|
||||||
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
|
||||||
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
|
|
||||||
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
|
|
||||||
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
|
|
||||||
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
|
|
||||||
|
|
||||||
for bname, block in self.transformer.encoder.named_children():
|
|
||||||
for uname, unit in block.named_children():
|
|
||||||
unit.load_from(weights, n_block=uname)
|
|
||||||
|
|
||||||
if self.transformer.embeddings.hybrid:
|
|
||||||
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
|
|
||||||
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
|
|
||||||
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
|
|
||||||
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
|
|
||||||
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
|
|
||||||
|
|
||||||
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
|
|
||||||
for uname, unit in block.named_children():
|
|
||||||
unit.load_from(weights, n_block=bname, n_unit=uname)
|
|
||||||
|
|
||||||
|
|
||||||
#VIT FG层
|
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(nn.Module):
|
||||||
def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
|
def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
|
||||||
super(VisionTransformer, self).__init__()
|
super(VisionTransformer, self).__init__()
|
||||||
@ -393,7 +302,7 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, labels=None):
|
def forward(self, x, labels=None):
|
||||||
part_tokens = self.transformer(x)
|
part_tokens = self.transformer(x)
|
||||||
part_logits = self.part_head(part_tokens[:, 0]) #part部分可以理解是细粒度,它专注于捕捉微小差异。但生物其实不需要这个,因为生物视觉本身就是有part功能的,通过眼球转动调整感受野来完成这一点
|
part_logits = self.part_head(part_tokens[:, 0])
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.smoothing_value == 0:
|
if self.smoothing_value == 0:
|
||||||
@ -456,7 +365,7 @@ class VisionTransformer(nn.Module):
|
|||||||
for uname, unit in block.named_children():
|
for uname, unit in block.named_children():
|
||||||
unit.load_from(weights, n_block=bname, n_unit=uname)
|
unit.load_from(weights, n_block=bname, n_unit=uname)
|
||||||
|
|
||||||
#loss计算
|
|
||||||
def con_loss(features, labels):
|
def con_loss(features, labels):
|
||||||
B, _ = features.shape
|
B, _ = features.shape
|
||||||
features = F.normalize(features)
|
features = F.normalize(features)
|
||||||
@ -470,7 +379,7 @@ def con_loss(features, labels):
|
|||||||
loss /= (B * B)
|
loss /= (B * B)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
#几种VIT模型配置
|
|
||||||
CONFIGS = {
|
CONFIGS = {
|
||||||
'ViT-B_16': configs.get_b16_config(),
|
'ViT-B_16': configs.get_b16_config(),
|
||||||
'ViT-B_32': configs.get_b32_config(),
|
'ViT-B_32': configs.get_b32_config(),
|
||||||
|
Reference in New Issue
Block a user