From 0c2e0dccacecf64114189c1975701342af9dc6de Mon Sep 17 00:00:00 2001 From: Brainway Date: Tue, 18 Oct 2022 03:35:54 +0000 Subject: [PATCH] update models/modeling.py. --- models/modeling.py | 103 +++------------------------------------------ 1 file changed, 6 insertions(+), 97 deletions(-) diff --git a/models/modeling.py b/models/modeling.py index ea5b1ae..fc3082b 100755 --- a/models/modeling.py +++ b/models/modeling.py @@ -22,35 +22,26 @@ import models.configs as configs logger = logging.getLogger(__name__) -#多头注意力参数 ATTENTION_Q = "MultiHeadDotProductAttention_1/query" ATTENTION_K = "MultiHeadDotProductAttention_1/key" ATTENTION_V = "MultiHeadDotProductAttention_1/value" ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" - -#Dense全连接层 FC_0 = "MlpBlock_3/Dense_0" FC_1 = "MlpBlock_3/Dense_1" - -#批归一化曾 ATTENTION_NORM = "LayerNorm_0" MLP_NORM = "LayerNorm_2" -#numpy转tensor def np2th(weights, conv=False): """Possibly convert HWIO to OIHW.""" if conv: weights = weights.transpose([3, 2, 0, 1]) return torch.from_numpy(weights) -#swish激活函数 def swish(x): return x * torch.sigmoid(x) -#gelu激活函数 ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} -#标签平滑类,用于数据增强 class LabelSmoothing(nn.Module): """ NLL loss with label smoothing. @@ -73,7 +64,6 @@ class LabelSmoothing(nn.Module): loss = self.confidence * nll_loss + self.smoothing * smooth_loss return loss.mean() -#注意力层 class Attention(nn.Module): def __init__(self, config): super(Attention, self).__init__() @@ -119,7 +109,6 @@ class Attention(nn.Module): attention_output = self.proj_dropout(attention_output) return attention_output, weights -#全连接层 class Mlp(nn.Module): def __init__(self, config): super(Mlp, self).__init__() @@ -144,7 +133,6 @@ class Mlp(nn.Module): x = self.dropout(x) return x -#嵌入编码 class Embeddings(nn.Module): """Construct the embeddings from patch, position embeddings. """ @@ -186,7 +174,6 @@ class Embeddings(nn.Module): embeddings = self.dropout(embeddings) return embeddings -#块 class Block(nn.Module): def __init__(self, config): super(Block, self).__init__() @@ -245,7 +232,7 @@ class Block(nn.Module): self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) -#部分注意力层 + class Part_Attention(nn.Module): def __init__(self): super(Part_Attention, self).__init__() @@ -260,7 +247,7 @@ class Part_Attention(nn.Module): _, max_inx = last_map.max(2) return _, max_inx -#编码器 + class Encoder(nn.Module): def __init__(self, config): super(Encoder, self).__init__() @@ -290,7 +277,7 @@ class Encoder(nn.Module): return part_encoded -#Transformer层 + class Transformer(nn.Module): def __init__(self, config, img_size): super(Transformer, self).__init__() @@ -302,85 +289,7 @@ class Transformer(nn.Module): part_encoded = self.encoder(embedding_output) return part_encoded -#VIT层 -class OldVisionTransformer(nn.Module): - def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): - super(VisionTransformer, self).__init__() - self.num_classes = num_classes - self.zero_head = zero_head - self.classifier = config.classifier - self.transformer = Transformer(config, img_size, vis) - self.head = Linear(config.hidden_size, num_classes) - - def forward(self, x, labels=None): - x, attn_weights = self.transformer(x) - logits = self.head(x[:, 0]) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) - return loss - else: - return logits, attn_weights - - def load_from(self, weights): - with torch.no_grad(): - if self.zero_head: - nn.init.zeros_(self.head.weight) - nn.init.zeros_(self.head.bias) - else: - self.head.weight.copy_(np2th(weights["head/kernel"]).t()) - self.head.bias.copy_(np2th(weights["head/bias"]).t()) - - self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) - self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) - self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) - self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) - self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) - - posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) - posemb_new = self.transformer.embeddings.position_embeddings - if posemb.size() == posemb_new.size(): - self.transformer.embeddings.position_embeddings.copy_(posemb) - else: - print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) - ntok_new = posemb_new.size(1) - - if self.classifier == "token": - posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] - ntok_new -= 1 - else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] - - gs_old = int(np.sqrt(len(posemb_grid))) - gs_new = int(np.sqrt(ntok_new)) - print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) - posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) - - zoom = (gs_new / gs_old, gs_new / gs_old, 1) - posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) - posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) - posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) - self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) - - for bname, block in self.transformer.encoder.named_children(): - for uname, unit in block.named_children(): - unit.load_from(weights, n_block=uname) - - if self.transformer.embeddings.hybrid: - self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) - gn_weight = np2th(weights["gn_root/scale"]).view(-1) - gn_bias = np2th(weights["gn_root/bias"]).view(-1) - self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) - self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) - - for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): - for uname, unit in block.named_children(): - unit.load_from(weights, n_block=bname, n_unit=uname) - - -#VIT FG层 class VisionTransformer(nn.Module): def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False): super(VisionTransformer, self).__init__() @@ -393,7 +302,7 @@ class VisionTransformer(nn.Module): def forward(self, x, labels=None): part_tokens = self.transformer(x) - part_logits = self.part_head(part_tokens[:, 0]) #part部分可以理解是细粒度,它专注于捕捉微小差异。但生物其实不需要这个,因为生物视觉本身就是有part功能的,通过眼球转动调整感受野来完成这一点 + part_logits = self.part_head(part_tokens[:, 0]) if labels is not None: if self.smoothing_value == 0: @@ -456,7 +365,7 @@ class VisionTransformer(nn.Module): for uname, unit in block.named_children(): unit.load_from(weights, n_block=bname, n_unit=uname) -#loss计算 + def con_loss(features, labels): B, _ = features.shape features = F.normalize(features) @@ -470,7 +379,7 @@ def con_loss(features, labels): loss /= (B * B) return loss -#几种VIT模型配置 + CONFIGS = { 'ViT-B_16': configs.get_b16_config(), 'ViT-B_32': configs.get_b32_config(),