commit f23dc227525afac407858578695d8bfb1c0dc134 Author: li chen <770918727@qq.com> Date: Fri Apr 8 18:13:02 2022 +0800 update diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..2bdab8a --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.idea/ +ckpts/ +logs/ +models/__pycache__/ +utils/__pycache__/ +output/ +attention_data/ diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..46baa2e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Ju He + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/docs/TransFG.pdf b/docs/TransFG.pdf new file mode 100755 index 0000000..08efe99 Binary files /dev/null and b/docs/TransFG.pdf differ diff --git a/docs/TransFG.png b/docs/TransFG.png new file mode 100755 index 0000000..3ce1357 Binary files /dev/null and b/docs/TransFG.png differ diff --git a/ieemoo-ai-isempty.py b/ieemoo-ai-isempty.py new file mode 100644 index 0000000..15712b1 --- /dev/null +++ b/ieemoo-ai-isempty.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +from flask import request, Flask +import numpy as np +import json +import time +import cv2, base64 +import argparse +import sys, os +import torch +from PIL import Image +from torchvision import transforms +from models.modeling import VisionTransformer, CONFIGS +sys.path.insert(0, ".") + + +app = Flask(__name__) +app.use_reloader=False + + +def parse_args(model_file="ckpts/emptyjudge5_checkpoint.bin"): + parser = argparse.ArgumentParser() + parser.add_argument("--img_size", default=448, type=int, help="Resolution size") + parser.add_argument('--split', type=str, default='overlap', help="Split method") + parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") + parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value") + parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model") + opt, unknown = parser.parse_known_args() + return opt + + +class Predictor(object): + def __init__(self, args): + self.args = args + self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(self.args.device) + self.args.nprocs = torch.cuda.device_count() + self.cls_dict = {} + self.num_classes = 0 + self.model = None + self.prepare_model() + self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + + def prepare_model(self): + config = CONFIGS["ViT-B_16"] + config.split = self.args.split + config.slide_step = self.args.slide_step + model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "") + print("use model_name: ", model_name) + self.num_classes = 5 + self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} + self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) + if self.args.pretrained_model is not None: + if not torch.cuda.is_available(): + pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model'] + self.model.load_state_dict(pretrained_model) + else: + pretrained_model = torch.load(self.args.pretrained_model)['model'] + self.model.load_state_dict(pretrained_model) + self.model.eval() + self.model.to(self.args.device) + #self.model.eval() + + def normal_predict(self, img_data, result): + # img = Image.open(img_path) + if img_data is None: + print('error, img data is None') + return result + else: + with torch.no_grad(): + x = self.test_transform(img_data) + if torch.cuda.is_available(): + x = x.cuda() + part_logits = self.model(x.unsqueeze(0)) + probs = torch.nn.Softmax(dim=-1)(part_logits) + topN = torch.argsort(probs, dim=-1, descending=True).tolist() + clas_ids = topN[0][0] + clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1 + print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item())) + result["success"] = "true" + result["rst_cls"] = str(clas_ids) + return result + + +model_file ="/data/ieemoo/emptypredict_pfc_FG/ckpts/emptyjudge5_checkpoint.bin" +args = parse_args(model_file) +predictor = Predictor(args) + + +@app.route("/isempty", methods=['POST']) +def get_isempty(): + start = time.time() + print('--------------------EmptyPredict-----------------') + data = request.get_data() + ip = request.remote_addr + print('------ ip = %s ------' % ip) + + json_data = json.loads(data.decode("utf-8")) + getdateend = time.time() + print('get date use time: {0:.2f}s'.format(getdateend - start)) + + pic = json_data.get("pic") + result = {"success": "false", + "rst_cls": '-1', + } + try: + imgdata = base64.b64decode(pic) + imgdata_np = np.frombuffer(imgdata, dtype='uint8') + img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR) + img_data = Image.fromarray(np.uint8(img_src)) + result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty + except: + return repr(result) + + return repr(result) + + +if __name__ == "__main__": + app.run() +# app.run("0.0.0.0", port=8083) + + + + + + + + + + + + + + + diff --git a/init.sh b/init.sh new file mode 100644 index 0000000..2037ecb --- /dev/null +++ b/init.sh @@ -0,0 +1,4 @@ +/opt/miniconda3/bin/conda activate ieemoo + +/opt/miniconda3/envs/ieemoo/bin/pip install -r requirements.txt + diff --git a/models/configs.py b/models/configs.py new file mode 100755 index 0000000..7fa3acb --- /dev/null +++ b/models/configs.py @@ -0,0 +1,76 @@ +import ml_collections + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.split = 'non-overlap' + config.slide_step = 12 + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + return config + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config diff --git a/models/modeling.py b/models/modeling.py new file mode 100755 index 0000000..fc3082b --- /dev/null +++ b/models/modeling.py @@ -0,0 +1,390 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage + +import models.configs as configs + +logger = logging.getLogger(__name__) + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + +def swish(x): + return x * torch.sigmoid(x) + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + +class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, smoothing=0.0): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + +class Attention(nn.Module): + def __init__(self, config): + super(Attention, self).__init__() + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + img_size = _pair(img_size) + + patch_size = _pair(config.patches["size"]) + if config.split == 'non-overlap': + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + elif config.split == 'overlap': + n_patches = ((img_size[0] - patch_size[0]) // config.slide_step + 1) * ((img_size[1] - patch_size[1]) // config.slide_step + 1) + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=(config.slide_step, config.slide_step)) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + def forward(self, x): + B = x.shape[0] + cls_tokens = self.cls_token.expand(B, -1, -1) + + if self.hybrid: + x = self.hybrid_model(x) + x = self.patch_embeddings(x) + x = x.flatten(2) + x = x.transpose(-1, -2) + x = torch.cat((cls_tokens, x), dim=1) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + +class Block(nn.Module): + def __init__(self, config): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Part_Attention(nn.Module): + def __init__(self): + super(Part_Attention, self).__init__() + + def forward(self, x): + length = len(x) + last_map = x[0] + for i in range(1, length): + last_map = torch.matmul(x[i], last_map) + last_map = last_map[:,:,0,1:] + + _, max_inx = last_map.max(2) + return _, max_inx + + +class Encoder(nn.Module): + def __init__(self, config): + super(Encoder, self).__init__() + self.layer = nn.ModuleList() + for _ in range(config.transformer["num_layers"] - 1): + layer = Block(config) + self.layer.append(copy.deepcopy(layer)) + self.part_select = Part_Attention() + self.part_layer = Block(config) + self.part_norm = LayerNorm(config.hidden_size, eps=1e-6) + + def forward(self, hidden_states): + attn_weights = [] + for layer in self.layer: + hidden_states, weights = layer(hidden_states) + attn_weights.append(weights) + part_num, part_inx = self.part_select(attn_weights) + part_inx = part_inx + 1 + parts = [] + B, num = part_inx.shape + for i in range(B): + parts.append(hidden_states[i, part_inx[i,:]]) + parts = torch.stack(parts).squeeze(1) + concat = torch.cat((hidden_states[:,0].unsqueeze(1), parts), dim=1) + part_states, part_weights = self.part_layer(concat) + part_encoded = self.part_norm(part_states) + + return part_encoded + + +class Transformer(nn.Module): + def __init__(self, config, img_size): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config) + + def forward(self, input_ids): + embedding_output = self.embeddings(input_ids) + part_encoded = self.encoder(embedding_output) + return part_encoded + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.smoothing_value = smoothing_value + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size) + self.part_head = Linear(config.hidden_size, num_classes) + + def forward(self, x, labels=None): + part_tokens = self.transformer(x) + part_logits = self.part_head(part_tokens[:, 0]) + + if labels is not None: + if self.smoothing_value == 0: + loss_fct = CrossEntropyLoss() + else: + loss_fct = LabelSmoothing(self.smoothing_value) + part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1)) + contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1)) + loss = part_loss + contrast_loss + return loss, part_logits + else: + return part_logits + + def load_from(self, weights): + with torch.no_grad(): + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) + self.transformer.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + + if self.classifier == "token": + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + for bname, block in self.transformer.encoder.named_children(): + if bname.startswith('part') == False: + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) + gn_weight = np2th(weights["gn_root/scale"]).view(-1) + gn_bias = np2th(weights["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=bname, n_unit=uname) + + +def con_loss(features, labels): + B, _ = features.shape + features = F.normalize(features) + cos_matrix = features.mm(features.t()) + pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float() + neg_label_matrix = 1 - pos_label_matrix + pos_cos_matrix = 1 - cos_matrix + neg_cos_matrix = cos_matrix - 0.4 + neg_cos_matrix[neg_cos_matrix < 0] = 0 + loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum() + loss /= (B * B) + return loss + + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'testing': configs.get_testing(), +} diff --git a/predict.py b/predict.py new file mode 100755 index 0000000..f14096f --- /dev/null +++ b/predict.py @@ -0,0 +1,153 @@ +import numpy as np +import cv2 +import time +import os +import argparse +import torch +from sklearn.metrics import confusion_matrix +from sklearn.metrics import f1_score +from PIL import Image +from torchvision import transforms +from models.modeling import VisionTransformer, CONFIGS + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--img_size", default=448, type=int, help="Resolution size") + parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap + parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") + parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") + parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model") + return parser.parse_args() + + +class Predictor(object): + def __init__(self, args): + self.args = args + self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("self.args.device =", self.args.device) + self.args.nprocs = torch.cuda.device_count() + + self.cls_dict = {} + self.num_classes = 0 + self.model = None + self.prepare_model() + self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + + def prepare_model(self): + config = CONFIGS["ViT-B_16"] + config.split = self.args.split + config.slide_step = self.args.slide_step + model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "") + print("use model_name: ", model_name) + if model_name.lower() == "emptyJudge5".lower(): + self.num_classes = 5 + self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} + elif model_name.lower() == "emptyJudge4".lower(): + self.num_classes = 4 + self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"} + elif model_name.lower() == "emptyJudge3".lower(): + self.num_classes = 3 + self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"} + elif model_name.lower() == "emptyJudge2".lower(): + self.num_classes = 2 + self.cls_dict = {0: "noemp", 1: "yesemp"} + self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value) + if self.args.pretrained_model is not None: + if not torch.cuda.is_available(): + pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model'] + self.model.load_state_dict(pretrained_model) + else: + pretrained_model = torch.load(self.args.pretrained_model)['model'] + self.model.load_state_dict(pretrained_model) + self.model.to(self.args.device) + self.model.eval() + + def normal_predict(self, img_path): + # img = cv2.imread(img_path) + img = Image.open(img_path) + if img is None: + print( + "Image file failed to read: {}".format(img_path)) + else: + x = self.test_transform(img) + if torch.cuda.is_available(): + x = x.cuda() + part_logits = self.model(x.unsqueeze(0)) + probs = torch.nn.Softmax(dim=-1)(part_logits) + topN = torch.argsort(probs, dim=-1, descending=True).tolist() + clas_ids = topN[0][0] + # print(probs[0, topN[0][0]].item()) + return clas_ids, probs[0, clas_ids].item() + + +if __name__ == "__main__": + args = parse_args() + predictor = Predictor(args) + + y_true = [] + y_pred = [] + test_dir = "/data/pfc/fineGrained/test_5cls" + dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"} + total = 0 + num = 0 + t0 = time.time() + for dir_name, label in dir_dict.items(): + cur_folder = os.path.join(test_dir, dir_name) + errorPath = os.path.join(test_dir, dir_name + "_error") + # os.makedirs(errorPath, exist_ok=True) + for cur_file in os.listdir(cur_folder): + total += 1 + print("%d processing: %s" % (total, cur_file)) + cur_img_file = os.path.join(cur_folder, cur_file) + error_img_dst = os.path.join(errorPath, cur_file) + cur_pred, pred_score = predictor.normal_predict(cur_img_file) + + label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label) + cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred) + y_true.append(int(label)) + y_pred.append(int(cur_pred)) + if int(label) == int(cur_pred): + num += 1 + # else: + # print(cur_file, "predict: ", cur_pred, "true: ", int(label)) + # print(cur_file, "predict: ", cur_pred, "true: ", int(label), "pred_score:", pred_score) + # os.system("cp %s %s" % (cur_img_file, error_img_dst)) + t1 = time.time() + print('The cast of time is :%f seconds' % (t1-t0)) + rate = float(num)/total + print('The classification accuracy is %f' % rate) + + rst_C = confusion_matrix(y_true, y_pred) + rst_f1 = f1_score(y_true, y_pred, average='macro') + print(rst_C) + print(rst_f1) + +''' +test_imgs: yesemp=145, noemp=453 大图 + +output/emptyjudge5_checkpoint.bin +The classification accuracy is 0.976589 +[[446 7] 1.5% + [ 7 138]] 4.8% +0.968135799649844 + +output/emptyjudge4_checkpoint.bin +The classification accuracy is 0.976589 +[[450 3] 0.6% + [ 11 134]] 7.5% +0.9675186616384996 + +test_5cls: yesemp=319, noemp=925 小图 + + +output/emptyjudge4_checkpoint.bin +The classification accuracy is 0.937299 +[[885 40] 4.3% + [ 38 281]] 11.9% +0.9179586038961038 + + +''' diff --git a/prepara_data.py b/prepara_data.py new file mode 100755 index 0000000..1d779d5 --- /dev/null +++ b/prepara_data.py @@ -0,0 +1,119 @@ +import os +import cv2 +import numpy as np +import subprocess +import random + + +# ----------- 改写名称 -------------- +# index = 0 +# src_dir = "/data/fineGrained/emptyJudge5" +# dst_dir = src_dir + "_new" +# os.makedirs(dst_dir, exist_ok=True) +# for sub in os.listdir(src_dir): +# sub_path = os.path.join(src_dir, sub) +# sub_path_dst = os.path.join(dst_dir, sub) +# os.makedirs(sub_path_dst, exist_ok=True) +# for cur_f in os.listdir(sub_path): +# cur_img = os.path.join(sub_path, cur_f) +# cur_img_dst = os.path.join(sub_path_dst, "a%05d.jpg" % index) +# index += 1 +# os.system("mv %s %s" % (cur_img, cur_img_dst)) + + +# ----------- 删除过小图像 -------------- +# src_dir = "/data/fineGrained/emptyJudge5" +# for sub in os.listdir(src_dir): +# sub_path = os.path.join(src_dir, sub) +# for cur_f in os.listdir(sub_path): +# filepath = os.path.join(sub_path, cur_f) +# res = subprocess.check_output(['file', filepath]) +# pp = res.decode("utf-8").split(",")[-2] +# height = int(pp.split("x")[1]) +# width = int(pp.split("x")[0]) +# min_l = min(height, width) +# if min_l <= 448: +# os.system("rm %s" % filepath) + + +# ----------- 获取有效图片并写images.txt -------------- +# src_dir = "/data/fineGrained/emptyJudge4/images" +# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"} +# all_dict = {"yesemp":[], "noemp":[], "hard": [], "stack": []} +# for sub, value in src_dict.items(): +# sub_path = os.path.join(src_dir, sub) +# for cur_f in os.listdir(sub_path): +# all_dict[sub].append(os.path.join(sub, cur_f)) +# +# yesnum = len(all_dict["yesemp"]) +# nonum = len(all_dict["noemp"]) +# hardnum = len(all_dict["hard"]) +# stacknum = len(all_dict["stack"]) +# thnum = min(yesnum, nonum, hardnum, stacknum) +# images_txt = src_dir + ".txt" +# index = 1 +# +# def write_images(cur_list, thnum, fw, index): +# for feat_path in random.sample(cur_list, thnum): +# fw.write(str(index) + " " + feat_path + "\n") +# index += 1 +# return index +# +# with open(images_txt, "w") as fw: +# index = write_images(all_dict["noemp"], thnum, fw, index) +# index = write_images(all_dict["yesemp"], thnum, fw, index) +# index = write_images(all_dict["hard"], thnum, fw, index) +# index = write_images(all_dict["stack"], thnum, fw, index) + +# ----------- 写 image_class_labels.txt + train_test_split.txt -------------- +# src_dir = "/data/fineGrained/emptyJudge4" +# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"} +# images_txt = os.path.join(src_dir, "images.txt") +# image_class_labels_txt = os.path.join(src_dir, "image_class_labels.txt") +# imgs_cnt = 0 +# with open(image_class_labels_txt, "w") as fw: +# with open(images_txt, "r") as fr: +# for cur_l in fr: +# imgs_cnt += 1 +# img_index, img_f = cur_l.strip().split(" ") +# folder_name = img_f.split("/")[0] +# if folder_name in src_dict: +# cur_line = img_index + " " + str(int(src_dict[folder_name])+1) +# fw.write(cur_line + "\n") +# +# train_num = int(imgs_cnt*0.85) +# print("train_num= ", train_num, ", imgs_cnt= ", imgs_cnt) +# all_list = [1]*train_num + [0]*(imgs_cnt-train_num) +# assert len(all_list) == imgs_cnt +# random.shuffle(all_list) +# train_test_split_txt = os.path.join(src_dir, "train_test_split.txt") +# with open(train_test_split_txt, "w") as fw: +# with open(images_txt, "r") as fr: +# for cur_l in fr: +# img_index, img_f = cur_l.strip().split(" ") +# cur_line = img_index + " " + str(all_list[int(img_index) - 1]) +# fw.write(cur_line + "\n") + +# ----------- 生成标准测试集 -------------- +# src_dir = "/data/fineGrained/emptyJudge5/images" +# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"} +# all_dict = {"noemp":[], "yesemp":[], "hard": [], "fly": [], "stack": []} +# for sub, value in src_dict.items(): +# sub_path = os.path.join(src_dir, sub) +# for cur_f in os.listdir(sub_path): +# all_dict[sub].append(cur_f) +# +# dst_dir = src_dir + "_test" +# os.makedirs(dst_dir, exist_ok=True) +# for sub, value in src_dict.items(): +# sub_path = os.path.join(src_dir, sub) +# sub_path_dst = os.path.join(dst_dir, sub) +# os.makedirs(sub_path_dst, exist_ok=True) +# +# cur_list = all_dict[sub] +# test_num = int(len(cur_list) * 0.05) +# for cur_f in random.sample(cur_list, test_num): +# cur_path = os.path.join(sub_path, cur_f) +# cur_path_dst = os.path.join(sub_path_dst, cur_f) +# os.system("cp %s %s" % (cur_path, cur_path_dst)) + diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000..64d403c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,80 @@ +absl-py==1.0.0 +Bottleneck==1.3.2 +brotlipy==0.7.0 +cachetools==5.0.0 +certifi==2021.10.8 +cffi @ file:///tmp/build/80754af9/cffi_1625807838443/work +charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work +click==8.0.3 +contextlib2==21.6.0 +cryptography @ file:///tmp/build/80754af9/cryptography_1635366571107/work +cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work +docopt==0.6.2 +esdk-obs-python==3.21.8 +faiss==1.7.1 +Flask @ file:///tmp/build/80754af9/flask_1634118196080/work +fonttools==4.25.0 +gevent @ file:///tmp/build/80754af9/gevent_1628273677693/work +google-auth==2.6.0 +google-auth-oauthlib==0.4.6 +greenlet @ file:///tmp/build/80754af9/greenlet_1628887725296/work +grpcio==1.44.0 +gunicorn==20.1.0 +h5py @ file:///tmp/build/80754af9/h5py_1637138879700/work +idna @ file:///tmp/build/80754af9/idna_1637925883363/work +importlib-metadata==4.11.3 +itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work +Jinja2 @ file:///tmp/build/80754af9/jinja2_1635780242639/work +kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work +Markdown==3.3.6 +MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528148836/work +matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1638289681807/work +mkl-fft==1.3.1 +mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work +mkl-service==2.4.0 +ml-collections==0.1.0 +munkres==1.1.4 +numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work +numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095647912/work +oauthlib==3.2.0 +olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work +opencv-python==4.5.4.60 +packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work +pandas==1.3.4 +Pillow==8.4.0 +pipreqs==0.4.11 +protobuf==3.19.4 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work +pycryptodome==3.10.1 +pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1635333100036/work +pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work +PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work +python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work +pytz==2021.3 +PyYAML==6.0 +requests @ file:///tmp/build/80754af9/requests_1629994808627/work +requests-oauthlib==1.3.1 +rsa==4.8 +scipy @ file:///tmp/build/80754af9/scipy_1630606796110/work +seaborn @ file:///tmp/build/80754af9/seaborn_1629307859561/work +sip==4.19.13 +six @ file:///tmp/build/80754af9/six_1623709665295/work +supervisor==4.2.2 +tensorboard==2.8.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +torch==1.8.0 +torchaudio==0.8.0a0+a751e1d +torchvision==0.9.0 +tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work +tqdm @ file:///tmp/build/80754af9/tqdm_1635330843403/work +typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work +urllib3==1.26.7 +Werkzeug @ file:///tmp/build/80754af9/werkzeug_1635505089296/work +yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work +yarg==0.1.9 +zipp==3.7.0 +zope.event==4.5.0 +zope.interface @ file:///tmp/build/80754af9/zope.interface_1625035545636/work diff --git a/start.sh b/start.sh new file mode 100644 index 0000000..915018f --- /dev/null +++ b/start.sh @@ -0,0 +1,3 @@ +#!/bin/bash +supervisorctl start ieemoo-ai-isempty + diff --git a/stop.sh b/stop.sh new file mode 100644 index 0000000..f5e9f04 --- /dev/null +++ b/stop.sh @@ -0,0 +1,2 @@ +#!/bin/bash +supervisorctl stop ieemoo-ai-isempty diff --git a/test_single.py b/test_single.py new file mode 100755 index 0000000..1f1072b --- /dev/null +++ b/test_single.py @@ -0,0 +1,64 @@ +# coding=utf-8 +import os +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +import argparse +from models.modeling import VisionTransformer, CONFIGS + + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", choices=["CUB_200_2011", "emptyJudge5", "emptyJudge4"], default="emptyJudge5", help="Which dataset.") +parser.add_argument("--img_size", default=448, type=int, help="Resolution size") +parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap +parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") +parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") +parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model") +args = parser.parse_args() + +args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +args.nprocs = torch.cuda.device_count() + +# Prepare Model +config = CONFIGS["ViT-B_16"] +config.split = args.split +config.slide_step = args.slide_step + +cls_dict = {} +num_classes = 0 +if args.dataset == "emptyJudge5": + num_classes = 5 + cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"} +elif args.dataset == "emptyJudge4": + num_classes = 4 + cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"} +elif args.dataset == "emptyJudge3": + num_classes = 3 + cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"} +elif args.dataset == "emptyJudge2": + num_classes = 2 + cls_dict = {0: "noemp", 1: "yesemp"} +model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) +if args.pretrained_model is not None: + pretrained_model = torch.load(args.pretrained_model, map_location=torch.device('cpu'))['model'] + model.load_state_dict(pretrained_model) +model.to(args.device) +model.eval() +# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), +# transforms.CenterCrop((448, 448)), +# transforms.ToTensor(), +# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) +test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) +img = Image.open("img.jpg") +x = test_transform(img) +part_logits = model(x.unsqueeze(0)) + +probs = torch.nn.Softmax(dim=-1)(part_logits) +top5 = torch.argsort(probs, dim=-1, descending=True) +print("Prediction Label\n") +for idx in top5[0, :5]: + print(f'{probs[0, idx.item()]:.5f} : {cls_dict[idx.item()]}', end='\n') + diff --git a/train.py b/train.py new file mode 100755 index 0000000..fb444c7 --- /dev/null +++ b/train.py @@ -0,0 +1,379 @@ +# coding=utf-8 +from __future__ import absolute_import, division, print_function + +import logging +import argparse +import os +import random +import numpy as np +import time + +from datetime import timedelta + +import torch +import torch.distributed as dist + +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +from models.modeling import VisionTransformer, CONFIGS +from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule +from utils.data_utils import get_loader +from utils.dist_util import get_world_size + +logger = logging.getLogger(__name__) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def reduce_mean(tensor, nprocs): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= nprocs + return rt + + +def save_model(args, model): + model_to_save = model.module if hasattr(model, 'module') else model + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) + checkpoint = { + 'model': model_to_save.state_dict(), + } + torch.save(checkpoint, model_checkpoint) + logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) + + +def setup(args): + # Prepare model + config = CONFIGS[args.model_type] + config.split = args.split + config.slide_step = args.slide_step + + if args.dataset == "CUB_200_2011": + num_classes = 200 + elif args.dataset == "car": + num_classes = 196 + elif args.dataset == "nabirds": + num_classes = 555 + elif args.dataset == "dog": + num_classes = 120 + elif args.dataset == "INat2017": + num_classes = 5089 + elif args.dataset == "emptyJudge5": + num_classes = 5 + elif args.dataset == "emptyJudge4": + num_classes = 4 + elif args.dataset == "emptyJudge3": + num_classes = 3 + + model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value) + + model.load_from(np.load(args.pretrained_dir)) + if args.pretrained_model is not None: + pretrained_model = torch.load(args.pretrained_model)['model'] + model.load_state_dict(pretrained_model) + model.to(args.device) + num_params = count_parameters(model) + + logger.info("{}".format(config)) + logger.info("Training parameters %s", args) + logger.info("Total Parameter: \t%2.1fM" % num_params) + return args, model + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params/1000000 + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def valid(args, model, writer, test_loader, global_step): + eval_losses = AverageMeter() + + logger.info("***** Running Validation *****") + # logger.info("val Num steps = %d", len(test_loader)) + # logger.info("val Batch size = %d", args.eval_batch_size) + + model.eval() + all_preds, all_label = [], [] + epoch_iterator = tqdm(test_loader, + desc="Validating... (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + loss_fct = torch.nn.CrossEntropyLoss() + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + with torch.no_grad(): + logits = model(x) + + eval_loss = loss_fct(logits, y) + eval_loss = eval_loss.mean() + eval_losses.update(eval_loss.item()) + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) + + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + accuracy = torch.tensor(accuracy).to(args.device) + # dist.barrier() + # val_accuracy = reduce_mean(accuracy, args.nprocs) + # val_accuracy = val_accuracy.detach().cpu().numpy() + val_accuracy = accuracy.detach().cpu().numpy() + + logger.info("\n") + logger.info("Validation Results") + logger.info("Global Steps: %d" % global_step) + logger.info("Valid Loss: %2.5f" % eval_losses.avg) + logger.info("Valid Accuracy: %2.5f" % val_accuracy) + if args.local_rank in [-1, 0]: + writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step) + return val_accuracy + + +def train(args, model): + """ Train the model """ + if args.local_rank in [-1, 0]: + os.makedirs(args.output_dir, exist_ok=True) + writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + # Prepare dataset + train_loader, test_loader = get_loader(args) + logger.info("train Num steps = %d", len(train_loader)) + logger.info("val Num steps = %d", len(test_loader)) + # Prepare optimizer and scheduler + optimizer = torch.optim.SGD(model.parameters(), + lr=args.learning_rate, + momentum=0.9, + weight_decay=args.weight_decay) + t_total = args.num_steps + if args.decay_type == "cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + else: + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + + # Train! + logger.info("***** Running training *****") + logger.info(" Total optimization steps = %d", args.num_steps) + logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + + model.zero_grad() + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + losses = AverageMeter() + global_step, best_acc = 0, 0 + start_time = time.time() + while True: + model.train() + epoch_iterator = tqdm(train_loader, + desc="Training (X / X Steps) (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + all_preds, all_label = [], [] + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + + loss, logits = model(x, y) + loss = loss.mean() + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + losses.update(loss.item()*args.gradient_accumulation_steps) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scheduler.step() + optimizer.step() + optimizer.zero_grad() + global_step += 1 + + epoch_iterator.set_description( + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) + ) + if args.local_rank in [-1, 0]: + writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) + writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) + if global_step % args.eval_every == 0: + with torch.no_grad(): + accuracy = valid(args, model, writer, test_loader, global_step) + if args.local_rank in [-1, 0]: + if best_acc < accuracy: + save_model(args, model) + best_acc = accuracy + logger.info("best accuracy so far: %f" % best_acc) + model.train() + + if global_step % t_total == 0: + break + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + accuracy = torch.tensor(accuracy).to(args.device) + # dist.barrier() + # train_accuracy = reduce_mean(accuracy, args.nprocs) + # train_accuracy = train_accuracy.detach().cpu().numpy() + train_accuracy = accuracy.detach().cpu().numpy() + logger.info("train accuracy so far: %f" % train_accuracy) + losses.reset() + if global_step % t_total == 0: + break + + writer.close() + logger.info("Best Accuracy: \t%f" % best_acc) + logger.info("End Training!") + end_time = time.time() + logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600)) + + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--name", required=True, + help="Name of this run. Used for monitoring.") + parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"], + default="CUB_200_2011", help="Which dataset.") + parser.add_argument('--data_root', type=str, default='/data/fineGrained') + parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"], + default="ViT-B_16",help="Which variant to use.") + parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz", + help="Where to search for pretrained ViT models.") + parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model") + #parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") + parser.add_argument("--output_dir", default="./output", type=str, + help="The output directory where checkpoints will be written.") + parser.add_argument("--img_size", default=448, type=int, help="Resolution size") + parser.add_argument("--train_batch_size", default=64, type=int, + help="Total batch size for training.") + parser.add_argument("--eval_batch_size", default=16, type=int, + help="Total batch size for eval.") + parser.add_argument("--eval_every", default=200, type=int, + help="Run prediction on validation set every so many steps." + "Will always run one evaluation at the end of training.") + + parser.add_argument("--learning_rate", default=3e-2, type=float, + help="The initial learning rate for SGD.") + parser.add_argument("--weight_decay", default=0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--num_steps", default=8000, type=int, #100000 + help="Total number of training epochs to perform.") + parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", + help="How to decay the learning rate.") + parser.add_argument("--warmup_steps", default=500, type=int, + help="Step of training to perform learning rate warmup for.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O2', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + parser.add_argument('--loss_scale', type=float, default=0, + help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" + "0 (default value): dynamic loss scaling.\n" + "Positive power of 2: static loss scaling value.\n") + + parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n") + + parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap + parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split") + args = parser.parse_args() + + args.data_root = '{}/{}'.format(args.data_root, args.dataset) + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count()) + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60)) + args.n_gpu = 1 + args.device = device + args.nprocs = torch.cuda.device_count() + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) + + # Set seed + set_seed(args) + + # Model & Tokenizer Setup + args, model = setup(args) + # Training + train(args, model) + + +if __name__ == "__main__": + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/utils/autoaugment.py b/utils/autoaugment.py new file mode 100755 index 0000000..1caf904 --- /dev/null +++ b/utils/autoaugment.py @@ -0,0 +1,204 @@ +""" +Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py +""" + +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random + +__all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy'] + + +class AutoAugImageNetPolicy(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), + + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), + + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment ImageNet Policy" + + +class AutoAugCIFAR10Policy(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), + + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), + + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment CIFAR10 Policy" + + +class AutoAugSVHNPolicy(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), + + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), + SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), + SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), + + SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), + SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), + SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), + SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment SVHN Policy" + + +class SubPolicy(object): + def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), + "solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) + + func = { + "shearX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( + # operation1, ranges[operation1][magnitude_idx1], + # operation2, ranges[operation2][magnitude_idx2]) + self.p1 = p1 + self.operation1 = func[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = func[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + if random.random() < self.p1: + img = self.operation1(img, self.magnitude1) + if random.random() < self.p2: + img = self.operation2(img, self.magnitude2) + return img diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100755 index 0000000..86df2a9 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,135 @@ +import logging +from PIL import Image +import os + +import torch + +from torchvision import transforms +from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler + +from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge +from .autoaugment import AutoAugImageNetPolicy + +logger = logging.getLogger(__name__) + + +def get_loader(args): + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + if args.dataset == 'CUB_200_2011': + train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.RandomCrop((448, 448)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.CenterCrop((448, 448)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + trainset = CUB(root=args.data_root, is_train=True, transform=train_transform) + testset = CUB(root=args.data_root, is_train=False, transform=test_transform) + elif args.dataset == 'car': + trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'), + os.path.join(args.data_root,'cars_train'), + os.path.join(args.data_root,'devkit/cars_meta.mat'), + # cleaned=os.path.join(data_dir,'cleaned.dat'), + transform=transforms.Compose([ + transforms.Resize((600, 600), Image.BILINEAR), + transforms.RandomCrop((448, 448)), + transforms.RandomHorizontalFlip(), + AutoAugImageNetPolicy(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + ) + testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'), + os.path.join(args.data_root,'cars_test'), + os.path.join(args.data_root,'devkit/cars_meta.mat'), + # cleaned=os.path.join(data_dir,'cleaned_test.dat'), + transform=transforms.Compose([ + transforms.Resize((600, 600), Image.BILINEAR), + transforms.CenterCrop((448, 448)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + ) + elif args.dataset == 'dog': + train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.RandomCrop((448, 448)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.CenterCrop((448, 448)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + trainset = dogs(root=args.data_root, + train=True, + cropped=False, + transform=train_transform, + download=False + ) + testset = dogs(root=args.data_root, + train=False, + cropped=False, + transform=test_transform, + download=False + ) + elif args.dataset == 'nabirds': + train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.RandomCrop((448, 448)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.CenterCrop((448, 448)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + trainset = NABirds(root=args.data_root, train=True, transform=train_transform) + testset = NABirds(root=args.data_root, train=False, transform=test_transform) + elif args.dataset == 'INat2017': + train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR), + transforms.RandomCrop((304, 304)), + transforms.RandomHorizontalFlip(), + AutoAugImageNetPolicy(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR), + transforms.CenterCrop((304, 304)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + trainset = INat2017(args.data_root, 'train', train_transform) + testset = INat2017(args.data_root, 'val', test_transform) + elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4': + train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + transforms.RandomCrop((448, 448)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + # test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + # transforms.CenterCrop((448, 448)), + # transforms.ToTensor(), + # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform) + testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform) + + if args.local_rank == 0: + torch.distributed.barrier() + + train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) + test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset) + train_loader = DataLoader(trainset, + sampler=train_sampler, + batch_size=args.train_batch_size, + num_workers=4, + drop_last=True, + pin_memory=True) + test_loader = DataLoader(testset, + sampler=test_sampler, + batch_size=args.eval_batch_size, + num_workers=4, + pin_memory=True) if testset is not None else None + + return train_loader, test_loader diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100755 index 0000000..7a06567 --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,629 @@ +import os +import json +from os.path import join + +import numpy as np +import scipy +from scipy import io +import scipy.misc +from PIL import Image +import pandas as pd +import matplotlib.pyplot as plt + +import torch +from torch.utils.data import Dataset +from torchvision.datasets import VisionDataset +from torchvision.datasets.folder import default_loader +from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg + + +class emptyJudge(): + def __init__(self, root, is_train=True, data_len=None, transform=None): + self.root = root + self.is_train = is_train + self.transform = transform + img_txt_file = open(os.path.join(self.root, 'images.txt')) + label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) + train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) + img_name_list = [] + for line in img_txt_file: + img_name_list.append(line[:-1].split(' ')[-1]) + label_list = [] + for line in label_txt_file: + label_list.append(int(line[:-1].split(' ')[-1]) - 1) + train_test_list = [] + for line in train_val_file: + train_test_list.append(int(line[:-1].split(' ')[-1])) + train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] + test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] + if self.is_train: + self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in + train_file_list[:data_len]] + self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] + self.train_imgname = [x for x in train_file_list[:data_len]] + if not self.is_train: + self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in + test_file_list[:data_len]] + self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] + self.test_imgname = [x for x in test_file_list[:data_len]] + + def __getitem__(self, index): + if self.is_train: + img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] + if len(img.shape) == 2: + img = np.stack([img] * 3, 2) + img = Image.fromarray(img, mode='RGB') + if self.transform is not None: + img = self.transform(img) + else: + img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index] + if len(img.shape) == 2: + img = np.stack([img] * 3, 2) + img = Image.fromarray(img, mode='RGB') + if self.transform is not None: + img = self.transform(img) + return img, target + + def __len__(self): + if self.is_train: + return len(self.train_label) + else: + return len(self.test_label) + + +class CUB(): + def __init__(self, root, is_train=True, data_len=None, transform=None): + self.root = root + self.is_train = is_train + self.transform = transform + img_txt_file = open(os.path.join(self.root, 'images.txt')) + label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) + train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) + img_name_list = [] + for line in img_txt_file: + img_name_list.append(line[:-1].split(' ')[-1]) + label_list = [] + for line in label_txt_file: + label_list.append(int(line[:-1].split(' ')[-1]) - 1) + train_test_list = [] + for line in train_val_file: + train_test_list.append(int(line[:-1].split(' ')[-1])) + train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] + test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] + if self.is_train: + self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in + train_file_list[:data_len]] + self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] + self.train_imgname = [x for x in train_file_list[:data_len]] + if not self.is_train: + self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in + test_file_list[:data_len]] + self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] + self.test_imgname = [x for x in test_file_list[:data_len]] + + def __getitem__(self, index): + if self.is_train: + img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] + if len(img.shape) == 2: + img = np.stack([img] * 3, 2) + img = Image.fromarray(img, mode='RGB') + if self.transform is not None: + img = self.transform(img) + else: + img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index] + if len(img.shape) == 2: + img = np.stack([img] * 3, 2) + img = Image.fromarray(img, mode='RGB') + if self.transform is not None: + img = self.transform(img) + + return img, target + + def __len__(self): + if self.is_train: + return len(self.train_label) + else: + return len(self.test_label) + + +class CarsDataset(Dataset): + def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None): + """ + Args: + mat_anno (string): Path to the MATLAB annotation file. + data_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + + self.full_data_set = io.loadmat(mat_anno) + self.car_annotations = self.full_data_set['annotations'] + self.car_annotations = self.car_annotations[0] + + if cleaned is not None: + cleaned_annos = [] + print("Cleaning up data set (only take pics with rgb chans)...") + clean_files = np.loadtxt(cleaned, dtype=str) + for c in self.car_annotations: + if c[-1][0] in clean_files: + cleaned_annos.append(c) + self.car_annotations = cleaned_annos + + self.car_names = scipy.io.loadmat(car_names)['class_names'] + self.car_names = np.array(self.car_names[0]) + + self.data_dir = data_dir + self.transform = transform + + def __len__(self): + return len(self.car_annotations) + + def __getitem__(self, idx): + img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0]) + image = Image.open(img_name).convert('RGB') + car_class = self.car_annotations[idx][-2][0][0] + car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1 + assert car_class < 196 + + if self.transform: + image = self.transform(image) + + # return image, car_class, img_name + return image, car_class + + def map_class(self, id): + id = np.ravel(id) + ret = self.car_names[id - 1][0][0] + return ret + + def show_batch(self, img_batch, class_batch): + + for i in range(img_batch.shape[0]): + ax = plt.subplot(1, img_batch.shape[0], i + 1) + title_str = self.map_class(int(class_batch[i])) + img = np.transpose(img_batch[i, ...], (1, 2, 0)) + ax.imshow(img) + ax.set_title(title_str.__str__(), {'fontsize': 5}) + plt.tight_layout() + + +def make_dataset(dir, image_ids, targets): + assert(len(image_ids) == len(targets)) + images = [] + dir = os.path.expanduser(dir) + for i in range(len(image_ids)): + item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i]) + images.append(item) + return images + + +def find_classes(classes_file): + # read classes file, separating out image IDs and class names + image_ids = [] + targets = [] + f = open(classes_file, 'r') + for line in f: + split_line = line.split(' ') + image_ids.append(split_line[0]) + targets.append(' '.join(split_line[1:])) + f.close() + + # index class names + classes = np.unique(targets) + class_to_idx = {classes[i]: i for i in range(len(classes))} + targets = [class_to_idx[c] for c in targets] + return (image_ids, targets, classes, class_to_idx) + + +class dogs(Dataset): + """`Stanford Dogs `_ Dataset. + Args: + root (string): Root directory of dataset where directory + ``omniglot-py`` exists. + cropped (bool, optional): If true, the images will be cropped into the bounding box specified + in the annotations + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset tar files from the internet and + puts it in root directory. If the tar files are already downloaded, they are not + downloaded again. + """ + folder = 'dog' + download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs' + + def __init__(self, + root, + train=True, + cropped=False, + transform=None, + target_transform=None, + download=False): + + # self.root = join(os.path.expanduser(root), self.folder) + self.root = root + self.train = train + self.cropped = cropped + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + split = self.load_split() + + self.images_folder = join(self.root, 'Images') + self.annotations_folder = join(self.root, 'Annotation') + self._breeds = list_dir(self.images_folder) + + if self.cropped: + self._breed_annotations = [[(annotation, box, idx) + for box in self.get_boxes(join(self.annotations_folder, annotation))] + for annotation, idx in split] + self._flat_breed_annotations = sum(self._breed_annotations, []) + + self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations] + else: + self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split] + + self._flat_breed_images = self._breed_images + + self.classes = ["Chihuaha", + "Japanese Spaniel", + "Maltese Dog", + "Pekinese", + "Shih-Tzu", + "Blenheim Spaniel", + "Papillon", + "Toy Terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick", + "Black-and-tan Coonhound", + "Walker Hound", + "English Foxhound", + "Redbone", + "Borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizian Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bullterrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wirehaired Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale", + "Cairn", + "Australian Terrier", + "Dandi Dinmont", + "Boston Bull", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scotch Terrier", + "Tibetan Terrier", + "Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa", + "Flat-coated Retriever", + "Curly-coater Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Short-haired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany", + "Clumber", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael", + "Malinois", + "Briard", + "Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "Collie", + "Border Collie", + "Bouvier des Flandres", + "Rottweiler", + "German Shepard", + "Doberman", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller", + "EntleBucher", + "Boxer", + "Bull Mastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "Saint Bernard", + "Eskimo Dog", + "Malamute", + "Siberian Husky", + "Affenpinscher", + "Basenji", + "Pug", + "Leonberg", + "Newfoundland", + "Great Pyrenees", + "Samoyed", + "Pomeranian", + "Chow", + "Keeshond", + "Brabancon Griffon", + "Pembroke", + "Cardigan", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican Hairless", + "Dingo", + "Dhole", + "African Hunting Dog"] + + def __len__(self): + return len(self._flat_breed_images) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target character class. + """ + image_name, target_class = self._flat_breed_images[index] + image_path = join(self.images_folder, image_name) + image = Image.open(image_path).convert('RGB') + + if self.cropped: + image = image.crop(self._flat_breed_annotations[index][1]) + + if self.transform: + image = self.transform(image) + + if self.target_transform: + target_class = self.target_transform(target_class) + + return image, target_class + + def download(self): + import tarfile + + if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')): + if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120: + print('Files already downloaded and verified') + return + + for filename in ['images', 'annotation', 'lists']: + tar_filename = filename + '.tar' + url = self.download_url_prefix + '/' + tar_filename + download_url(url, self.root, tar_filename, None) + print('Extracting downloaded file: ' + join(self.root, tar_filename)) + with tarfile.open(join(self.root, tar_filename), 'r') as tar_file: + tar_file.extractall(self.root) + os.remove(join(self.root, tar_filename)) + + @staticmethod + def get_boxes(path): + import xml.etree.ElementTree + e = xml.etree.ElementTree.parse(path).getroot() + boxes = [] + for objs in e.iter('object'): + boxes.append([int(objs.find('bndbox').find('xmin').text), + int(objs.find('bndbox').find('ymin').text), + int(objs.find('bndbox').find('xmax').text), + int(objs.find('bndbox').find('ymax').text)]) + return boxes + + def load_split(self): + if self.train: + split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list'] + labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels'] + else: + split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list'] + labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels'] + + split = [item[0][0] for item in split] + labels = [item[0]-1 for item in labels] + return list(zip(split, labels)) + + def stats(self): + counts = {} + for index in range(len(self._flat_breed_images)): + image_name, target_class = self._flat_breed_images[index] + if target_class not in counts.keys(): + counts[target_class] = 1 + else: + counts[target_class] += 1 + + print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys())))) + return counts + + +class NABirds(Dataset): + """`NABirds `_ Dataset. + + Args: + root (string): Root directory of the dataset. + train (bool, optional): If True, creates dataset from training set, otherwise + creates from test set. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + base_folder = 'nabirds/images' + + def __init__(self, root, train=True, transform=None): + dataset_path = os.path.join(root, 'nabirds') + self.root = root + self.loader = default_loader + self.train = train + self.transform = transform + + image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'), + sep=' ', names=['img_id', 'filepath']) + image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'), + sep=' ', names=['img_id', 'target']) + # Since the raw labels are non-continuous, map them to new ones + self.label_map = get_continuous_class_map(image_class_labels['target']) + train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'), + sep=' ', names=['img_id', 'is_training_img']) + data = image_paths.merge(image_class_labels, on='img_id') + self.data = data.merge(train_test_split, on='img_id') + # Load in the train / test split + if self.train: + self.data = self.data[self.data.is_training_img == 1] + else: + self.data = self.data[self.data.is_training_img == 0] + + # Load in the class data + self.class_names = load_class_names(dataset_path) + self.class_hierarchy = load_hierarchy(dataset_path) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data.iloc[idx] + path = os.path.join(self.root, self.base_folder, sample.filepath) + target = self.label_map[sample.target] + img = self.loader(path) + + if self.transform is not None: + img = self.transform(img) + return img, target + + +def get_continuous_class_map(class_labels): + label_set = set(class_labels) + return {k: i for i, k in enumerate(label_set)} + + +def load_class_names(dataset_path=''): + names = {} + + with open(os.path.join(dataset_path, 'classes.txt')) as f: + for line in f: + pieces = line.strip().split() + class_id = pieces[0] + names[class_id] = ' '.join(pieces[1:]) + + return names + + +def load_hierarchy(dataset_path=''): + parents = {} + + with open(os.path.join(dataset_path, 'hierarchy.txt')) as f: + for line in f: + pieces = line.strip().split() + child_id, parent_id = pieces + parents[child_id] = parent_id + + return parents + + +class INat2017(VisionDataset): + """`iNaturalist 2017 `_ Dataset. + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``train``, or ``val``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + base_folder = 'train_val_images/' + file_list = { + 'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz', + 'train_val_images.tar.gz', + '7c784ea5e424efaec655bd392f87301f'), + 'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip', + 'train_val2017.zip', + '444c835f6459867ad69fcb36478786e7') + } + + def __init__(self, root, split='train', transform=None, target_transform=None, download=False): + super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform) + self.loader = default_loader + self.split = verify_str_arg(split, "split", ("train", "val",)) + + if self._check_exists(): + print('Files already downloaded and verified.') + elif download: + if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1])) + and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))): + print('Downloading...') + self._download() + print('Extracting...') + extract_archive(os.path.join(self.root, self.file_list['imgs'][1])) + extract_archive(os.path.join(self.root, self.file_list['annos'][1])) + else: + raise RuntimeError( + 'Dataset not found. You can use download=True to download it.') + anno_filename = split + '2017.json' + with open(os.path.join(self.root, anno_filename), 'r') as fp: + all_annos = json.load(fp) + + self.annos = all_annos['annotations'] + self.images = all_annos['images'] + + def __getitem__(self, index): + path = os.path.join(self.root, self.images[index]['file_name']) + target = self.annos[index]['category_id'] + + image = self.loader(path) + if self.transform is not None: + image = self.transform(image) + if self.target_transform is not None: + target = self.target_transform(target) + + return image, target + + def __len__(self): + return len(self.images) + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, self.base_folder)) + + def _download(self): + for url, filename, md5 in self.file_list.values(): + download_url(url, root=self.root, filename=filename) + if not check_integrity(os.path.join(self.root, filename), md5): + raise RuntimeError("File not found or corrupted.") diff --git a/utils/dist_util.py b/utils/dist_util.py new file mode 100755 index 0000000..ab8c447 --- /dev/null +++ b/utils/dist_util.py @@ -0,0 +1,30 @@ +import torch.distributed as dist + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + +def is_main_process(): + return get_rank() == 0 + +def format_step(step): + if isinstance(step, str): + return step + s = "" + if len(step) > 0: + s += "Training Epoch: {} ".format(step[0]) + if len(step) > 1: + s += "Training Iteration: {} ".format(step[1]) + if len(step) > 2: + s += "Validation Iteration: {} ".format(step[2]) + return s diff --git a/utils/scheduler.py b/utils/scheduler.py new file mode 100755 index 0000000..9daaf6e --- /dev/null +++ b/utils/scheduler.py @@ -0,0 +1,63 @@ +import logging +import math + +from torch.optim.lr_scheduler import LambdaLR + +logger = logging.getLogger(__name__) + +class ConstantLRSchedule(LambdaLR): + """ Constant learning rate schedule. + """ + def __init__(self, optimizer, last_epoch=-1): + super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) + + +class WarmupConstantSchedule(LambdaLR): + """ Linear warmup and then constant. + Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. + Keeps learning rate schedule equal to 1. after warmup_steps. + """ + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + return 1. + + +class WarmupLinearSchedule(LambdaLR): + """ Linear warmup and then linear decay. + Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. + Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. + """ + def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1, self.warmup_steps)) + return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) + + +class WarmupCosineSchedule(LambdaLR): + """ Linear warmup and then cosine decay. + Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. + Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. + If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. + """ + def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + self.cycles = cycles + super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + # progress after warmup + progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) + return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))