update models/modeling.py.

This commit is contained in:
Brainway
2022-10-18 03:35:54 +00:00
committed by Gitee
parent accca98d1c
commit 0c2e0dccac

View File

@ -22,35 +22,26 @@ import models.configs as configs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
#多头注意力参数
ATTENTION_Q = "MultiHeadDotProductAttention_1/query" ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key" ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value" ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
#Dense全连接层
FC_0 = "MlpBlock_3/Dense_0" FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1" FC_1 = "MlpBlock_3/Dense_1"
#批归一化曾
ATTENTION_NORM = "LayerNorm_0" ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2" MLP_NORM = "LayerNorm_2"
#numpy转tensor
def np2th(weights, conv=False): def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW.""" """Possibly convert HWIO to OIHW."""
if conv: if conv:
weights = weights.transpose([3, 2, 0, 1]) weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights) return torch.from_numpy(weights)
#swish激活函数
def swish(x): def swish(x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
#gelu激活函数
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
#标签平滑类,用于数据增强
class LabelSmoothing(nn.Module): class LabelSmoothing(nn.Module):
""" """
NLL loss with label smoothing. NLL loss with label smoothing.
@ -73,7 +64,6 @@ class LabelSmoothing(nn.Module):
loss = self.confidence * nll_loss + self.smoothing * smooth_loss loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean() return loss.mean()
#注意力层
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, config): def __init__(self, config):
super(Attention, self).__init__() super(Attention, self).__init__()
@ -119,7 +109,6 @@ class Attention(nn.Module):
attention_output = self.proj_dropout(attention_output) attention_output = self.proj_dropout(attention_output)
return attention_output, weights return attention_output, weights
#全连接层
class Mlp(nn.Module): class Mlp(nn.Module):
def __init__(self, config): def __init__(self, config):
super(Mlp, self).__init__() super(Mlp, self).__init__()
@ -144,7 +133,6 @@ class Mlp(nn.Module):
x = self.dropout(x) x = self.dropout(x)
return x return x
#嵌入编码
class Embeddings(nn.Module): class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings. """Construct the embeddings from patch, position embeddings.
""" """
@ -186,7 +174,6 @@ class Embeddings(nn.Module):
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
#块
class Block(nn.Module): class Block(nn.Module):
def __init__(self, config): def __init__(self, config):
super(Block, self).__init__() super(Block, self).__init__()
@ -245,7 +232,7 @@ class Block(nn.Module):
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
#部分注意力层
class Part_Attention(nn.Module): class Part_Attention(nn.Module):
def __init__(self): def __init__(self):
super(Part_Attention, self).__init__() super(Part_Attention, self).__init__()
@ -260,7 +247,7 @@ class Part_Attention(nn.Module):
_, max_inx = last_map.max(2) _, max_inx = last_map.max(2)
return _, max_inx return _, max_inx
#编码器
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(Encoder, self).__init__() super(Encoder, self).__init__()
@ -290,7 +277,7 @@ class Encoder(nn.Module):
return part_encoded return part_encoded
#Transformer层
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, config, img_size): def __init__(self, config, img_size):
super(Transformer, self).__init__() super(Transformer, self).__init__()
@ -302,85 +289,7 @@ class Transformer(nn.Module):
part_encoded = self.encoder(embedding_output) part_encoded = self.encoder(embedding_output)
return part_encoded return part_encoded
#VIT层
class OldVisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
x, attn_weights = self.transformer(x)
logits = self.head(x[:, 0])
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
def load_from(self, weights):
with torch.no_grad():
if self.zero_head:
nn.init.zeros_(self.head.weight)
nn.init.zeros_(self.head.bias)
else:
self.head.weight.copy_(np2th(weights["head/kernel"]).t())
self.head.bias.copy_(np2th(weights["head/bias"]).t())
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
posemb_new = self.transformer.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
self.transformer.embeddings.position_embeddings.copy_(posemb)
else:
print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)
if self.classifier == "token":
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
for bname, block in self.transformer.encoder.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)
if self.transformer.embeddings.hybrid:
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=bname, n_unit=uname)
#VIT FG层
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False): def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
super(VisionTransformer, self).__init__() super(VisionTransformer, self).__init__()
@ -393,7 +302,7 @@ class VisionTransformer(nn.Module):
def forward(self, x, labels=None): def forward(self, x, labels=None):
part_tokens = self.transformer(x) part_tokens = self.transformer(x)
part_logits = self.part_head(part_tokens[:, 0]) #part部分可以理解是细粒度它专注于捕捉微小差异。但生物其实不需要这个因为生物视觉本身就是有part功能的通过眼球转动调整感受野来完成这一点 part_logits = self.part_head(part_tokens[:, 0])
if labels is not None: if labels is not None:
if self.smoothing_value == 0: if self.smoothing_value == 0:
@ -456,7 +365,7 @@ class VisionTransformer(nn.Module):
for uname, unit in block.named_children(): for uname, unit in block.named_children():
unit.load_from(weights, n_block=bname, n_unit=uname) unit.load_from(weights, n_block=bname, n_unit=uname)
#loss计算
def con_loss(features, labels): def con_loss(features, labels):
B, _ = features.shape B, _ = features.shape
features = F.normalize(features) features = F.normalize(features)
@ -470,7 +379,7 @@ def con_loss(features, labels):
loss /= (B * B) loss /= (B * B)
return loss return loss
#几种VIT模型配置
CONFIGS = { CONFIGS = {
'ViT-B_16': configs.get_b16_config(), 'ViT-B_16': configs.get_b16_config(),
'ViT-B_32': configs.get_b32_config(), 'ViT-B_32': configs.get_b32_config(),