This commit is contained in:
li chen
2022-04-08 18:13:02 +08:00
commit f23dc22752
21 changed files with 2495 additions and 0 deletions

7
.gitignore vendored Executable file
View File

@ -0,0 +1,7 @@
.idea/
ckpts/
logs/
models/__pycache__/
utils/__pycache__/
output/
attention_data/

21
LICENSE Executable file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Ju He
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

BIN
docs/TransFG.pdf Executable file

Binary file not shown.

BIN
docs/TransFG.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 733 KiB

136
ieemoo-ai-isempty.py Normal file
View File

@ -0,0 +1,136 @@
# -*- coding: utf-8 -*-
from flask import request, Flask
import numpy as np
import json
import time
import cv2, base64
import argparse
import sys, os
import torch
from PIL import Image
from torchvision import transforms
from models.modeling import VisionTransformer, CONFIGS
sys.path.insert(0, ".")
app = Flask(__name__)
app.use_reloader=False
def parse_args(model_file="ckpts/emptyjudge5_checkpoint.bin"):
parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method")
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model")
opt, unknown = parser.parse_known_args()
return opt
class Predictor(object):
def __init__(self, args):
self.args = args
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.args.device)
self.args.nprocs = torch.cuda.device_count()
self.cls_dict = {}
self.num_classes = 0
self.model = None
self.prepare_model()
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def prepare_model(self):
config = CONFIGS["ViT-B_16"]
config.split = self.args.split
config.slide_step = self.args.slide_step
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
print("use model_name: ", model_name)
self.num_classes = 5
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
if self.args.pretrained_model is not None:
if not torch.cuda.is_available():
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
self.model.load_state_dict(pretrained_model)
else:
pretrained_model = torch.load(self.args.pretrained_model)['model']
self.model.load_state_dict(pretrained_model)
self.model.eval()
self.model.to(self.args.device)
#self.model.eval()
def normal_predict(self, img_data, result):
# img = Image.open(img_path)
if img_data is None:
print('error, img data is None')
return result
else:
with torch.no_grad():
x = self.test_transform(img_data)
if torch.cuda.is_available():
x = x.cuda()
part_logits = self.model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
clas_ids = topN[0][0]
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
result["success"] = "true"
result["rst_cls"] = str(clas_ids)
return result
model_file ="/data/ieemoo/emptypredict_pfc_FG/ckpts/emptyjudge5_checkpoint.bin"
args = parse_args(model_file)
predictor = Predictor(args)
@app.route("/isempty", methods=['POST'])
def get_isempty():
start = time.time()
print('--------------------EmptyPredict-----------------')
data = request.get_data()
ip = request.remote_addr
print('------ ip = %s ------' % ip)
json_data = json.loads(data.decode("utf-8"))
getdateend = time.time()
print('get date use time: {0:.2f}s'.format(getdateend - start))
pic = json_data.get("pic")
result = {"success": "false",
"rst_cls": '-1',
}
try:
imgdata = base64.b64decode(pic)
imgdata_np = np.frombuffer(imgdata, dtype='uint8')
img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
img_data = Image.fromarray(np.uint8(img_src))
result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
except:
return repr(result)
return repr(result)
if __name__ == "__main__":
app.run()
# app.run("0.0.0.0", port=8083)

4
init.sh Normal file
View File

@ -0,0 +1,4 @@
/opt/miniconda3/bin/conda activate ieemoo
/opt/miniconda3/envs/ieemoo/bin/pip install -r requirements.txt

76
models/configs.py Executable file
View File

@ -0,0 +1,76 @@
import ml_collections
def get_testing():
"""Returns a minimal configuration for testing."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 1
config.transformer.num_heads = 1
config.transformer.num_layers = 1
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.split = 'non-overlap'
config.slide_step = 12
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
def get_b32_config():
"""Returns the ViT-B/32 configuration."""
config = get_b16_config()
config.patches.size = (32, 32)
return config
def get_l16_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1024
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 4096
config.transformer.num_heads = 16
config.transformer.num_layers = 24
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
def get_l32_config():
"""Returns the ViT-L/32 configuration."""
config = get_l16_config()
config.patches.size = (32, 32)
return config
def get_h14_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
config.hidden_size = 1280
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 5120
config.transformer.num_heads = 16
config.transformer.num_layers = 32
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config

390
models/modeling.py Executable file
View File

@ -0,0 +1,390 @@
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
import models.configs as configs
logger = logging.getLogger(__name__)
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
class LabelSmoothing(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.0):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
class Attention(nn.Module):
def __init__(self, config):
super(Attention, self).__init__()
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config.transformer["dropout_rate"])
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
img_size = _pair(img_size)
patch_size = _pair(config.patches["size"])
if config.split == 'non-overlap':
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
elif config.split == 'overlap':
n_patches = ((img_size[0] - patch_size[0]) // config.slide_step + 1) * ((img_size[1] - patch_size[1]) // config.slide_step + 1)
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=(config.slide_step, config.slide_step))
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.dropout = Dropout(config.transformer["dropout_rate"])
def forward(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.hybrid:
x = self.hybrid_model(x)
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
x = torch.cat((cls_tokens, x), dim=1)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class Block(nn.Module):
def __init__(self, config):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config)
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights
def load_from(self, weights, n_block):
ROOT = f"Transformer/encoderblock_{n_block}"
with torch.no_grad():
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
self.attn.query.weight.copy_(query_weight)
self.attn.key.weight.copy_(key_weight)
self.attn.value.weight.copy_(value_weight)
self.attn.out.weight.copy_(out_weight)
self.attn.query.bias.copy_(query_bias)
self.attn.key.bias.copy_(key_bias)
self.attn.value.bias.copy_(value_bias)
self.attn.out.bias.copy_(out_bias)
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
self.ffn.fc1.weight.copy_(mlp_weight_0)
self.ffn.fc2.weight.copy_(mlp_weight_1)
self.ffn.fc1.bias.copy_(mlp_bias_0)
self.ffn.fc2.bias.copy_(mlp_bias_1)
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
class Part_Attention(nn.Module):
def __init__(self):
super(Part_Attention, self).__init__()
def forward(self, x):
length = len(x)
last_map = x[0]
for i in range(1, length):
last_map = torch.matmul(x[i], last_map)
last_map = last_map[:,:,0,1:]
_, max_inx = last_map.max(2)
return _, max_inx
class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
self.layer = nn.ModuleList()
for _ in range(config.transformer["num_layers"] - 1):
layer = Block(config)
self.layer.append(copy.deepcopy(layer))
self.part_select = Part_Attention()
self.part_layer = Block(config)
self.part_norm = LayerNorm(config.hidden_size, eps=1e-6)
def forward(self, hidden_states):
attn_weights = []
for layer in self.layer:
hidden_states, weights = layer(hidden_states)
attn_weights.append(weights)
part_num, part_inx = self.part_select(attn_weights)
part_inx = part_inx + 1
parts = []
B, num = part_inx.shape
for i in range(B):
parts.append(hidden_states[i, part_inx[i,:]])
parts = torch.stack(parts).squeeze(1)
concat = torch.cat((hidden_states[:,0].unsqueeze(1), parts), dim=1)
part_states, part_weights = self.part_layer(concat)
part_encoded = self.part_norm(part_states)
return part_encoded
class Transformer(nn.Module):
def __init__(self, config, img_size):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config)
def forward(self, input_ids):
embedding_output = self.embeddings(input_ids)
part_encoded = self.encoder(embedding_output)
return part_encoded
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.smoothing_value = smoothing_value
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size)
self.part_head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
part_tokens = self.transformer(x)
part_logits = self.part_head(part_tokens[:, 0])
if labels is not None:
if self.smoothing_value == 0:
loss_fct = CrossEntropyLoss()
else:
loss_fct = LabelSmoothing(self.smoothing_value)
part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1))
loss = part_loss + contrast_loss
return loss, part_logits
else:
return part_logits
def load_from(self, weights):
with torch.no_grad():
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
self.transformer.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
self.transformer.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
posemb_new = self.transformer.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
self.transformer.embeddings.position_embeddings.copy_(posemb)
else:
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)
if self.classifier == "token":
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
for bname, block in self.transformer.encoder.named_children():
if bname.startswith('part') == False:
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)
if self.transformer.embeddings.hybrid:
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=bname, n_unit=uname)
def con_loss(features, labels):
B, _ = features.shape
features = F.normalize(features)
cos_matrix = features.mm(features.t())
pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float()
neg_label_matrix = 1 - pos_label_matrix
pos_cos_matrix = 1 - cos_matrix
neg_cos_matrix = cos_matrix - 0.4
neg_cos_matrix[neg_cos_matrix < 0] = 0
loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum()
loss /= (B * B)
return loss
CONFIGS = {
'ViT-B_16': configs.get_b16_config(),
'ViT-B_32': configs.get_b32_config(),
'ViT-L_16': configs.get_l16_config(),
'ViT-L_32': configs.get_l32_config(),
'ViT-H_14': configs.get_h14_config(),
'testing': configs.get_testing(),
}

153
predict.py Executable file
View File

@ -0,0 +1,153 @@
import numpy as np
import cv2
import time
import os
import argparse
import torch
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from PIL import Image
from torchvision import transforms
from models.modeling import VisionTransformer, CONFIGS
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
return parser.parse_args()
class Predictor(object):
def __init__(self, args):
self.args = args
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("self.args.device =", self.args.device)
self.args.nprocs = torch.cuda.device_count()
self.cls_dict = {}
self.num_classes = 0
self.model = None
self.prepare_model()
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def prepare_model(self):
config = CONFIGS["ViT-B_16"]
config.split = self.args.split
config.slide_step = self.args.slide_step
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
print("use model_name: ", model_name)
if model_name.lower() == "emptyJudge5".lower():
self.num_classes = 5
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
elif model_name.lower() == "emptyJudge4".lower():
self.num_classes = 4
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
elif model_name.lower() == "emptyJudge3".lower():
self.num_classes = 3
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
elif model_name.lower() == "emptyJudge2".lower():
self.num_classes = 2
self.cls_dict = {0: "noemp", 1: "yesemp"}
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
if self.args.pretrained_model is not None:
if not torch.cuda.is_available():
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
self.model.load_state_dict(pretrained_model)
else:
pretrained_model = torch.load(self.args.pretrained_model)['model']
self.model.load_state_dict(pretrained_model)
self.model.to(self.args.device)
self.model.eval()
def normal_predict(self, img_path):
# img = cv2.imread(img_path)
img = Image.open(img_path)
if img is None:
print(
"Image file failed to read: {}".format(img_path))
else:
x = self.test_transform(img)
if torch.cuda.is_available():
x = x.cuda()
part_logits = self.model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
clas_ids = topN[0][0]
# print(probs[0, topN[0][0]].item())
return clas_ids, probs[0, clas_ids].item()
if __name__ == "__main__":
args = parse_args()
predictor = Predictor(args)
y_true = []
y_pred = []
test_dir = "/data/pfc/fineGrained/test_5cls"
dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
total = 0
num = 0
t0 = time.time()
for dir_name, label in dir_dict.items():
cur_folder = os.path.join(test_dir, dir_name)
errorPath = os.path.join(test_dir, dir_name + "_error")
# os.makedirs(errorPath, exist_ok=True)
for cur_file in os.listdir(cur_folder):
total += 1
print("%d processing: %s" % (total, cur_file))
cur_img_file = os.path.join(cur_folder, cur_file)
error_img_dst = os.path.join(errorPath, cur_file)
cur_pred, pred_score = predictor.normal_predict(cur_img_file)
label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label)
cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred)
y_true.append(int(label))
y_pred.append(int(cur_pred))
if int(label) == int(cur_pred):
num += 1
# else:
# print(cur_file, "predict: ", cur_pred, "true: ", int(label))
# print(cur_file, "predict: ", cur_pred, "true: ", int(label), "pred_score:", pred_score)
# os.system("cp %s %s" % (cur_img_file, error_img_dst))
t1 = time.time()
print('The cast of time is :%f seconds' % (t1-t0))
rate = float(num)/total
print('The classification accuracy is %f' % rate)
rst_C = confusion_matrix(y_true, y_pred)
rst_f1 = f1_score(y_true, y_pred, average='macro')
print(rst_C)
print(rst_f1)
'''
test_imgs: yesemp=145, noemp=453 大图
output/emptyjudge5_checkpoint.bin
The classification accuracy is 0.976589
[[446 7] 1.5%
[ 7 138]] 4.8%
0.968135799649844
output/emptyjudge4_checkpoint.bin
The classification accuracy is 0.976589
[[450 3] 0.6%
[ 11 134]] 7.5%
0.9675186616384996
test_5cls: yesemp=319, noemp=925 小图
output/emptyjudge4_checkpoint.bin
The classification accuracy is 0.937299
[[885 40] 4.3%
[ 38 281]] 11.9%
0.9179586038961038
'''

119
prepara_data.py Executable file
View File

@ -0,0 +1,119 @@
import os
import cv2
import numpy as np
import subprocess
import random
# ----------- 改写名称 --------------
# index = 0
# src_dir = "/data/fineGrained/emptyJudge5"
# dst_dir = src_dir + "_new"
# os.makedirs(dst_dir, exist_ok=True)
# for sub in os.listdir(src_dir):
# sub_path = os.path.join(src_dir, sub)
# sub_path_dst = os.path.join(dst_dir, sub)
# os.makedirs(sub_path_dst, exist_ok=True)
# for cur_f in os.listdir(sub_path):
# cur_img = os.path.join(sub_path, cur_f)
# cur_img_dst = os.path.join(sub_path_dst, "a%05d.jpg" % index)
# index += 1
# os.system("mv %s %s" % (cur_img, cur_img_dst))
# ----------- 删除过小图像 --------------
# src_dir = "/data/fineGrained/emptyJudge5"
# for sub in os.listdir(src_dir):
# sub_path = os.path.join(src_dir, sub)
# for cur_f in os.listdir(sub_path):
# filepath = os.path.join(sub_path, cur_f)
# res = subprocess.check_output(['file', filepath])
# pp = res.decode("utf-8").split(",")[-2]
# height = int(pp.split("x")[1])
# width = int(pp.split("x")[0])
# min_l = min(height, width)
# if min_l <= 448:
# os.system("rm %s" % filepath)
# ----------- 获取有效图片并写images.txt --------------
# src_dir = "/data/fineGrained/emptyJudge4/images"
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
# all_dict = {"yesemp":[], "noemp":[], "hard": [], "stack": []}
# for sub, value in src_dict.items():
# sub_path = os.path.join(src_dir, sub)
# for cur_f in os.listdir(sub_path):
# all_dict[sub].append(os.path.join(sub, cur_f))
#
# yesnum = len(all_dict["yesemp"])
# nonum = len(all_dict["noemp"])
# hardnum = len(all_dict["hard"])
# stacknum = len(all_dict["stack"])
# thnum = min(yesnum, nonum, hardnum, stacknum)
# images_txt = src_dir + ".txt"
# index = 1
#
# def write_images(cur_list, thnum, fw, index):
# for feat_path in random.sample(cur_list, thnum):
# fw.write(str(index) + " " + feat_path + "\n")
# index += 1
# return index
#
# with open(images_txt, "w") as fw:
# index = write_images(all_dict["noemp"], thnum, fw, index)
# index = write_images(all_dict["yesemp"], thnum, fw, index)
# index = write_images(all_dict["hard"], thnum, fw, index)
# index = write_images(all_dict["stack"], thnum, fw, index)
# ----------- 写 image_class_labels.txt + train_test_split.txt --------------
# src_dir = "/data/fineGrained/emptyJudge4"
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
# images_txt = os.path.join(src_dir, "images.txt")
# image_class_labels_txt = os.path.join(src_dir, "image_class_labels.txt")
# imgs_cnt = 0
# with open(image_class_labels_txt, "w") as fw:
# with open(images_txt, "r") as fr:
# for cur_l in fr:
# imgs_cnt += 1
# img_index, img_f = cur_l.strip().split(" ")
# folder_name = img_f.split("/")[0]
# if folder_name in src_dict:
# cur_line = img_index + " " + str(int(src_dict[folder_name])+1)
# fw.write(cur_line + "\n")
#
# train_num = int(imgs_cnt*0.85)
# print("train_num= ", train_num, ", imgs_cnt= ", imgs_cnt)
# all_list = [1]*train_num + [0]*(imgs_cnt-train_num)
# assert len(all_list) == imgs_cnt
# random.shuffle(all_list)
# train_test_split_txt = os.path.join(src_dir, "train_test_split.txt")
# with open(train_test_split_txt, "w") as fw:
# with open(images_txt, "r") as fr:
# for cur_l in fr:
# img_index, img_f = cur_l.strip().split(" ")
# cur_line = img_index + " " + str(all_list[int(img_index) - 1])
# fw.write(cur_line + "\n")
# ----------- 生成标准测试集 --------------
# src_dir = "/data/fineGrained/emptyJudge5/images"
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
# all_dict = {"noemp":[], "yesemp":[], "hard": [], "fly": [], "stack": []}
# for sub, value in src_dict.items():
# sub_path = os.path.join(src_dir, sub)
# for cur_f in os.listdir(sub_path):
# all_dict[sub].append(cur_f)
#
# dst_dir = src_dir + "_test"
# os.makedirs(dst_dir, exist_ok=True)
# for sub, value in src_dict.items():
# sub_path = os.path.join(src_dir, sub)
# sub_path_dst = os.path.join(dst_dir, sub)
# os.makedirs(sub_path_dst, exist_ok=True)
#
# cur_list = all_dict[sub]
# test_num = int(len(cur_list) * 0.05)
# for cur_f in random.sample(cur_list, test_num):
# cur_path = os.path.join(sub_path, cur_f)
# cur_path_dst = os.path.join(sub_path_dst, cur_f)
# os.system("cp %s %s" % (cur_path, cur_path_dst))

80
requirements.txt Executable file
View File

@ -0,0 +1,80 @@
absl-py==1.0.0
Bottleneck==1.3.2
brotlipy==0.7.0
cachetools==5.0.0
certifi==2021.10.8
cffi @ file:///tmp/build/80754af9/cffi_1625807838443/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click==8.0.3
contextlib2==21.6.0
cryptography @ file:///tmp/build/80754af9/cryptography_1635366571107/work
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
docopt==0.6.2
esdk-obs-python==3.21.8
faiss==1.7.1
Flask @ file:///tmp/build/80754af9/flask_1634118196080/work
fonttools==4.25.0
gevent @ file:///tmp/build/80754af9/gevent_1628273677693/work
google-auth==2.6.0
google-auth-oauthlib==0.4.6
greenlet @ file:///tmp/build/80754af9/greenlet_1628887725296/work
grpcio==1.44.0
gunicorn==20.1.0
h5py @ file:///tmp/build/80754af9/h5py_1637138879700/work
idna @ file:///tmp/build/80754af9/idna_1637925883363/work
importlib-metadata==4.11.3
itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work
Jinja2 @ file:///tmp/build/80754af9/jinja2_1635780242639/work
kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work
Markdown==3.3.6
MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528148836/work
matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1638289681807/work
mkl-fft==1.3.1
mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work
mkl-service==2.4.0
ml-collections==0.1.0
munkres==1.1.4
numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095647912/work
oauthlib==3.2.0
olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work
opencv-python==4.5.4.60
packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
pandas==1.3.4
Pillow==8.4.0
pipreqs==0.4.11
protobuf==3.19.4
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pycryptodome==3.10.1
pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1635333100036/work
pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
pytz==2021.3
PyYAML==6.0
requests @ file:///tmp/build/80754af9/requests_1629994808627/work
requests-oauthlib==1.3.1
rsa==4.8
scipy @ file:///tmp/build/80754af9/scipy_1630606796110/work
seaborn @ file:///tmp/build/80754af9/seaborn_1629307859561/work
sip==4.19.13
six @ file:///tmp/build/80754af9/six_1623709665295/work
supervisor==4.2.2
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
torch==1.8.0
torchaudio==0.8.0a0+a751e1d
torchvision==0.9.0
tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work
tqdm @ file:///tmp/build/80754af9/tqdm_1635330843403/work
typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work
urllib3==1.26.7
Werkzeug @ file:///tmp/build/80754af9/werkzeug_1635505089296/work
yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work
yarg==0.1.9
zipp==3.7.0
zope.event==4.5.0
zope.interface @ file:///tmp/build/80754af9/zope.interface_1625035545636/work

3
start.sh Normal file
View File

@ -0,0 +1,3 @@
#!/bin/bash
supervisorctl start ieemoo-ai-isempty

2
stop.sh Normal file
View File

@ -0,0 +1,2 @@
#!/bin/bash
supervisorctl stop ieemoo-ai-isempty

64
test_single.py Executable file
View File

@ -0,0 +1,64 @@
# coding=utf-8
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import argparse
from models.modeling import VisionTransformer, CONFIGS
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", choices=["CUB_200_2011", "emptyJudge5", "emptyJudge4"], default="emptyJudge5", help="Which dataset.")
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.nprocs = torch.cuda.device_count()
# Prepare Model
config = CONFIGS["ViT-B_16"]
config.split = args.split
config.slide_step = args.slide_step
cls_dict = {}
num_classes = 0
if args.dataset == "emptyJudge5":
num_classes = 5
cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
elif args.dataset == "emptyJudge4":
num_classes = 4
cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
elif args.dataset == "emptyJudge3":
num_classes = 3
cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
elif args.dataset == "emptyJudge2":
num_classes = 2
cls_dict = {0: "noemp", 1: "yesemp"}
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
if args.pretrained_model is not None:
pretrained_model = torch.load(args.pretrained_model, map_location=torch.device('cpu'))['model']
model.load_state_dict(pretrained_model)
model.to(args.device)
model.eval()
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
# transforms.CenterCrop((448, 448)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img = Image.open("img.jpg")
x = test_transform(img)
part_logits = model(x.unsqueeze(0))
probs = torch.nn.Softmax(dim=-1)(part_logits)
top5 = torch.argsort(probs, dim=-1, descending=True)
print("Prediction Label\n")
for idx in top5[0, :5]:
print(f'{probs[0, idx.item()]:.5f} : {cls_dict[idx.item()]}', end='\n')

379
train.py Executable file
View File

@ -0,0 +1,379 @@
# coding=utf-8
from __future__ import absolute_import, division, print_function
import logging
import argparse
import os
import random
import numpy as np
import time
from datetime import timedelta
import torch
import torch.distributed as dist
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size
logger = logging.getLogger(__name__)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
def save_model(args, model):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
checkpoint = {
'model': model_to_save.state_dict(),
}
torch.save(checkpoint, model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
def setup(args):
# Prepare model
config = CONFIGS[args.model_type]
config.split = args.split
config.slide_step = args.slide_step
if args.dataset == "CUB_200_2011":
num_classes = 200
elif args.dataset == "car":
num_classes = 196
elif args.dataset == "nabirds":
num_classes = 555
elif args.dataset == "dog":
num_classes = 120
elif args.dataset == "INat2017":
num_classes = 5089
elif args.dataset == "emptyJudge5":
num_classes = 5
elif args.dataset == "emptyJudge4":
num_classes = 4
elif args.dataset == "emptyJudge3":
num_classes = 3
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
model.load_from(np.load(args.pretrained_dir))
if args.pretrained_model is not None:
pretrained_model = torch.load(args.pretrained_model)['model']
model.load_state_dict(pretrained_model)
model.to(args.device)
num_params = count_parameters(model)
logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)
return args, model
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def valid(args, model, writer, test_loader, global_step):
eval_losses = AverageMeter()
logger.info("***** Running Validation *****")
# logger.info("val Num steps = %d", len(test_loader))
# logger.info("val Batch size = %d", args.eval_batch_size)
model.eval()
all_preds, all_label = [], []
epoch_iterator = tqdm(test_loader,
desc="Validating... (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True,
disable=args.local_rank not in [-1, 0])
loss_fct = torch.nn.CrossEntropyLoss()
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
with torch.no_grad():
logits = model(x)
eval_loss = loss_fct(logits, y)
eval_loss = eval_loss.mean()
eval_losses.update(eval_loss.item())
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(args.device)
# dist.barrier()
# val_accuracy = reduce_mean(accuracy, args.nprocs)
# val_accuracy = val_accuracy.detach().cpu().numpy()
val_accuracy = accuracy.detach().cpu().numpy()
logger.info("\n")
logger.info("Validation Results")
logger.info("Global Steps: %d" % global_step)
logger.info("Valid Loss: %2.5f" % eval_losses.avg)
logger.info("Valid Accuracy: %2.5f" % val_accuracy)
if args.local_rank in [-1, 0]:
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
return val_accuracy
def train(args, model):
""" Train the model """
if args.local_rank in [-1, 0]:
os.makedirs(args.output_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
# Prepare dataset
train_loader, test_loader = get_loader(args)
logger.info("train Num steps = %d", len(train_loader))
logger.info("val Num steps = %d", len(test_loader))
# Prepare optimizer and scheduler
optimizer = torch.optim.SGD(model.parameters(),
lr=args.learning_rate,
momentum=0.9,
weight_decay=args.weight_decay)
t_total = args.num_steps
if args.decay_type == "cosine":
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
else:
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
# Train!
logger.info("***** Running training *****")
logger.info(" Total optimization steps = %d", args.num_steps)
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
model.zero_grad()
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
losses = AverageMeter()
global_step, best_acc = 0, 0
start_time = time.time()
while True:
model.train()
epoch_iterator = tqdm(train_loader,
desc="Training (X / X Steps) (loss=X.X)",
bar_format="{l_bar}{r_bar}",
dynamic_ncols=True,
disable=args.local_rank not in [-1, 0])
all_preds, all_label = [], []
for step, batch in enumerate(epoch_iterator):
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
loss, logits = model(x, y)
loss = loss.mean()
preds = torch.argmax(logits, dim=-1)
if len(all_preds) == 0:
all_preds.append(preds.detach().cpu().numpy())
all_label.append(y.detach().cpu().numpy())
else:
all_preds[0] = np.append(
all_preds[0], preds.detach().cpu().numpy(), axis=0
)
all_label[0] = np.append(
all_label[0], y.detach().cpu().numpy(), axis=0
)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
losses.update(loss.item()*args.gradient_accumulation_steps)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
)
if args.local_rank in [-1, 0]:
writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
if global_step % args.eval_every == 0:
with torch.no_grad():
accuracy = valid(args, model, writer, test_loader, global_step)
if args.local_rank in [-1, 0]:
if best_acc < accuracy:
save_model(args, model)
best_acc = accuracy
logger.info("best accuracy so far: %f" % best_acc)
model.train()
if global_step % t_total == 0:
break
all_preds, all_label = all_preds[0], all_label[0]
accuracy = simple_accuracy(all_preds, all_label)
accuracy = torch.tensor(accuracy).to(args.device)
# dist.barrier()
# train_accuracy = reduce_mean(accuracy, args.nprocs)
# train_accuracy = train_accuracy.detach().cpu().numpy()
train_accuracy = accuracy.detach().cpu().numpy()
logger.info("train accuracy so far: %f" % train_accuracy)
losses.reset()
if global_step % t_total == 0:
break
writer.close()
logger.info("Best Accuracy: \t%f" % best_acc)
logger.info("End Training!")
end_time = time.time()
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", required=True,
help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
default="CUB_200_2011", help="Which dataset.")
parser.add_argument('--data_root', type=str, default='/data/fineGrained')
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
default="ViT-B_16",help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
parser.add_argument("--output_dir", default="./output", type=str,
help="The output directory where checkpoints will be written.")
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
parser.add_argument("--train_batch_size", default=64, type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=16, type=int,
help="Total batch size for eval.")
parser.add_argument("--eval_every", default=200, type=int,
help="Run prediction on validation set every so many steps."
"Will always run one evaluation at the end of training.")
parser.add_argument("--learning_rate", default=3e-2, type=float,
help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=8000, type=int, #100000
help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
args = parser.parse_args()
args.data_root = '{}/{}'.format(args.data_root, args.dataset)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60))
args.n_gpu = 1
args.device = device
args.nprocs = torch.cuda.device_count()
# Setup logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
# Set seed
set_seed(args)
# Model & Tokenizer Setup
args, model = setup(args)
# Training
train(args, model)
if __name__ == "__main__":
main()

0
utils/__init__.py Executable file
View File

204
utils/autoaugment.py Executable file
View File

@ -0,0 +1,204 @@
"""
Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
__all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy']
class AutoAugImageNetPolicy(object):
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class AutoAugCIFAR10Policy(object):
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class AutoAugSVHNPolicy(object):
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
# self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
# operation1, ranges[operation1][magnitude_idx1],
# operation2, ranges[operation2][magnitude_idx2])
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1:
img = self.operation1(img, self.magnitude1)
if random.random() < self.p2:
img = self.operation2(img, self.magnitude2)
return img

135
utils/data_utils.py Executable file
View File

@ -0,0 +1,135 @@
import logging
from PIL import Image
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge
from .autoaugment import AutoAugImageNetPolicy
logger = logging.getLogger(__name__)
def get_loader(args):
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
if args.dataset == 'CUB_200_2011':
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = CUB(root=args.data_root, is_train=True, transform=train_transform)
testset = CUB(root=args.data_root, is_train=False, transform=test_transform)
elif args.dataset == 'car':
trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'),
os.path.join(args.data_root,'cars_train'),
os.path.join(args.data_root,'devkit/cars_meta.mat'),
# cleaned=os.path.join(data_dir,'cleaned.dat'),
transform=transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
AutoAugImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
)
testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'),
os.path.join(args.data_root,'cars_test'),
os.path.join(args.data_root,'devkit/cars_meta.mat'),
# cleaned=os.path.join(data_dir,'cleaned_test.dat'),
transform=transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
)
elif args.dataset == 'dog':
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = dogs(root=args.data_root,
train=True,
cropped=False,
transform=train_transform,
download=False
)
testset = dogs(root=args.data_root,
train=False,
cropped=False,
transform=test_transform,
download=False
)
elif args.dataset == 'nabirds':
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = NABirds(root=args.data_root, train=True, transform=train_transform)
testset = NABirds(root=args.data_root, train=False, transform=test_transform)
elif args.dataset == 'INat2017':
train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
transforms.RandomCrop((304, 304)),
transforms.RandomHorizontalFlip(),
AutoAugImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
transforms.CenterCrop((304, 304)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = INat2017(args.data_root, 'train', train_transform)
testset = INat2017(args.data_root, 'val', test_transform)
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
# transforms.CenterCrop((448, 448)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform)
if args.local_rank == 0:
torch.distributed.barrier()
train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset)
train_loader = DataLoader(trainset,
sampler=train_sampler,
batch_size=args.train_batch_size,
num_workers=4,
drop_last=True,
pin_memory=True)
test_loader = DataLoader(testset,
sampler=test_sampler,
batch_size=args.eval_batch_size,
num_workers=4,
pin_memory=True) if testset is not None else None
return train_loader, test_loader

629
utils/dataset.py Executable file
View File

@ -0,0 +1,629 @@
import os
import json
from os.path import join
import numpy as np
import scipy
from scipy import io
import scipy.misc
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
class emptyJudge():
def __init__(self, root, is_train=True, data_len=None, transform=None):
self.root = root
self.is_train = is_train
self.transform = transform
img_txt_file = open(os.path.join(self.root, 'images.txt'))
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
img_name_list = []
for line in img_txt_file:
img_name_list.append(line[:-1].split(' ')[-1])
label_list = []
for line in label_txt_file:
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
train_test_list = []
for line in train_val_file:
train_test_list.append(int(line[:-1].split(' ')[-1]))
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
self.train_imgname = [x for x in train_file_list[:data_len]]
if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
self.test_imgname = [x for x in test_file_list[:data_len]]
def __getitem__(self, index):
if self.is_train:
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = Image.fromarray(img, mode='RGB')
if self.transform is not None:
img = self.transform(img)
else:
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = Image.fromarray(img, mode='RGB')
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
if self.is_train:
return len(self.train_label)
else:
return len(self.test_label)
class CUB():
def __init__(self, root, is_train=True, data_len=None, transform=None):
self.root = root
self.is_train = is_train
self.transform = transform
img_txt_file = open(os.path.join(self.root, 'images.txt'))
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
img_name_list = []
for line in img_txt_file:
img_name_list.append(line[:-1].split(' ')[-1])
label_list = []
for line in label_txt_file:
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
train_test_list = []
for line in train_val_file:
train_test_list.append(int(line[:-1].split(' ')[-1]))
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
self.train_imgname = [x for x in train_file_list[:data_len]]
if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
self.test_imgname = [x for x in test_file_list[:data_len]]
def __getitem__(self, index):
if self.is_train:
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = Image.fromarray(img, mode='RGB')
if self.transform is not None:
img = self.transform(img)
else:
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = Image.fromarray(img, mode='RGB')
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
if self.is_train:
return len(self.train_label)
else:
return len(self.test_label)
class CarsDataset(Dataset):
def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None):
"""
Args:
mat_anno (string): Path to the MATLAB annotation file.
data_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.full_data_set = io.loadmat(mat_anno)
self.car_annotations = self.full_data_set['annotations']
self.car_annotations = self.car_annotations[0]
if cleaned is not None:
cleaned_annos = []
print("Cleaning up data set (only take pics with rgb chans)...")
clean_files = np.loadtxt(cleaned, dtype=str)
for c in self.car_annotations:
if c[-1][0] in clean_files:
cleaned_annos.append(c)
self.car_annotations = cleaned_annos
self.car_names = scipy.io.loadmat(car_names)['class_names']
self.car_names = np.array(self.car_names[0])
self.data_dir = data_dir
self.transform = transform
def __len__(self):
return len(self.car_annotations)
def __getitem__(self, idx):
img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0])
image = Image.open(img_name).convert('RGB')
car_class = self.car_annotations[idx][-2][0][0]
car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1
assert car_class < 196
if self.transform:
image = self.transform(image)
# return image, car_class, img_name
return image, car_class
def map_class(self, id):
id = np.ravel(id)
ret = self.car_names[id - 1][0][0]
return ret
def show_batch(self, img_batch, class_batch):
for i in range(img_batch.shape[0]):
ax = plt.subplot(1, img_batch.shape[0], i + 1)
title_str = self.map_class(int(class_batch[i]))
img = np.transpose(img_batch[i, ...], (1, 2, 0))
ax.imshow(img)
ax.set_title(title_str.__str__(), {'fontsize': 5})
plt.tight_layout()
def make_dataset(dir, image_ids, targets):
assert(len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class dogs(Dataset):
"""`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``omniglot-py`` exists.
cropped (bool, optional): If true, the images will be cropped into the bounding box specified
in the annotations
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset tar files from the internet and
puts it in root directory. If the tar files are already downloaded, they are not
downloaded again.
"""
folder = 'dog'
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
def __init__(self,
root,
train=True,
cropped=False,
transform=None,
target_transform=None,
download=False):
# self.root = join(os.path.expanduser(root), self.folder)
self.root = root
self.train = train
self.cropped = cropped
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
split = self.load_split()
self.images_folder = join(self.root, 'Images')
self.annotations_folder = join(self.root, 'Annotation')
self._breeds = list_dir(self.images_folder)
if self.cropped:
self._breed_annotations = [[(annotation, box, idx)
for box in self.get_boxes(join(self.annotations_folder, annotation))]
for annotation, idx in split]
self._flat_breed_annotations = sum(self._breed_annotations, [])
self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
else:
self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]
self._flat_breed_images = self._breed_images
self.classes = ["Chihuaha",
"Japanese Spaniel",
"Maltese Dog",
"Pekinese",
"Shih-Tzu",
"Blenheim Spaniel",
"Papillon",
"Toy Terrier",
"Rhodesian Ridgeback",
"Afghan Hound",
"Basset Hound",
"Beagle",
"Bloodhound",
"Bluetick",
"Black-and-tan Coonhound",
"Walker Hound",
"English Foxhound",
"Redbone",
"Borzoi",
"Irish Wolfhound",
"Italian Greyhound",
"Whippet",
"Ibizian Hound",
"Norwegian Elkhound",
"Otterhound",
"Saluki",
"Scottish Deerhound",
"Weimaraner",
"Staffordshire Bullterrier",
"American Staffordshire Terrier",
"Bedlington Terrier",
"Border Terrier",
"Kerry Blue Terrier",
"Irish Terrier",
"Norfolk Terrier",
"Norwich Terrier",
"Yorkshire Terrier",
"Wirehaired Fox Terrier",
"Lakeland Terrier",
"Sealyham Terrier",
"Airedale",
"Cairn",
"Australian Terrier",
"Dandi Dinmont",
"Boston Bull",
"Miniature Schnauzer",
"Giant Schnauzer",
"Standard Schnauzer",
"Scotch Terrier",
"Tibetan Terrier",
"Silky Terrier",
"Soft-coated Wheaten Terrier",
"West Highland White Terrier",
"Lhasa",
"Flat-coated Retriever",
"Curly-coater Retriever",
"Golden Retriever",
"Labrador Retriever",
"Chesapeake Bay Retriever",
"German Short-haired Pointer",
"Vizsla",
"English Setter",
"Irish Setter",
"Gordon Setter",
"Brittany",
"Clumber",
"English Springer Spaniel",
"Welsh Springer Spaniel",
"Cocker Spaniel",
"Sussex Spaniel",
"Irish Water Spaniel",
"Kuvasz",
"Schipperke",
"Groenendael",
"Malinois",
"Briard",
"Kelpie",
"Komondor",
"Old English Sheepdog",
"Shetland Sheepdog",
"Collie",
"Border Collie",
"Bouvier des Flandres",
"Rottweiler",
"German Shepard",
"Doberman",
"Miniature Pinscher",
"Greater Swiss Mountain Dog",
"Bernese Mountain Dog",
"Appenzeller",
"EntleBucher",
"Boxer",
"Bull Mastiff",
"Tibetan Mastiff",
"French Bulldog",
"Great Dane",
"Saint Bernard",
"Eskimo Dog",
"Malamute",
"Siberian Husky",
"Affenpinscher",
"Basenji",
"Pug",
"Leonberg",
"Newfoundland",
"Great Pyrenees",
"Samoyed",
"Pomeranian",
"Chow",
"Keeshond",
"Brabancon Griffon",
"Pembroke",
"Cardigan",
"Toy Poodle",
"Miniature Poodle",
"Standard Poodle",
"Mexican Hairless",
"Dingo",
"Dhole",
"African Hunting Dog"]
def __len__(self):
return len(self._flat_breed_images)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target character class.
"""
image_name, target_class = self._flat_breed_images[index]
image_path = join(self.images_folder, image_name)
image = Image.open(image_path).convert('RGB')
if self.cropped:
image = image.crop(self._flat_breed_annotations[index][1])
if self.transform:
image = self.transform(image)
if self.target_transform:
target_class = self.target_transform(target_class)
return image, target_class
def download(self):
import tarfile
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
print('Files already downloaded and verified')
return
for filename in ['images', 'annotation', 'lists']:
tar_filename = filename + '.tar'
url = self.download_url_prefix + '/' + tar_filename
download_url(url, self.root, tar_filename, None)
print('Extracting downloaded file: ' + join(self.root, tar_filename))
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
tar_file.extractall(self.root)
os.remove(join(self.root, tar_filename))
@staticmethod
def get_boxes(path):
import xml.etree.ElementTree
e = xml.etree.ElementTree.parse(path).getroot()
boxes = []
for objs in e.iter('object'):
boxes.append([int(objs.find('bndbox').find('xmin').text),
int(objs.find('bndbox').find('ymin').text),
int(objs.find('bndbox').find('xmax').text),
int(objs.find('bndbox').find('ymax').text)])
return boxes
def load_split(self):
if self.train:
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
else:
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
split = [item[0][0] for item in split]
labels = [item[0]-1 for item in labels]
return list(zip(split, labels))
def stats(self):
counts = {}
for index in range(len(self._flat_breed_images)):
image_name, target_class = self._flat_breed_images[index]
if target_class not in counts.keys():
counts[target_class] = 1
else:
counts[target_class] += 1
print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))
return counts
class NABirds(Dataset):
"""`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'nabirds/images'
def __init__(self, root, train=True, transform=None):
dataset_path = os.path.join(root, 'nabirds')
self.root = root
self.loader = default_loader
self.train = train
self.transform = transform
image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
sep=' ', names=['img_id', 'filepath'])
image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
sep=' ', names=['img_id', 'target'])
# Since the raw labels are non-continuous, map them to new ones
self.label_map = get_continuous_class_map(image_class_labels['target'])
train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
sep=' ', names=['img_id', 'is_training_img'])
data = image_paths.merge(image_class_labels, on='img_id')
self.data = data.merge(train_test_split, on='img_id')
# Load in the train / test split
if self.train:
self.data = self.data[self.data.is_training_img == 1]
else:
self.data = self.data[self.data.is_training_img == 0]
# Load in the class data
self.class_names = load_class_names(dataset_path)
self.class_hierarchy = load_hierarchy(dataset_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data.iloc[idx]
path = os.path.join(self.root, self.base_folder, sample.filepath)
target = self.label_map[sample.target]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img, target
def get_continuous_class_map(class_labels):
label_set = set(class_labels)
return {k: i for i, k in enumerate(label_set)}
def load_class_names(dataset_path=''):
names = {}
with open(os.path.join(dataset_path, 'classes.txt')) as f:
for line in f:
pieces = line.strip().split()
class_id = pieces[0]
names[class_id] = ' '.join(pieces[1:])
return names
def load_hierarchy(dataset_path=''):
parents = {}
with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
for line in f:
pieces = line.strip().split()
child_id, parent_id = pieces
parents[child_id] = parent_id
return parents
class INat2017(VisionDataset):
"""`iNaturalist 2017 <https://github.com/visipedia/inat_comp/blob/master/2017/README.md>`_ Dataset.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'train_val_images/'
file_list = {
'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
'train_val_images.tar.gz',
'7c784ea5e424efaec655bd392f87301f'),
'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
'train_val2017.zip',
'444c835f6459867ad69fcb36478786e7')
}
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = default_loader
self.split = verify_str_arg(split, "split", ("train", "val",))
if self._check_exists():
print('Files already downloaded and verified.')
elif download:
if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
print('Downloading...')
self._download()
print('Extracting...')
extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
else:
raise RuntimeError(
'Dataset not found. You can use download=True to download it.')
anno_filename = split + '2017.json'
with open(os.path.join(self.root, anno_filename), 'r') as fp:
all_annos = json.load(fp)
self.annos = all_annos['annotations']
self.images = all_annos['images']
def __getitem__(self, index):
path = os.path.join(self.root, self.images[index]['file_name'])
target = self.annos[index]['category_id']
image = self.loader(path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def __len__(self):
return len(self.images)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.base_folder))
def _download(self):
for url, filename, md5 in self.file_list.values():
download_url(url, root=self.root, filename=filename)
if not check_integrity(os.path.join(self.root, filename), md5):
raise RuntimeError("File not found or corrupted.")

30
utils/dist_util.py Executable file
View File

@ -0,0 +1,30 @@
import torch.distributed as dist
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def is_main_process():
return get_rank() == 0
def format_step(step):
if isinstance(step, str):
return step
s = ""
if len(step) > 0:
s += "Training Epoch: {} ".format(step[0])
if len(step) > 1:
s += "Training Iteration: {} ".format(step[1])
if len(step) > 2:
s += "Validation Iteration: {} ".format(step[2])
return s

63
utils/scheduler.py Executable file
View File

@ -0,0 +1,63 @@
import logging
import math
from torch.optim.lr_scheduler import LambdaLR
logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
""" Constant learning rate schedule.
"""
def __init__(self, optimizer, last_epoch=-1):
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
class WarmupConstantSchedule(LambdaLR):
""" Linear warmup and then constant.
Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
Keeps learning rate schedule equal to 1. after warmup_steps.
"""
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
self.warmup_steps = warmup_steps
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
return 1.
class WarmupLinearSchedule(LambdaLR):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
class WarmupCosineSchedule(LambdaLR):
""" Linear warmup and then cosine decay.
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
"""
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.cycles = cycles
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
# progress after warmup
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))