commit f23dc227525afac407858578695d8bfb1c0dc134
Author: li chen <770918727@qq.com>
Date: Fri Apr 8 18:13:02 2022 +0800
update
diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..2bdab8a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+.idea/
+ckpts/
+logs/
+models/__pycache__/
+utils/__pycache__/
+output/
+attention_data/
diff --git a/LICENSE b/LICENSE
new file mode 100755
index 0000000..46baa2e
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Ju He
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/TransFG.pdf b/docs/TransFG.pdf
new file mode 100755
index 0000000..08efe99
Binary files /dev/null and b/docs/TransFG.pdf differ
diff --git a/docs/TransFG.png b/docs/TransFG.png
new file mode 100755
index 0000000..3ce1357
Binary files /dev/null and b/docs/TransFG.png differ
diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py
new file mode 100644
index 0000000..15712b1
--- /dev/null
+++ b/ieemoo-ai-isempty.py
@@ -0,0 +1,136 @@
+# -*- coding: utf-8 -*-
+from flask import request, Flask
+import numpy as np
+import json
+import time
+import cv2, base64
+import argparse
+import sys, os
+import torch
+from PIL import Image
+from torchvision import transforms
+from models.modeling import VisionTransformer, CONFIGS
+sys.path.insert(0, ".")
+
+
+app = Flask(__name__)
+app.use_reloader=False
+
+
+def parse_args(model_file="ckpts/emptyjudge5_checkpoint.bin"):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
+ parser.add_argument('--split', type=str, default='overlap', help="Split method")
+ parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
+ parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
+ parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model")
+ opt, unknown = parser.parse_known_args()
+ return opt
+
+
+class Predictor(object):
+ def __init__(self, args):
+ self.args = args
+ self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(self.args.device)
+ self.args.nprocs = torch.cuda.device_count()
+ self.cls_dict = {}
+ self.num_classes = 0
+ self.model = None
+ self.prepare_model()
+ self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+
+ def prepare_model(self):
+ config = CONFIGS["ViT-B_16"]
+ config.split = self.args.split
+ config.slide_step = self.args.slide_step
+ model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
+ print("use model_name: ", model_name)
+ self.num_classes = 5
+ self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
+ self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
+ if self.args.pretrained_model is not None:
+ if not torch.cuda.is_available():
+ pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
+ self.model.load_state_dict(pretrained_model)
+ else:
+ pretrained_model = torch.load(self.args.pretrained_model)['model']
+ self.model.load_state_dict(pretrained_model)
+ self.model.eval()
+ self.model.to(self.args.device)
+ #self.model.eval()
+
+ def normal_predict(self, img_data, result):
+ # img = Image.open(img_path)
+ if img_data is None:
+ print('error, img data is None')
+ return result
+ else:
+ with torch.no_grad():
+ x = self.test_transform(img_data)
+ if torch.cuda.is_available():
+ x = x.cuda()
+ part_logits = self.model(x.unsqueeze(0))
+ probs = torch.nn.Softmax(dim=-1)(part_logits)
+ topN = torch.argsort(probs, dim=-1, descending=True).tolist()
+ clas_ids = topN[0][0]
+ clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
+ print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
+ result["success"] = "true"
+ result["rst_cls"] = str(clas_ids)
+ return result
+
+
+model_file ="/data/ieemoo/emptypredict_pfc_FG/ckpts/emptyjudge5_checkpoint.bin"
+args = parse_args(model_file)
+predictor = Predictor(args)
+
+
+@app.route("/isempty", methods=['POST'])
+def get_isempty():
+ start = time.time()
+ print('--------------------EmptyPredict-----------------')
+ data = request.get_data()
+ ip = request.remote_addr
+ print('------ ip = %s ------' % ip)
+
+ json_data = json.loads(data.decode("utf-8"))
+ getdateend = time.time()
+ print('get date use time: {0:.2f}s'.format(getdateend - start))
+
+ pic = json_data.get("pic")
+ result = {"success": "false",
+ "rst_cls": '-1',
+ }
+ try:
+ imgdata = base64.b64decode(pic)
+ imgdata_np = np.frombuffer(imgdata, dtype='uint8')
+ img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
+ img_data = Image.fromarray(np.uint8(img_src))
+ result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
+ except:
+ return repr(result)
+
+ return repr(result)
+
+
+if __name__ == "__main__":
+ app.run()
+# app.run("0.0.0.0", port=8083)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/init.sh b/init.sh
new file mode 100644
index 0000000..2037ecb
--- /dev/null
+++ b/init.sh
@@ -0,0 +1,4 @@
+/opt/miniconda3/bin/conda activate ieemoo
+
+/opt/miniconda3/envs/ieemoo/bin/pip install -r requirements.txt
+
diff --git a/models/configs.py b/models/configs.py
new file mode 100755
index 0000000..7fa3acb
--- /dev/null
+++ b/models/configs.py
@@ -0,0 +1,76 @@
+import ml_collections
+
+def get_testing():
+ """Returns a minimal configuration for testing."""
+ config = ml_collections.ConfigDict()
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
+ config.hidden_size = 1
+ config.transformer = ml_collections.ConfigDict()
+ config.transformer.mlp_dim = 1
+ config.transformer.num_heads = 1
+ config.transformer.num_layers = 1
+ config.transformer.attention_dropout_rate = 0.0
+ config.transformer.dropout_rate = 0.1
+ config.classifier = 'token'
+ config.representation_size = None
+ return config
+
+
+def get_b16_config():
+ """Returns the ViT-B/16 configuration."""
+ config = ml_collections.ConfigDict()
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
+ config.split = 'non-overlap'
+ config.slide_step = 12
+ config.hidden_size = 768
+ config.transformer = ml_collections.ConfigDict()
+ config.transformer.mlp_dim = 3072
+ config.transformer.num_heads = 12
+ config.transformer.num_layers = 12
+ config.transformer.attention_dropout_rate = 0.0
+ config.transformer.dropout_rate = 0.1
+ config.classifier = 'token'
+ config.representation_size = None
+ return config
+
+def get_b32_config():
+ """Returns the ViT-B/32 configuration."""
+ config = get_b16_config()
+ config.patches.size = (32, 32)
+ return config
+
+def get_l16_config():
+ """Returns the ViT-L/16 configuration."""
+ config = ml_collections.ConfigDict()
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
+ config.hidden_size = 1024
+ config.transformer = ml_collections.ConfigDict()
+ config.transformer.mlp_dim = 4096
+ config.transformer.num_heads = 16
+ config.transformer.num_layers = 24
+ config.transformer.attention_dropout_rate = 0.0
+ config.transformer.dropout_rate = 0.1
+ config.classifier = 'token'
+ config.representation_size = None
+ return config
+
+def get_l32_config():
+ """Returns the ViT-L/32 configuration."""
+ config = get_l16_config()
+ config.patches.size = (32, 32)
+ return config
+
+def get_h14_config():
+ """Returns the ViT-L/16 configuration."""
+ config = ml_collections.ConfigDict()
+ config.patches = ml_collections.ConfigDict({'size': (14, 14)})
+ config.hidden_size = 1280
+ config.transformer = ml_collections.ConfigDict()
+ config.transformer.mlp_dim = 5120
+ config.transformer.num_heads = 16
+ config.transformer.num_layers = 32
+ config.transformer.attention_dropout_rate = 0.0
+ config.transformer.dropout_rate = 0.1
+ config.classifier = 'token'
+ config.representation_size = None
+ return config
diff --git a/models/modeling.py b/models/modeling.py
new file mode 100755
index 0000000..fc3082b
--- /dev/null
+++ b/models/modeling.py
@@ -0,0 +1,390 @@
+# coding=utf-8
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import logging
+import math
+
+from os.path import join as pjoin
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
+from torch.nn.modules.utils import _pair
+from scipy import ndimage
+
+import models.configs as configs
+
+logger = logging.getLogger(__name__)
+
+ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
+ATTENTION_K = "MultiHeadDotProductAttention_1/key"
+ATTENTION_V = "MultiHeadDotProductAttention_1/value"
+ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
+FC_0 = "MlpBlock_3/Dense_0"
+FC_1 = "MlpBlock_3/Dense_1"
+ATTENTION_NORM = "LayerNorm_0"
+MLP_NORM = "LayerNorm_2"
+
+def np2th(weights, conv=False):
+ """Possibly convert HWIO to OIHW."""
+ if conv:
+ weights = weights.transpose([3, 2, 0, 1])
+ return torch.from_numpy(weights)
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
+
+class LabelSmoothing(nn.Module):
+ """
+ NLL loss with label smoothing.
+ """
+ def __init__(self, smoothing=0.0):
+ """
+ Constructor for the LabelSmoothing module.
+ :param smoothing: label smoothing factor
+ """
+ super(LabelSmoothing, self).__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+
+ def forward(self, x, target):
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
+
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
+ nll_loss = nll_loss.squeeze(1)
+ smooth_loss = -logprobs.mean(dim=-1)
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
+ return loss.mean()
+
+class Attention(nn.Module):
+ def __init__(self, config):
+ super(Attention, self).__init__()
+ self.num_attention_heads = config.transformer["num_heads"]
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = Linear(config.hidden_size, self.all_head_size)
+ self.key = Linear(config.hidden_size, self.all_head_size)
+ self.value = Linear(config.hidden_size, self.all_head_size)
+
+ self.out = Linear(config.hidden_size, config.hidden_size)
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
+
+ self.softmax = Softmax(dim=-1)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ attention_probs = self.softmax(attention_scores)
+ weights = attention_probs
+ attention_probs = self.attn_dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ attention_output = self.out(context_layer)
+ attention_output = self.proj_dropout(attention_output)
+ return attention_output, weights
+
+class Mlp(nn.Module):
+ def __init__(self, config):
+ super(Mlp, self).__init__()
+ self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
+ self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
+ self.act_fn = ACT2FN["gelu"]
+ self.dropout = Dropout(config.transformer["dropout_rate"])
+
+ self._init_weights()
+
+ def _init_weights(self):
+ nn.init.xavier_uniform_(self.fc1.weight)
+ nn.init.xavier_uniform_(self.fc2.weight)
+ nn.init.normal_(self.fc1.bias, std=1e-6)
+ nn.init.normal_(self.fc2.bias, std=1e-6)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+class Embeddings(nn.Module):
+ """Construct the embeddings from patch, position embeddings.
+ """
+ def __init__(self, config, img_size, in_channels=3):
+ super(Embeddings, self).__init__()
+ self.hybrid = None
+ img_size = _pair(img_size)
+
+ patch_size = _pair(config.patches["size"])
+ if config.split == 'non-overlap':
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
+ out_channels=config.hidden_size,
+ kernel_size=patch_size,
+ stride=patch_size)
+ elif config.split == 'overlap':
+ n_patches = ((img_size[0] - patch_size[0]) // config.slide_step + 1) * ((img_size[1] - patch_size[1]) // config.slide_step + 1)
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
+ out_channels=config.hidden_size,
+ kernel_size=patch_size,
+ stride=(config.slide_step, config.slide_step))
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+
+ self.dropout = Dropout(config.transformer["dropout_rate"])
+
+ def forward(self, x):
+ B = x.shape[0]
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ if self.hybrid:
+ x = self.hybrid_model(x)
+ x = self.patch_embeddings(x)
+ x = x.flatten(2)
+ x = x.transpose(-1, -2)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ embeddings = x + self.position_embeddings
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+class Block(nn.Module):
+ def __init__(self, config):
+ super(Block, self).__init__()
+ self.hidden_size = config.hidden_size
+ self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
+ self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
+ self.ffn = Mlp(config)
+ self.attn = Attention(config)
+
+ def forward(self, x):
+ h = x
+ x = self.attention_norm(x)
+ x, weights = self.attn(x)
+ x = x + h
+
+ h = x
+ x = self.ffn_norm(x)
+ x = self.ffn(x)
+ x = x + h
+ return x, weights
+
+ def load_from(self, weights, n_block):
+ ROOT = f"Transformer/encoderblock_{n_block}"
+ with torch.no_grad():
+ query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
+ key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
+ value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
+ out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
+
+ query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
+ key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
+ value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
+ out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
+
+ self.attn.query.weight.copy_(query_weight)
+ self.attn.key.weight.copy_(key_weight)
+ self.attn.value.weight.copy_(value_weight)
+ self.attn.out.weight.copy_(out_weight)
+ self.attn.query.bias.copy_(query_bias)
+ self.attn.key.bias.copy_(key_bias)
+ self.attn.value.bias.copy_(value_bias)
+ self.attn.out.bias.copy_(out_bias)
+
+ mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
+ mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
+ mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
+ mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
+
+ self.ffn.fc1.weight.copy_(mlp_weight_0)
+ self.ffn.fc2.weight.copy_(mlp_weight_1)
+ self.ffn.fc1.bias.copy_(mlp_bias_0)
+ self.ffn.fc2.bias.copy_(mlp_bias_1)
+
+ self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
+ self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
+ self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
+ self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
+
+
+class Part_Attention(nn.Module):
+ def __init__(self):
+ super(Part_Attention, self).__init__()
+
+ def forward(self, x):
+ length = len(x)
+ last_map = x[0]
+ for i in range(1, length):
+ last_map = torch.matmul(x[i], last_map)
+ last_map = last_map[:,:,0,1:]
+
+ _, max_inx = last_map.max(2)
+ return _, max_inx
+
+
+class Encoder(nn.Module):
+ def __init__(self, config):
+ super(Encoder, self).__init__()
+ self.layer = nn.ModuleList()
+ for _ in range(config.transformer["num_layers"] - 1):
+ layer = Block(config)
+ self.layer.append(copy.deepcopy(layer))
+ self.part_select = Part_Attention()
+ self.part_layer = Block(config)
+ self.part_norm = LayerNorm(config.hidden_size, eps=1e-6)
+
+ def forward(self, hidden_states):
+ attn_weights = []
+ for layer in self.layer:
+ hidden_states, weights = layer(hidden_states)
+ attn_weights.append(weights)
+ part_num, part_inx = self.part_select(attn_weights)
+ part_inx = part_inx + 1
+ parts = []
+ B, num = part_inx.shape
+ for i in range(B):
+ parts.append(hidden_states[i, part_inx[i,:]])
+ parts = torch.stack(parts).squeeze(1)
+ concat = torch.cat((hidden_states[:,0].unsqueeze(1), parts), dim=1)
+ part_states, part_weights = self.part_layer(concat)
+ part_encoded = self.part_norm(part_states)
+
+ return part_encoded
+
+
+class Transformer(nn.Module):
+ def __init__(self, config, img_size):
+ super(Transformer, self).__init__()
+ self.embeddings = Embeddings(config, img_size=img_size)
+ self.encoder = Encoder(config)
+
+ def forward(self, input_ids):
+ embedding_output = self.embeddings(input_ids)
+ part_encoded = self.encoder(embedding_output)
+ return part_encoded
+
+
+class VisionTransformer(nn.Module):
+ def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
+ super(VisionTransformer, self).__init__()
+ self.num_classes = num_classes
+ self.smoothing_value = smoothing_value
+ self.zero_head = zero_head
+ self.classifier = config.classifier
+ self.transformer = Transformer(config, img_size)
+ self.part_head = Linear(config.hidden_size, num_classes)
+
+ def forward(self, x, labels=None):
+ part_tokens = self.transformer(x)
+ part_logits = self.part_head(part_tokens[:, 0])
+
+ if labels is not None:
+ if self.smoothing_value == 0:
+ loss_fct = CrossEntropyLoss()
+ else:
+ loss_fct = LabelSmoothing(self.smoothing_value)
+ part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
+ contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1))
+ loss = part_loss + contrast_loss
+ return loss, part_logits
+ else:
+ return part_logits
+
+ def load_from(self, weights):
+ with torch.no_grad():
+ self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
+ self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
+ self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
+ self.transformer.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
+ self.transformer.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
+
+ posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
+ posemb_new = self.transformer.embeddings.position_embeddings
+ if posemb.size() == posemb_new.size():
+ self.transformer.embeddings.position_embeddings.copy_(posemb)
+ else:
+ logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
+ ntok_new = posemb_new.size(1)
+
+ if self.classifier == "token":
+ posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
+ ntok_new -= 1
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+
+ gs_old = int(np.sqrt(len(posemb_grid)))
+ gs_new = int(np.sqrt(ntok_new))
+ print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
+ posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
+
+ zoom = (gs_new / gs_old, gs_new / gs_old, 1)
+ posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
+ posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
+ posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
+ self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
+
+ for bname, block in self.transformer.encoder.named_children():
+ if bname.startswith('part') == False:
+ for uname, unit in block.named_children():
+ unit.load_from(weights, n_block=uname)
+
+ if self.transformer.embeddings.hybrid:
+ self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
+ gn_weight = np2th(weights["gn_root/scale"]).view(-1)
+ gn_bias = np2th(weights["gn_root/bias"]).view(-1)
+ self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
+ self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
+
+ for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
+ for uname, unit in block.named_children():
+ unit.load_from(weights, n_block=bname, n_unit=uname)
+
+
+def con_loss(features, labels):
+ B, _ = features.shape
+ features = F.normalize(features)
+ cos_matrix = features.mm(features.t())
+ pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float()
+ neg_label_matrix = 1 - pos_label_matrix
+ pos_cos_matrix = 1 - cos_matrix
+ neg_cos_matrix = cos_matrix - 0.4
+ neg_cos_matrix[neg_cos_matrix < 0] = 0
+ loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum()
+ loss /= (B * B)
+ return loss
+
+
+CONFIGS = {
+ 'ViT-B_16': configs.get_b16_config(),
+ 'ViT-B_32': configs.get_b32_config(),
+ 'ViT-L_16': configs.get_l16_config(),
+ 'ViT-L_32': configs.get_l32_config(),
+ 'ViT-H_14': configs.get_h14_config(),
+ 'testing': configs.get_testing(),
+}
diff --git a/predict.py b/predict.py
new file mode 100755
index 0000000..f14096f
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,153 @@
+import numpy as np
+import cv2
+import time
+import os
+import argparse
+import torch
+from sklearn.metrics import confusion_matrix
+from sklearn.metrics import f1_score
+from PIL import Image
+from torchvision import transforms
+from models.modeling import VisionTransformer, CONFIGS
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
+ parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
+ parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
+ parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
+ parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
+ return parser.parse_args()
+
+
+class Predictor(object):
+ def __init__(self, args):
+ self.args = args
+ self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print("self.args.device =", self.args.device)
+ self.args.nprocs = torch.cuda.device_count()
+
+ self.cls_dict = {}
+ self.num_classes = 0
+ self.model = None
+ self.prepare_model()
+ self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+
+ def prepare_model(self):
+ config = CONFIGS["ViT-B_16"]
+ config.split = self.args.split
+ config.slide_step = self.args.slide_step
+ model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
+ print("use model_name: ", model_name)
+ if model_name.lower() == "emptyJudge5".lower():
+ self.num_classes = 5
+ self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
+ elif model_name.lower() == "emptyJudge4".lower():
+ self.num_classes = 4
+ self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
+ elif model_name.lower() == "emptyJudge3".lower():
+ self.num_classes = 3
+ self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
+ elif model_name.lower() == "emptyJudge2".lower():
+ self.num_classes = 2
+ self.cls_dict = {0: "noemp", 1: "yesemp"}
+ self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
+ if self.args.pretrained_model is not None:
+ if not torch.cuda.is_available():
+ pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
+ self.model.load_state_dict(pretrained_model)
+ else:
+ pretrained_model = torch.load(self.args.pretrained_model)['model']
+ self.model.load_state_dict(pretrained_model)
+ self.model.to(self.args.device)
+ self.model.eval()
+
+ def normal_predict(self, img_path):
+ # img = cv2.imread(img_path)
+ img = Image.open(img_path)
+ if img is None:
+ print(
+ "Image file failed to read: {}".format(img_path))
+ else:
+ x = self.test_transform(img)
+ if torch.cuda.is_available():
+ x = x.cuda()
+ part_logits = self.model(x.unsqueeze(0))
+ probs = torch.nn.Softmax(dim=-1)(part_logits)
+ topN = torch.argsort(probs, dim=-1, descending=True).tolist()
+ clas_ids = topN[0][0]
+ # print(probs[0, topN[0][0]].item())
+ return clas_ids, probs[0, clas_ids].item()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ predictor = Predictor(args)
+
+ y_true = []
+ y_pred = []
+ test_dir = "/data/pfc/fineGrained/test_5cls"
+ dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
+ total = 0
+ num = 0
+ t0 = time.time()
+ for dir_name, label in dir_dict.items():
+ cur_folder = os.path.join(test_dir, dir_name)
+ errorPath = os.path.join(test_dir, dir_name + "_error")
+ # os.makedirs(errorPath, exist_ok=True)
+ for cur_file in os.listdir(cur_folder):
+ total += 1
+ print("%d processing: %s" % (total, cur_file))
+ cur_img_file = os.path.join(cur_folder, cur_file)
+ error_img_dst = os.path.join(errorPath, cur_file)
+ cur_pred, pred_score = predictor.normal_predict(cur_img_file)
+
+ label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label)
+ cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred)
+ y_true.append(int(label))
+ y_pred.append(int(cur_pred))
+ if int(label) == int(cur_pred):
+ num += 1
+ # else:
+ # print(cur_file, "predict: ", cur_pred, "true: ", int(label))
+ # print(cur_file, "predict: ", cur_pred, "true: ", int(label), "pred_score:", pred_score)
+ # os.system("cp %s %s" % (cur_img_file, error_img_dst))
+ t1 = time.time()
+ print('The cast of time is :%f seconds' % (t1-t0))
+ rate = float(num)/total
+ print('The classification accuracy is %f' % rate)
+
+ rst_C = confusion_matrix(y_true, y_pred)
+ rst_f1 = f1_score(y_true, y_pred, average='macro')
+ print(rst_C)
+ print(rst_f1)
+
+'''
+test_imgs: yesemp=145, noemp=453 大图
+
+output/emptyjudge5_checkpoint.bin
+The classification accuracy is 0.976589
+[[446 7] 1.5%
+ [ 7 138]] 4.8%
+0.968135799649844
+
+output/emptyjudge4_checkpoint.bin
+The classification accuracy is 0.976589
+[[450 3] 0.6%
+ [ 11 134]] 7.5%
+0.9675186616384996
+
+test_5cls: yesemp=319, noemp=925 小图
+
+
+output/emptyjudge4_checkpoint.bin
+The classification accuracy is 0.937299
+[[885 40] 4.3%
+ [ 38 281]] 11.9%
+0.9179586038961038
+
+
+'''
diff --git a/prepara_data.py b/prepara_data.py
new file mode 100755
index 0000000..1d779d5
--- /dev/null
+++ b/prepara_data.py
@@ -0,0 +1,119 @@
+import os
+import cv2
+import numpy as np
+import subprocess
+import random
+
+
+# ----------- 改写名称 --------------
+# index = 0
+# src_dir = "/data/fineGrained/emptyJudge5"
+# dst_dir = src_dir + "_new"
+# os.makedirs(dst_dir, exist_ok=True)
+# for sub in os.listdir(src_dir):
+# sub_path = os.path.join(src_dir, sub)
+# sub_path_dst = os.path.join(dst_dir, sub)
+# os.makedirs(sub_path_dst, exist_ok=True)
+# for cur_f in os.listdir(sub_path):
+# cur_img = os.path.join(sub_path, cur_f)
+# cur_img_dst = os.path.join(sub_path_dst, "a%05d.jpg" % index)
+# index += 1
+# os.system("mv %s %s" % (cur_img, cur_img_dst))
+
+
+# ----------- 删除过小图像 --------------
+# src_dir = "/data/fineGrained/emptyJudge5"
+# for sub in os.listdir(src_dir):
+# sub_path = os.path.join(src_dir, sub)
+# for cur_f in os.listdir(sub_path):
+# filepath = os.path.join(sub_path, cur_f)
+# res = subprocess.check_output(['file', filepath])
+# pp = res.decode("utf-8").split(",")[-2]
+# height = int(pp.split("x")[1])
+# width = int(pp.split("x")[0])
+# min_l = min(height, width)
+# if min_l <= 448:
+# os.system("rm %s" % filepath)
+
+
+# ----------- 获取有效图片并写images.txt --------------
+# src_dir = "/data/fineGrained/emptyJudge4/images"
+# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
+# all_dict = {"yesemp":[], "noemp":[], "hard": [], "stack": []}
+# for sub, value in src_dict.items():
+# sub_path = os.path.join(src_dir, sub)
+# for cur_f in os.listdir(sub_path):
+# all_dict[sub].append(os.path.join(sub, cur_f))
+#
+# yesnum = len(all_dict["yesemp"])
+# nonum = len(all_dict["noemp"])
+# hardnum = len(all_dict["hard"])
+# stacknum = len(all_dict["stack"])
+# thnum = min(yesnum, nonum, hardnum, stacknum)
+# images_txt = src_dir + ".txt"
+# index = 1
+#
+# def write_images(cur_list, thnum, fw, index):
+# for feat_path in random.sample(cur_list, thnum):
+# fw.write(str(index) + " " + feat_path + "\n")
+# index += 1
+# return index
+#
+# with open(images_txt, "w") as fw:
+# index = write_images(all_dict["noemp"], thnum, fw, index)
+# index = write_images(all_dict["yesemp"], thnum, fw, index)
+# index = write_images(all_dict["hard"], thnum, fw, index)
+# index = write_images(all_dict["stack"], thnum, fw, index)
+
+# ----------- 写 image_class_labels.txt + train_test_split.txt --------------
+# src_dir = "/data/fineGrained/emptyJudge4"
+# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
+# images_txt = os.path.join(src_dir, "images.txt")
+# image_class_labels_txt = os.path.join(src_dir, "image_class_labels.txt")
+# imgs_cnt = 0
+# with open(image_class_labels_txt, "w") as fw:
+# with open(images_txt, "r") as fr:
+# for cur_l in fr:
+# imgs_cnt += 1
+# img_index, img_f = cur_l.strip().split(" ")
+# folder_name = img_f.split("/")[0]
+# if folder_name in src_dict:
+# cur_line = img_index + " " + str(int(src_dict[folder_name])+1)
+# fw.write(cur_line + "\n")
+#
+# train_num = int(imgs_cnt*0.85)
+# print("train_num= ", train_num, ", imgs_cnt= ", imgs_cnt)
+# all_list = [1]*train_num + [0]*(imgs_cnt-train_num)
+# assert len(all_list) == imgs_cnt
+# random.shuffle(all_list)
+# train_test_split_txt = os.path.join(src_dir, "train_test_split.txt")
+# with open(train_test_split_txt, "w") as fw:
+# with open(images_txt, "r") as fr:
+# for cur_l in fr:
+# img_index, img_f = cur_l.strip().split(" ")
+# cur_line = img_index + " " + str(all_list[int(img_index) - 1])
+# fw.write(cur_line + "\n")
+
+# ----------- 生成标准测试集 --------------
+# src_dir = "/data/fineGrained/emptyJudge5/images"
+# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
+# all_dict = {"noemp":[], "yesemp":[], "hard": [], "fly": [], "stack": []}
+# for sub, value in src_dict.items():
+# sub_path = os.path.join(src_dir, sub)
+# for cur_f in os.listdir(sub_path):
+# all_dict[sub].append(cur_f)
+#
+# dst_dir = src_dir + "_test"
+# os.makedirs(dst_dir, exist_ok=True)
+# for sub, value in src_dict.items():
+# sub_path = os.path.join(src_dir, sub)
+# sub_path_dst = os.path.join(dst_dir, sub)
+# os.makedirs(sub_path_dst, exist_ok=True)
+#
+# cur_list = all_dict[sub]
+# test_num = int(len(cur_list) * 0.05)
+# for cur_f in random.sample(cur_list, test_num):
+# cur_path = os.path.join(sub_path, cur_f)
+# cur_path_dst = os.path.join(sub_path_dst, cur_f)
+# os.system("cp %s %s" % (cur_path, cur_path_dst))
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000..64d403c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,80 @@
+absl-py==1.0.0
+Bottleneck==1.3.2
+brotlipy==0.7.0
+cachetools==5.0.0
+certifi==2021.10.8
+cffi @ file:///tmp/build/80754af9/cffi_1625807838443/work
+charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
+click==8.0.3
+contextlib2==21.6.0
+cryptography @ file:///tmp/build/80754af9/cryptography_1635366571107/work
+cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
+docopt==0.6.2
+esdk-obs-python==3.21.8
+faiss==1.7.1
+Flask @ file:///tmp/build/80754af9/flask_1634118196080/work
+fonttools==4.25.0
+gevent @ file:///tmp/build/80754af9/gevent_1628273677693/work
+google-auth==2.6.0
+google-auth-oauthlib==0.4.6
+greenlet @ file:///tmp/build/80754af9/greenlet_1628887725296/work
+grpcio==1.44.0
+gunicorn==20.1.0
+h5py @ file:///tmp/build/80754af9/h5py_1637138879700/work
+idna @ file:///tmp/build/80754af9/idna_1637925883363/work
+importlib-metadata==4.11.3
+itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work
+Jinja2 @ file:///tmp/build/80754af9/jinja2_1635780242639/work
+kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work
+Markdown==3.3.6
+MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528148836/work
+matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1638289681807/work
+mkl-fft==1.3.1
+mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work
+mkl-service==2.4.0
+ml-collections==0.1.0
+munkres==1.1.4
+numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work
+numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095647912/work
+oauthlib==3.2.0
+olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work
+opencv-python==4.5.4.60
+packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
+pandas==1.3.4
+Pillow==8.4.0
+pipreqs==0.4.11
+protobuf==3.19.4
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
+pycryptodome==3.10.1
+pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1635333100036/work
+pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work
+PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
+python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
+pytz==2021.3
+PyYAML==6.0
+requests @ file:///tmp/build/80754af9/requests_1629994808627/work
+requests-oauthlib==1.3.1
+rsa==4.8
+scipy @ file:///tmp/build/80754af9/scipy_1630606796110/work
+seaborn @ file:///tmp/build/80754af9/seaborn_1629307859561/work
+sip==4.19.13
+six @ file:///tmp/build/80754af9/six_1623709665295/work
+supervisor==4.2.2
+tensorboard==2.8.0
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+torch==1.8.0
+torchaudio==0.8.0a0+a751e1d
+torchvision==0.9.0
+tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work
+tqdm @ file:///tmp/build/80754af9/tqdm_1635330843403/work
+typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work
+urllib3==1.26.7
+Werkzeug @ file:///tmp/build/80754af9/werkzeug_1635505089296/work
+yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work
+yarg==0.1.9
+zipp==3.7.0
+zope.event==4.5.0
+zope.interface @ file:///tmp/build/80754af9/zope.interface_1625035545636/work
diff --git a/start.sh b/start.sh
new file mode 100644
index 0000000..915018f
--- /dev/null
+++ b/start.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+supervisorctl start ieemoo-ai-isempty
+
diff --git a/stop.sh b/stop.sh
new file mode 100644
index 0000000..f5e9f04
--- /dev/null
+++ b/stop.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+supervisorctl stop ieemoo-ai-isempty
diff --git a/test_single.py b/test_single.py
new file mode 100755
index 0000000..1f1072b
--- /dev/null
+++ b/test_single.py
@@ -0,0 +1,64 @@
+# coding=utf-8
+import os
+import torch
+import numpy as np
+from PIL import Image
+from torchvision import transforms
+import argparse
+from models.modeling import VisionTransformer, CONFIGS
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--dataset", choices=["CUB_200_2011", "emptyJudge5", "emptyJudge4"], default="emptyJudge5", help="Which dataset.")
+parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
+parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
+parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
+parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
+parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
+args = parser.parse_args()
+
+args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+args.nprocs = torch.cuda.device_count()
+
+# Prepare Model
+config = CONFIGS["ViT-B_16"]
+config.split = args.split
+config.slide_step = args.slide_step
+
+cls_dict = {}
+num_classes = 0
+if args.dataset == "emptyJudge5":
+ num_classes = 5
+ cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
+elif args.dataset == "emptyJudge4":
+ num_classes = 4
+ cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
+elif args.dataset == "emptyJudge3":
+ num_classes = 3
+ cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
+elif args.dataset == "emptyJudge2":
+ num_classes = 2
+ cls_dict = {0: "noemp", 1: "yesemp"}
+model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
+if args.pretrained_model is not None:
+ pretrained_model = torch.load(args.pretrained_model, map_location=torch.device('cpu'))['model']
+ model.load_state_dict(pretrained_model)
+model.to(args.device)
+model.eval()
+# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+# transforms.CenterCrop((448, 448)),
+# transforms.ToTensor(),
+# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+img = Image.open("img.jpg")
+x = test_transform(img)
+part_logits = model(x.unsqueeze(0))
+
+probs = torch.nn.Softmax(dim=-1)(part_logits)
+top5 = torch.argsort(probs, dim=-1, descending=True)
+print("Prediction Label\n")
+for idx in top5[0, :5]:
+ print(f'{probs[0, idx.item()]:.5f} : {cls_dict[idx.item()]}', end='\n')
+
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..fb444c7
--- /dev/null
+++ b/train.py
@@ -0,0 +1,379 @@
+# coding=utf-8
+from __future__ import absolute_import, division, print_function
+
+import logging
+import argparse
+import os
+import random
+import numpy as np
+import time
+
+from datetime import timedelta
+
+import torch
+import torch.distributed as dist
+
+from tqdm import tqdm
+from torch.utils.tensorboard import SummaryWriter
+
+from models.modeling import VisionTransformer, CONFIGS
+from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
+from utils.data_utils import get_loader
+from utils.dist_util import get_world_size
+
+logger = logging.getLogger(__name__)
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def simple_accuracy(preds, labels):
+ return (preds == labels).mean()
+
+
+def reduce_mean(tensor, nprocs):
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ rt /= nprocs
+ return rt
+
+
+def save_model(args, model):
+ model_to_save = model.module if hasattr(model, 'module') else model
+ model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
+ checkpoint = {
+ 'model': model_to_save.state_dict(),
+ }
+ torch.save(checkpoint, model_checkpoint)
+ logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
+
+
+def setup(args):
+ # Prepare model
+ config = CONFIGS[args.model_type]
+ config.split = args.split
+ config.slide_step = args.slide_step
+
+ if args.dataset == "CUB_200_2011":
+ num_classes = 200
+ elif args.dataset == "car":
+ num_classes = 196
+ elif args.dataset == "nabirds":
+ num_classes = 555
+ elif args.dataset == "dog":
+ num_classes = 120
+ elif args.dataset == "INat2017":
+ num_classes = 5089
+ elif args.dataset == "emptyJudge5":
+ num_classes = 5
+ elif args.dataset == "emptyJudge4":
+ num_classes = 4
+ elif args.dataset == "emptyJudge3":
+ num_classes = 3
+
+ model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
+
+ model.load_from(np.load(args.pretrained_dir))
+ if args.pretrained_model is not None:
+ pretrained_model = torch.load(args.pretrained_model)['model']
+ model.load_state_dict(pretrained_model)
+ model.to(args.device)
+ num_params = count_parameters(model)
+
+ logger.info("{}".format(config))
+ logger.info("Training parameters %s", args)
+ logger.info("Total Parameter: \t%2.1fM" % num_params)
+ return args, model
+
+
+def count_parameters(model):
+ params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return params/1000000
+
+
+def set_seed(args):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.n_gpu > 0:
+ torch.cuda.manual_seed_all(args.seed)
+
+
+def valid(args, model, writer, test_loader, global_step):
+ eval_losses = AverageMeter()
+
+ logger.info("***** Running Validation *****")
+ # logger.info("val Num steps = %d", len(test_loader))
+ # logger.info("val Batch size = %d", args.eval_batch_size)
+
+ model.eval()
+ all_preds, all_label = [], []
+ epoch_iterator = tqdm(test_loader,
+ desc="Validating... (loss=X.X)",
+ bar_format="{l_bar}{r_bar}",
+ dynamic_ncols=True,
+ disable=args.local_rank not in [-1, 0])
+ loss_fct = torch.nn.CrossEntropyLoss()
+ for step, batch in enumerate(epoch_iterator):
+ batch = tuple(t.to(args.device) for t in batch)
+ x, y = batch
+ with torch.no_grad():
+ logits = model(x)
+
+ eval_loss = loss_fct(logits, y)
+ eval_loss = eval_loss.mean()
+ eval_losses.update(eval_loss.item())
+
+ preds = torch.argmax(logits, dim=-1)
+
+ if len(all_preds) == 0:
+ all_preds.append(preds.detach().cpu().numpy())
+ all_label.append(y.detach().cpu().numpy())
+ else:
+ all_preds[0] = np.append(
+ all_preds[0], preds.detach().cpu().numpy(), axis=0
+ )
+ all_label[0] = np.append(
+ all_label[0], y.detach().cpu().numpy(), axis=0
+ )
+ epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
+
+ all_preds, all_label = all_preds[0], all_label[0]
+ accuracy = simple_accuracy(all_preds, all_label)
+ accuracy = torch.tensor(accuracy).to(args.device)
+ # dist.barrier()
+ # val_accuracy = reduce_mean(accuracy, args.nprocs)
+ # val_accuracy = val_accuracy.detach().cpu().numpy()
+ val_accuracy = accuracy.detach().cpu().numpy()
+
+ logger.info("\n")
+ logger.info("Validation Results")
+ logger.info("Global Steps: %d" % global_step)
+ logger.info("Valid Loss: %2.5f" % eval_losses.avg)
+ logger.info("Valid Accuracy: %2.5f" % val_accuracy)
+ if args.local_rank in [-1, 0]:
+ writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
+ return val_accuracy
+
+
+def train(args, model):
+ """ Train the model """
+ if args.local_rank in [-1, 0]:
+ os.makedirs(args.output_dir, exist_ok=True)
+ writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
+
+ args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
+
+ # Prepare dataset
+ train_loader, test_loader = get_loader(args)
+ logger.info("train Num steps = %d", len(train_loader))
+ logger.info("val Num steps = %d", len(test_loader))
+ # Prepare optimizer and scheduler
+ optimizer = torch.optim.SGD(model.parameters(),
+ lr=args.learning_rate,
+ momentum=0.9,
+ weight_decay=args.weight_decay)
+ t_total = args.num_steps
+ if args.decay_type == "cosine":
+ scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
+ else:
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Total optimization steps = %d", args.num_steps)
+ logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
+ args.train_batch_size * args.gradient_accumulation_steps * (
+ torch.distributed.get_world_size() if args.local_rank != -1 else 1))
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
+
+ model.zero_grad()
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
+ losses = AverageMeter()
+ global_step, best_acc = 0, 0
+ start_time = time.time()
+ while True:
+ model.train()
+ epoch_iterator = tqdm(train_loader,
+ desc="Training (X / X Steps) (loss=X.X)",
+ bar_format="{l_bar}{r_bar}",
+ dynamic_ncols=True,
+ disable=args.local_rank not in [-1, 0])
+ all_preds, all_label = [], []
+ for step, batch in enumerate(epoch_iterator):
+ batch = tuple(t.to(args.device) for t in batch)
+ x, y = batch
+
+ loss, logits = model(x, y)
+ loss = loss.mean()
+
+ preds = torch.argmax(logits, dim=-1)
+
+ if len(all_preds) == 0:
+ all_preds.append(preds.detach().cpu().numpy())
+ all_label.append(y.detach().cpu().numpy())
+ else:
+ all_preds[0] = np.append(
+ all_preds[0], preds.detach().cpu().numpy(), axis=0
+ )
+ all_label[0] = np.append(
+ all_label[0], y.detach().cpu().numpy(), axis=0
+ )
+
+ if args.gradient_accumulation_steps > 1:
+ loss = loss / args.gradient_accumulation_steps
+ loss.backward()
+
+ if (step + 1) % args.gradient_accumulation_steps == 0:
+ losses.update(loss.item()*args.gradient_accumulation_steps)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+ scheduler.step()
+ optimizer.step()
+ optimizer.zero_grad()
+ global_step += 1
+
+ epoch_iterator.set_description(
+ "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
+ )
+ if args.local_rank in [-1, 0]:
+ writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
+ writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
+ if global_step % args.eval_every == 0:
+ with torch.no_grad():
+ accuracy = valid(args, model, writer, test_loader, global_step)
+ if args.local_rank in [-1, 0]:
+ if best_acc < accuracy:
+ save_model(args, model)
+ best_acc = accuracy
+ logger.info("best accuracy so far: %f" % best_acc)
+ model.train()
+
+ if global_step % t_total == 0:
+ break
+ all_preds, all_label = all_preds[0], all_label[0]
+ accuracy = simple_accuracy(all_preds, all_label)
+ accuracy = torch.tensor(accuracy).to(args.device)
+ # dist.barrier()
+ # train_accuracy = reduce_mean(accuracy, args.nprocs)
+ # train_accuracy = train_accuracy.detach().cpu().numpy()
+ train_accuracy = accuracy.detach().cpu().numpy()
+ logger.info("train accuracy so far: %f" % train_accuracy)
+ losses.reset()
+ if global_step % t_total == 0:
+ break
+
+ writer.close()
+ logger.info("Best Accuracy: \t%f" % best_acc)
+ logger.info("End Training!")
+ end_time = time.time()
+ logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument("--name", required=True,
+ help="Name of this run. Used for monitoring.")
+ parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
+ default="CUB_200_2011", help="Which dataset.")
+ parser.add_argument('--data_root', type=str, default='/data/fineGrained')
+ parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
+ default="ViT-B_16",help="Which variant to use.")
+ parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
+ help="Where to search for pretrained ViT models.")
+ parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
+ #parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
+ parser.add_argument("--output_dir", default="./output", type=str,
+ help="The output directory where checkpoints will be written.")
+ parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
+ parser.add_argument("--train_batch_size", default=64, type=int,
+ help="Total batch size for training.")
+ parser.add_argument("--eval_batch_size", default=16, type=int,
+ help="Total batch size for eval.")
+ parser.add_argument("--eval_every", default=200, type=int,
+ help="Run prediction on validation set every so many steps."
+ "Will always run one evaluation at the end of training.")
+
+ parser.add_argument("--learning_rate", default=3e-2, type=float,
+ help="The initial learning rate for SGD.")
+ parser.add_argument("--weight_decay", default=0, type=float,
+ help="Weight deay if we apply some.")
+ parser.add_argument("--num_steps", default=8000, type=int, #100000
+ help="Total number of training epochs to perform.")
+ parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
+ help="How to decay the learning rate.")
+ parser.add_argument("--warmup_steps", default=500, type=int,
+ help="Step of training to perform learning rate warmup for.")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
+ help="Max gradient norm.")
+
+ parser.add_argument("--local_rank", type=int, default=-1,
+ help="local_rank for distributed training on gpus")
+ parser.add_argument('--seed', type=int, default=42,
+ help="random seed for initialization")
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
+ parser.add_argument('--fp16', action='store_true',
+ help="Whether to use 16-bit float precision instead of 32-bit")
+ parser.add_argument('--fp16_opt_level', type=str, default='O2',
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html")
+ parser.add_argument('--loss_scale', type=float, default=0,
+ help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
+ "0 (default value): dynamic loss scaling.\n"
+ "Positive power of 2: static loss scaling value.\n")
+
+ parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
+
+ parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
+ parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
+ args = parser.parse_args()
+
+ args.data_root = '{}/{}'.format(args.data_root, args.dataset)
+ # Setup CUDA, GPU & distributed training
+ if args.local_rank == -1:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
+ args.n_gpu = torch.cuda.device_count()
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+ torch.cuda.set_device(args.local_rank)
+ device = torch.device("cuda", args.local_rank)
+ torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60))
+ args.n_gpu = 1
+ args.device = device
+ args.nprocs = torch.cuda.device_count()
+
+ # Setup logging
+ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
+ datefmt='%m/%d/%Y %H:%M:%S',
+ level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
+ (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
+
+ # Set seed
+ set_seed(args)
+
+ # Model & Tokenizer Setup
+ args, model = setup(args)
+ # Training
+ train(args, model)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/utils/autoaugment.py b/utils/autoaugment.py
new file mode 100755
index 0000000..1caf904
--- /dev/null
+++ b/utils/autoaugment.py
@@ -0,0 +1,204 @@
+"""
+Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
+"""
+
+from PIL import Image, ImageEnhance, ImageOps
+import numpy as np
+import random
+
+__all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy']
+
+
+class AutoAugImageNetPolicy(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.policies = [
+ SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
+ SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
+ SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
+ SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
+ SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
+
+ SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
+ SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
+ SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
+ SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
+ SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
+
+ SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
+ SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
+ SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
+ SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
+ SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
+
+ SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
+ SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
+ SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
+ SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
+ SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
+
+ SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
+ SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
+ SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
+ SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor)
+ ]
+
+ def __call__(self, img):
+ policy_idx = random.randint(0, len(self.policies) - 1)
+ return self.policies[policy_idx](img)
+
+ def __repr__(self):
+ return "AutoAugment ImageNet Policy"
+
+
+class AutoAugCIFAR10Policy(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.policies = [
+ SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
+ SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
+ SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
+ SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
+ SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
+
+ SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
+ SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
+ SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
+ SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
+ SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
+
+ SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
+ SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
+ SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
+ SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
+ SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
+
+ SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
+ SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
+ SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
+ SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
+ SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
+
+ SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
+ SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
+ SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
+ SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
+ SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
+ ]
+
+ def __call__(self, img):
+ policy_idx = random.randint(0, len(self.policies) - 1)
+ return self.policies[policy_idx](img)
+
+ def __repr__(self):
+ return "AutoAugment CIFAR10 Policy"
+
+
+class AutoAugSVHNPolicy(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.policies = [
+ SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
+ SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
+ SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
+ SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
+ SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
+
+ SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
+ SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
+ SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
+ SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
+ SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
+
+ SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
+ SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
+ SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
+ SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
+ SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
+
+ SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
+ SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
+ SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
+ SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
+ SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
+
+ SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
+ SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
+ SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
+ SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
+ SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
+ ]
+
+ def __call__(self, img):
+ policy_idx = random.randint(0, len(self.policies) - 1)
+ return self.policies[policy_idx](img)
+
+ def __repr__(self):
+ return "AutoAugment SVHN Policy"
+
+
+class SubPolicy(object):
+ def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
+ ranges = {
+ "shearX": np.linspace(0, 0.3, 10),
+ "shearY": np.linspace(0, 0.3, 10),
+ "translateX": np.linspace(0, 150 / 331, 10),
+ "translateY": np.linspace(0, 150 / 331, 10),
+ "rotate": np.linspace(0, 30, 10),
+ "color": np.linspace(0.0, 0.9, 10),
+ "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
+ "solarize": np.linspace(256, 0, 10),
+ "contrast": np.linspace(0.0, 0.9, 10),
+ "sharpness": np.linspace(0.0, 0.9, 10),
+ "brightness": np.linspace(0.0, 0.9, 10),
+ "autocontrast": [0] * 10,
+ "equalize": [0] * 10,
+ "invert": [0] * 10
+ }
+
+ def rotate_with_fill(img, magnitude):
+ rot = img.convert("RGBA").rotate(magnitude)
+ return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
+
+ func = {
+ "shearX": lambda img, magnitude: img.transform(
+ img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
+ Image.BICUBIC, fillcolor=fillcolor),
+ "shearY": lambda img, magnitude: img.transform(
+ img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
+ Image.BICUBIC, fillcolor=fillcolor),
+ "translateX": lambda img, magnitude: img.transform(
+ img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
+ fillcolor=fillcolor),
+ "translateY": lambda img, magnitude: img.transform(
+ img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
+ fillcolor=fillcolor),
+ "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
+ # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
+ "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
+ "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
+ "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
+ "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
+ 1 + magnitude * random.choice([-1, 1])),
+ "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
+ 1 + magnitude * random.choice([-1, 1])),
+ "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
+ 1 + magnitude * random.choice([-1, 1])),
+ "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
+ "equalize": lambda img, magnitude: ImageOps.equalize(img),
+ "invert": lambda img, magnitude: ImageOps.invert(img)
+ }
+
+ # self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
+ # operation1, ranges[operation1][magnitude_idx1],
+ # operation2, ranges[operation2][magnitude_idx2])
+ self.p1 = p1
+ self.operation1 = func[operation1]
+ self.magnitude1 = ranges[operation1][magnitude_idx1]
+ self.p2 = p2
+ self.operation2 = func[operation2]
+ self.magnitude2 = ranges[operation2][magnitude_idx2]
+
+ def __call__(self, img):
+ if random.random() < self.p1:
+ img = self.operation1(img, self.magnitude1)
+ if random.random() < self.p2:
+ img = self.operation2(img, self.magnitude2)
+ return img
diff --git a/utils/data_utils.py b/utils/data_utils.py
new file mode 100755
index 0000000..86df2a9
--- /dev/null
+++ b/utils/data_utils.py
@@ -0,0 +1,135 @@
+import logging
+from PIL import Image
+import os
+
+import torch
+
+from torchvision import transforms
+from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
+
+from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge
+from .autoaugment import AutoAugImageNetPolicy
+
+logger = logging.getLogger(__name__)
+
+
+def get_loader(args):
+ if args.local_rank not in [-1, 0]:
+ torch.distributed.barrier()
+
+ if args.dataset == 'CUB_200_2011':
+ train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.RandomCrop((448, 448)),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.CenterCrop((448, 448)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ trainset = CUB(root=args.data_root, is_train=True, transform=train_transform)
+ testset = CUB(root=args.data_root, is_train=False, transform=test_transform)
+ elif args.dataset == 'car':
+ trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'),
+ os.path.join(args.data_root,'cars_train'),
+ os.path.join(args.data_root,'devkit/cars_meta.mat'),
+ # cleaned=os.path.join(data_dir,'cleaned.dat'),
+ transform=transforms.Compose([
+ transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.RandomCrop((448, 448)),
+ transforms.RandomHorizontalFlip(),
+ AutoAugImageNetPolicy(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ )
+ testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'),
+ os.path.join(args.data_root,'cars_test'),
+ os.path.join(args.data_root,'devkit/cars_meta.mat'),
+ # cleaned=os.path.join(data_dir,'cleaned_test.dat'),
+ transform=transforms.Compose([
+ transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.CenterCrop((448, 448)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ )
+ elif args.dataset == 'dog':
+ train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.RandomCrop((448, 448)),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.CenterCrop((448, 448)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ trainset = dogs(root=args.data_root,
+ train=True,
+ cropped=False,
+ transform=train_transform,
+ download=False
+ )
+ testset = dogs(root=args.data_root,
+ train=False,
+ cropped=False,
+ transform=test_transform,
+ download=False
+ )
+ elif args.dataset == 'nabirds':
+ train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.RandomCrop((448, 448)),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.CenterCrop((448, 448)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ trainset = NABirds(root=args.data_root, train=True, transform=train_transform)
+ testset = NABirds(root=args.data_root, train=False, transform=test_transform)
+ elif args.dataset == 'INat2017':
+ train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
+ transforms.RandomCrop((304, 304)),
+ transforms.RandomHorizontalFlip(),
+ AutoAugImageNetPolicy(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
+ transforms.CenterCrop((304, 304)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ trainset = INat2017(args.data_root, 'train', train_transform)
+ testset = INat2017(args.data_root, 'val', test_transform)
+ elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
+ train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ transforms.RandomCrop((448, 448)),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ # test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
+ # transforms.CenterCrop((448, 448)),
+ # transforms.ToTensor(),
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+ trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
+ testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform)
+
+ if args.local_rank == 0:
+ torch.distributed.barrier()
+
+ train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
+ test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset)
+ train_loader = DataLoader(trainset,
+ sampler=train_sampler,
+ batch_size=args.train_batch_size,
+ num_workers=4,
+ drop_last=True,
+ pin_memory=True)
+ test_loader = DataLoader(testset,
+ sampler=test_sampler,
+ batch_size=args.eval_batch_size,
+ num_workers=4,
+ pin_memory=True) if testset is not None else None
+
+ return train_loader, test_loader
diff --git a/utils/dataset.py b/utils/dataset.py
new file mode 100755
index 0000000..7a06567
--- /dev/null
+++ b/utils/dataset.py
@@ -0,0 +1,629 @@
+import os
+import json
+from os.path import join
+
+import numpy as np
+import scipy
+from scipy import io
+import scipy.misc
+from PIL import Image
+import pandas as pd
+import matplotlib.pyplot as plt
+
+import torch
+from torch.utils.data import Dataset
+from torchvision.datasets import VisionDataset
+from torchvision.datasets.folder import default_loader
+from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
+
+
+class emptyJudge():
+ def __init__(self, root, is_train=True, data_len=None, transform=None):
+ self.root = root
+ self.is_train = is_train
+ self.transform = transform
+ img_txt_file = open(os.path.join(self.root, 'images.txt'))
+ label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
+ train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
+ img_name_list = []
+ for line in img_txt_file:
+ img_name_list.append(line[:-1].split(' ')[-1])
+ label_list = []
+ for line in label_txt_file:
+ label_list.append(int(line[:-1].split(' ')[-1]) - 1)
+ train_test_list = []
+ for line in train_val_file:
+ train_test_list.append(int(line[:-1].split(' ')[-1]))
+ train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
+ test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
+ if self.is_train:
+ self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
+ train_file_list[:data_len]]
+ self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
+ self.train_imgname = [x for x in train_file_list[:data_len]]
+ if not self.is_train:
+ self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
+ test_file_list[:data_len]]
+ self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
+ self.test_imgname = [x for x in test_file_list[:data_len]]
+
+ def __getitem__(self, index):
+ if self.is_train:
+ img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
+ if len(img.shape) == 2:
+ img = np.stack([img] * 3, 2)
+ img = Image.fromarray(img, mode='RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+ else:
+ img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
+ if len(img.shape) == 2:
+ img = np.stack([img] * 3, 2)
+ img = Image.fromarray(img, mode='RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+ return img, target
+
+ def __len__(self):
+ if self.is_train:
+ return len(self.train_label)
+ else:
+ return len(self.test_label)
+
+
+class CUB():
+ def __init__(self, root, is_train=True, data_len=None, transform=None):
+ self.root = root
+ self.is_train = is_train
+ self.transform = transform
+ img_txt_file = open(os.path.join(self.root, 'images.txt'))
+ label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
+ train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
+ img_name_list = []
+ for line in img_txt_file:
+ img_name_list.append(line[:-1].split(' ')[-1])
+ label_list = []
+ for line in label_txt_file:
+ label_list.append(int(line[:-1].split(' ')[-1]) - 1)
+ train_test_list = []
+ for line in train_val_file:
+ train_test_list.append(int(line[:-1].split(' ')[-1]))
+ train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
+ test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
+ if self.is_train:
+ self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
+ train_file_list[:data_len]]
+ self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
+ self.train_imgname = [x for x in train_file_list[:data_len]]
+ if not self.is_train:
+ self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
+ test_file_list[:data_len]]
+ self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
+ self.test_imgname = [x for x in test_file_list[:data_len]]
+
+ def __getitem__(self, index):
+ if self.is_train:
+ img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
+ if len(img.shape) == 2:
+ img = np.stack([img] * 3, 2)
+ img = Image.fromarray(img, mode='RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+ else:
+ img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
+ if len(img.shape) == 2:
+ img = np.stack([img] * 3, 2)
+ img = Image.fromarray(img, mode='RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+
+ return img, target
+
+ def __len__(self):
+ if self.is_train:
+ return len(self.train_label)
+ else:
+ return len(self.test_label)
+
+
+class CarsDataset(Dataset):
+ def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None):
+ """
+ Args:
+ mat_anno (string): Path to the MATLAB annotation file.
+ data_dir (string): Directory with all the images.
+ transform (callable, optional): Optional transform to be applied
+ on a sample.
+ """
+
+ self.full_data_set = io.loadmat(mat_anno)
+ self.car_annotations = self.full_data_set['annotations']
+ self.car_annotations = self.car_annotations[0]
+
+ if cleaned is not None:
+ cleaned_annos = []
+ print("Cleaning up data set (only take pics with rgb chans)...")
+ clean_files = np.loadtxt(cleaned, dtype=str)
+ for c in self.car_annotations:
+ if c[-1][0] in clean_files:
+ cleaned_annos.append(c)
+ self.car_annotations = cleaned_annos
+
+ self.car_names = scipy.io.loadmat(car_names)['class_names']
+ self.car_names = np.array(self.car_names[0])
+
+ self.data_dir = data_dir
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.car_annotations)
+
+ def __getitem__(self, idx):
+ img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0])
+ image = Image.open(img_name).convert('RGB')
+ car_class = self.car_annotations[idx][-2][0][0]
+ car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1
+ assert car_class < 196
+
+ if self.transform:
+ image = self.transform(image)
+
+ # return image, car_class, img_name
+ return image, car_class
+
+ def map_class(self, id):
+ id = np.ravel(id)
+ ret = self.car_names[id - 1][0][0]
+ return ret
+
+ def show_batch(self, img_batch, class_batch):
+
+ for i in range(img_batch.shape[0]):
+ ax = plt.subplot(1, img_batch.shape[0], i + 1)
+ title_str = self.map_class(int(class_batch[i]))
+ img = np.transpose(img_batch[i, ...], (1, 2, 0))
+ ax.imshow(img)
+ ax.set_title(title_str.__str__(), {'fontsize': 5})
+ plt.tight_layout()
+
+
+def make_dataset(dir, image_ids, targets):
+ assert(len(image_ids) == len(targets))
+ images = []
+ dir = os.path.expanduser(dir)
+ for i in range(len(image_ids)):
+ item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i])
+ images.append(item)
+ return images
+
+
+def find_classes(classes_file):
+ # read classes file, separating out image IDs and class names
+ image_ids = []
+ targets = []
+ f = open(classes_file, 'r')
+ for line in f:
+ split_line = line.split(' ')
+ image_ids.append(split_line[0])
+ targets.append(' '.join(split_line[1:]))
+ f.close()
+
+ # index class names
+ classes = np.unique(targets)
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ targets = [class_to_idx[c] for c in targets]
+ return (image_ids, targets, classes, class_to_idx)
+
+
+class dogs(Dataset):
+ """`Stanford Dogs `_ Dataset.
+ Args:
+ root (string): Root directory of dataset where directory
+ ``omniglot-py`` exists.
+ cropped (bool, optional): If true, the images will be cropped into the bounding box specified
+ in the annotations
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset tar files from the internet and
+ puts it in root directory. If the tar files are already downloaded, they are not
+ downloaded again.
+ """
+ folder = 'dog'
+ download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
+
+ def __init__(self,
+ root,
+ train=True,
+ cropped=False,
+ transform=None,
+ target_transform=None,
+ download=False):
+
+ # self.root = join(os.path.expanduser(root), self.folder)
+ self.root = root
+ self.train = train
+ self.cropped = cropped
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if download:
+ self.download()
+
+ split = self.load_split()
+
+ self.images_folder = join(self.root, 'Images')
+ self.annotations_folder = join(self.root, 'Annotation')
+ self._breeds = list_dir(self.images_folder)
+
+ if self.cropped:
+ self._breed_annotations = [[(annotation, box, idx)
+ for box in self.get_boxes(join(self.annotations_folder, annotation))]
+ for annotation, idx in split]
+ self._flat_breed_annotations = sum(self._breed_annotations, [])
+
+ self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
+ else:
+ self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]
+
+ self._flat_breed_images = self._breed_images
+
+ self.classes = ["Chihuaha",
+ "Japanese Spaniel",
+ "Maltese Dog",
+ "Pekinese",
+ "Shih-Tzu",
+ "Blenheim Spaniel",
+ "Papillon",
+ "Toy Terrier",
+ "Rhodesian Ridgeback",
+ "Afghan Hound",
+ "Basset Hound",
+ "Beagle",
+ "Bloodhound",
+ "Bluetick",
+ "Black-and-tan Coonhound",
+ "Walker Hound",
+ "English Foxhound",
+ "Redbone",
+ "Borzoi",
+ "Irish Wolfhound",
+ "Italian Greyhound",
+ "Whippet",
+ "Ibizian Hound",
+ "Norwegian Elkhound",
+ "Otterhound",
+ "Saluki",
+ "Scottish Deerhound",
+ "Weimaraner",
+ "Staffordshire Bullterrier",
+ "American Staffordshire Terrier",
+ "Bedlington Terrier",
+ "Border Terrier",
+ "Kerry Blue Terrier",
+ "Irish Terrier",
+ "Norfolk Terrier",
+ "Norwich Terrier",
+ "Yorkshire Terrier",
+ "Wirehaired Fox Terrier",
+ "Lakeland Terrier",
+ "Sealyham Terrier",
+ "Airedale",
+ "Cairn",
+ "Australian Terrier",
+ "Dandi Dinmont",
+ "Boston Bull",
+ "Miniature Schnauzer",
+ "Giant Schnauzer",
+ "Standard Schnauzer",
+ "Scotch Terrier",
+ "Tibetan Terrier",
+ "Silky Terrier",
+ "Soft-coated Wheaten Terrier",
+ "West Highland White Terrier",
+ "Lhasa",
+ "Flat-coated Retriever",
+ "Curly-coater Retriever",
+ "Golden Retriever",
+ "Labrador Retriever",
+ "Chesapeake Bay Retriever",
+ "German Short-haired Pointer",
+ "Vizsla",
+ "English Setter",
+ "Irish Setter",
+ "Gordon Setter",
+ "Brittany",
+ "Clumber",
+ "English Springer Spaniel",
+ "Welsh Springer Spaniel",
+ "Cocker Spaniel",
+ "Sussex Spaniel",
+ "Irish Water Spaniel",
+ "Kuvasz",
+ "Schipperke",
+ "Groenendael",
+ "Malinois",
+ "Briard",
+ "Kelpie",
+ "Komondor",
+ "Old English Sheepdog",
+ "Shetland Sheepdog",
+ "Collie",
+ "Border Collie",
+ "Bouvier des Flandres",
+ "Rottweiler",
+ "German Shepard",
+ "Doberman",
+ "Miniature Pinscher",
+ "Greater Swiss Mountain Dog",
+ "Bernese Mountain Dog",
+ "Appenzeller",
+ "EntleBucher",
+ "Boxer",
+ "Bull Mastiff",
+ "Tibetan Mastiff",
+ "French Bulldog",
+ "Great Dane",
+ "Saint Bernard",
+ "Eskimo Dog",
+ "Malamute",
+ "Siberian Husky",
+ "Affenpinscher",
+ "Basenji",
+ "Pug",
+ "Leonberg",
+ "Newfoundland",
+ "Great Pyrenees",
+ "Samoyed",
+ "Pomeranian",
+ "Chow",
+ "Keeshond",
+ "Brabancon Griffon",
+ "Pembroke",
+ "Cardigan",
+ "Toy Poodle",
+ "Miniature Poodle",
+ "Standard Poodle",
+ "Mexican Hairless",
+ "Dingo",
+ "Dhole",
+ "African Hunting Dog"]
+
+ def __len__(self):
+ return len(self._flat_breed_images)
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is index of the target character class.
+ """
+ image_name, target_class = self._flat_breed_images[index]
+ image_path = join(self.images_folder, image_name)
+ image = Image.open(image_path).convert('RGB')
+
+ if self.cropped:
+ image = image.crop(self._flat_breed_annotations[index][1])
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ target_class = self.target_transform(target_class)
+
+ return image, target_class
+
+ def download(self):
+ import tarfile
+
+ if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
+ if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
+ print('Files already downloaded and verified')
+ return
+
+ for filename in ['images', 'annotation', 'lists']:
+ tar_filename = filename + '.tar'
+ url = self.download_url_prefix + '/' + tar_filename
+ download_url(url, self.root, tar_filename, None)
+ print('Extracting downloaded file: ' + join(self.root, tar_filename))
+ with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
+ tar_file.extractall(self.root)
+ os.remove(join(self.root, tar_filename))
+
+ @staticmethod
+ def get_boxes(path):
+ import xml.etree.ElementTree
+ e = xml.etree.ElementTree.parse(path).getroot()
+ boxes = []
+ for objs in e.iter('object'):
+ boxes.append([int(objs.find('bndbox').find('xmin').text),
+ int(objs.find('bndbox').find('ymin').text),
+ int(objs.find('bndbox').find('xmax').text),
+ int(objs.find('bndbox').find('ymax').text)])
+ return boxes
+
+ def load_split(self):
+ if self.train:
+ split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
+ labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
+ else:
+ split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
+ labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
+
+ split = [item[0][0] for item in split]
+ labels = [item[0]-1 for item in labels]
+ return list(zip(split, labels))
+
+ def stats(self):
+ counts = {}
+ for index in range(len(self._flat_breed_images)):
+ image_name, target_class = self._flat_breed_images[index]
+ if target_class not in counts.keys():
+ counts[target_class] = 1
+ else:
+ counts[target_class] += 1
+
+ print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))
+ return counts
+
+
+class NABirds(Dataset):
+ """`NABirds `_ Dataset.
+
+ Args:
+ root (string): Root directory of the dataset.
+ train (bool, optional): If True, creates dataset from training set, otherwise
+ creates from test set.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+ base_folder = 'nabirds/images'
+
+ def __init__(self, root, train=True, transform=None):
+ dataset_path = os.path.join(root, 'nabirds')
+ self.root = root
+ self.loader = default_loader
+ self.train = train
+ self.transform = transform
+
+ image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
+ sep=' ', names=['img_id', 'filepath'])
+ image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
+ sep=' ', names=['img_id', 'target'])
+ # Since the raw labels are non-continuous, map them to new ones
+ self.label_map = get_continuous_class_map(image_class_labels['target'])
+ train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ data = image_paths.merge(image_class_labels, on='img_id')
+ self.data = data.merge(train_test_split, on='img_id')
+ # Load in the train / test split
+ if self.train:
+ self.data = self.data[self.data.is_training_img == 1]
+ else:
+ self.data = self.data[self.data.is_training_img == 0]
+
+ # Load in the class data
+ self.class_names = load_class_names(dataset_path)
+ self.class_hierarchy = load_hierarchy(dataset_path)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ sample = self.data.iloc[idx]
+ path = os.path.join(self.root, self.base_folder, sample.filepath)
+ target = self.label_map[sample.target]
+ img = self.loader(path)
+
+ if self.transform is not None:
+ img = self.transform(img)
+ return img, target
+
+
+def get_continuous_class_map(class_labels):
+ label_set = set(class_labels)
+ return {k: i for i, k in enumerate(label_set)}
+
+
+def load_class_names(dataset_path=''):
+ names = {}
+
+ with open(os.path.join(dataset_path, 'classes.txt')) as f:
+ for line in f:
+ pieces = line.strip().split()
+ class_id = pieces[0]
+ names[class_id] = ' '.join(pieces[1:])
+
+ return names
+
+
+def load_hierarchy(dataset_path=''):
+ parents = {}
+
+ with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
+ for line in f:
+ pieces = line.strip().split()
+ child_id, parent_id = pieces
+ parents[child_id] = parent_id
+
+ return parents
+
+
+class INat2017(VisionDataset):
+ """`iNaturalist 2017 `_ Dataset.
+ Args:
+ root (string): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``train``, or ``val``.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+ base_folder = 'train_val_images/'
+ file_list = {
+ 'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
+ 'train_val_images.tar.gz',
+ '7c784ea5e424efaec655bd392f87301f'),
+ 'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
+ 'train_val2017.zip',
+ '444c835f6459867ad69fcb36478786e7')
+ }
+
+ def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
+ super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
+ self.loader = default_loader
+ self.split = verify_str_arg(split, "split", ("train", "val",))
+
+ if self._check_exists():
+ print('Files already downloaded and verified.')
+ elif download:
+ if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
+ and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
+ print('Downloading...')
+ self._download()
+ print('Extracting...')
+ extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
+ extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
+ else:
+ raise RuntimeError(
+ 'Dataset not found. You can use download=True to download it.')
+ anno_filename = split + '2017.json'
+ with open(os.path.join(self.root, anno_filename), 'r') as fp:
+ all_annos = json.load(fp)
+
+ self.annos = all_annos['annotations']
+ self.images = all_annos['images']
+
+ def __getitem__(self, index):
+ path = os.path.join(self.root, self.images[index]['file_name'])
+ target = self.annos[index]['category_id']
+
+ image = self.loader(path)
+ if self.transform is not None:
+ image = self.transform(image)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return image, target
+
+ def __len__(self):
+ return len(self.images)
+
+ def _check_exists(self):
+ return os.path.exists(os.path.join(self.root, self.base_folder))
+
+ def _download(self):
+ for url, filename, md5 in self.file_list.values():
+ download_url(url, root=self.root, filename=filename)
+ if not check_integrity(os.path.join(self.root, filename), md5):
+ raise RuntimeError("File not found or corrupted.")
diff --git a/utils/dist_util.py b/utils/dist_util.py
new file mode 100755
index 0000000..ab8c447
--- /dev/null
+++ b/utils/dist_util.py
@@ -0,0 +1,30 @@
+import torch.distributed as dist
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+def is_main_process():
+ return get_rank() == 0
+
+def format_step(step):
+ if isinstance(step, str):
+ return step
+ s = ""
+ if len(step) > 0:
+ s += "Training Epoch: {} ".format(step[0])
+ if len(step) > 1:
+ s += "Training Iteration: {} ".format(step[1])
+ if len(step) > 2:
+ s += "Validation Iteration: {} ".format(step[2])
+ return s
diff --git a/utils/scheduler.py b/utils/scheduler.py
new file mode 100755
index 0000000..9daaf6e
--- /dev/null
+++ b/utils/scheduler.py
@@ -0,0 +1,63 @@
+import logging
+import math
+
+from torch.optim.lr_scheduler import LambdaLR
+
+logger = logging.getLogger(__name__)
+
+class ConstantLRSchedule(LambdaLR):
+ """ Constant learning rate schedule.
+ """
+ def __init__(self, optimizer, last_epoch=-1):
+ super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
+
+
+class WarmupConstantSchedule(LambdaLR):
+ """ Linear warmup and then constant.
+ Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
+ Keeps learning rate schedule equal to 1. after warmup_steps.
+ """
+ def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+ self.warmup_steps = warmup_steps
+ super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1.0, self.warmup_steps))
+ return 1.
+
+
+class WarmupLinearSchedule(LambdaLR):
+ """ Linear warmup and then linear decay.
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
+ Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
+ """
+ def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1, self.warmup_steps))
+ return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
+
+
+class WarmupCosineSchedule(LambdaLR):
+ """ Linear warmup and then cosine decay.
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
+ Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
+ If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
+ """
+ def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ self.cycles = cycles
+ super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1.0, self.warmup_steps))
+ # progress after warmup
+ progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
+ return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))