update
This commit is contained in:
7
.gitignore
vendored
Executable file
7
.gitignore
vendored
Executable file
@ -0,0 +1,7 @@
|
||||
.idea/
|
||||
ckpts/
|
||||
logs/
|
||||
models/__pycache__/
|
||||
utils/__pycache__/
|
||||
output/
|
||||
attention_data/
|
21
LICENSE
Executable file
21
LICENSE
Executable file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Ju He
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
BIN
docs/TransFG.pdf
Executable file
BIN
docs/TransFG.pdf
Executable file
Binary file not shown.
BIN
docs/TransFG.png
Executable file
BIN
docs/TransFG.png
Executable file
Binary file not shown.
After Width: | Height: | Size: 733 KiB |
136
ieemoo-ai-isempty.py
Normal file
136
ieemoo-ai-isempty.py
Normal file
@ -0,0 +1,136 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from flask import request, Flask
|
||||
import numpy as np
|
||||
import json
|
||||
import time
|
||||
import cv2, base64
|
||||
import argparse
|
||||
import sys, os
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from models.modeling import VisionTransformer, CONFIGS
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
app.use_reloader=False
|
||||
|
||||
|
||||
def parse_args(model_file="ckpts/emptyjudge5_checkpoint.bin"):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||
parser.add_argument('--split', type=str, default='overlap', help="Split method")
|
||||
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
|
||||
parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model")
|
||||
opt, unknown = parser.parse_known_args()
|
||||
return opt
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(self.args.device)
|
||||
self.args.nprocs = torch.cuda.device_count()
|
||||
self.cls_dict = {}
|
||||
self.num_classes = 0
|
||||
self.model = None
|
||||
self.prepare_model()
|
||||
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
|
||||
def prepare_model(self):
|
||||
config = CONFIGS["ViT-B_16"]
|
||||
config.split = self.args.split
|
||||
config.slide_step = self.args.slide_step
|
||||
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
|
||||
print("use model_name: ", model_name)
|
||||
self.num_classes = 5
|
||||
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
||||
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
||||
if self.args.pretrained_model is not None:
|
||||
if not torch.cuda.is_available():
|
||||
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
|
||||
self.model.load_state_dict(pretrained_model)
|
||||
else:
|
||||
pretrained_model = torch.load(self.args.pretrained_model)['model']
|
||||
self.model.load_state_dict(pretrained_model)
|
||||
self.model.eval()
|
||||
self.model.to(self.args.device)
|
||||
#self.model.eval()
|
||||
|
||||
def normal_predict(self, img_data, result):
|
||||
# img = Image.open(img_path)
|
||||
if img_data is None:
|
||||
print('error, img data is None')
|
||||
return result
|
||||
else:
|
||||
with torch.no_grad():
|
||||
x = self.test_transform(img_data)
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda()
|
||||
part_logits = self.model(x.unsqueeze(0))
|
||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||||
clas_ids = topN[0][0]
|
||||
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
||||
print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
|
||||
result["success"] = "true"
|
||||
result["rst_cls"] = str(clas_ids)
|
||||
return result
|
||||
|
||||
|
||||
model_file ="/data/ieemoo/emptypredict_pfc_FG/ckpts/emptyjudge5_checkpoint.bin"
|
||||
args = parse_args(model_file)
|
||||
predictor = Predictor(args)
|
||||
|
||||
|
||||
@app.route("/isempty", methods=['POST'])
|
||||
def get_isempty():
|
||||
start = time.time()
|
||||
print('--------------------EmptyPredict-----------------')
|
||||
data = request.get_data()
|
||||
ip = request.remote_addr
|
||||
print('------ ip = %s ------' % ip)
|
||||
|
||||
json_data = json.loads(data.decode("utf-8"))
|
||||
getdateend = time.time()
|
||||
print('get date use time: {0:.2f}s'.format(getdateend - start))
|
||||
|
||||
pic = json_data.get("pic")
|
||||
result = {"success": "false",
|
||||
"rst_cls": '-1',
|
||||
}
|
||||
try:
|
||||
imgdata = base64.b64decode(pic)
|
||||
imgdata_np = np.frombuffer(imgdata, dtype='uint8')
|
||||
img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
|
||||
img_data = Image.fromarray(np.uint8(img_src))
|
||||
result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
|
||||
except:
|
||||
return repr(result)
|
||||
|
||||
return repr(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
# app.run("0.0.0.0", port=8083)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
4
init.sh
Normal file
4
init.sh
Normal file
@ -0,0 +1,4 @@
|
||||
/opt/miniconda3/bin/conda activate ieemoo
|
||||
|
||||
/opt/miniconda3/envs/ieemoo/bin/pip install -r requirements.txt
|
||||
|
76
models/configs.py
Executable file
76
models/configs.py
Executable file
@ -0,0 +1,76 @@
|
||||
import ml_collections
|
||||
|
||||
def get_testing():
|
||||
"""Returns a minimal configuration for testing."""
|
||||
config = ml_collections.ConfigDict()
|
||||
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
||||
config.hidden_size = 1
|
||||
config.transformer = ml_collections.ConfigDict()
|
||||
config.transformer.mlp_dim = 1
|
||||
config.transformer.num_heads = 1
|
||||
config.transformer.num_layers = 1
|
||||
config.transformer.attention_dropout_rate = 0.0
|
||||
config.transformer.dropout_rate = 0.1
|
||||
config.classifier = 'token'
|
||||
config.representation_size = None
|
||||
return config
|
||||
|
||||
|
||||
def get_b16_config():
|
||||
"""Returns the ViT-B/16 configuration."""
|
||||
config = ml_collections.ConfigDict()
|
||||
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
||||
config.split = 'non-overlap'
|
||||
config.slide_step = 12
|
||||
config.hidden_size = 768
|
||||
config.transformer = ml_collections.ConfigDict()
|
||||
config.transformer.mlp_dim = 3072
|
||||
config.transformer.num_heads = 12
|
||||
config.transformer.num_layers = 12
|
||||
config.transformer.attention_dropout_rate = 0.0
|
||||
config.transformer.dropout_rate = 0.1
|
||||
config.classifier = 'token'
|
||||
config.representation_size = None
|
||||
return config
|
||||
|
||||
def get_b32_config():
|
||||
"""Returns the ViT-B/32 configuration."""
|
||||
config = get_b16_config()
|
||||
config.patches.size = (32, 32)
|
||||
return config
|
||||
|
||||
def get_l16_config():
|
||||
"""Returns the ViT-L/16 configuration."""
|
||||
config = ml_collections.ConfigDict()
|
||||
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
||||
config.hidden_size = 1024
|
||||
config.transformer = ml_collections.ConfigDict()
|
||||
config.transformer.mlp_dim = 4096
|
||||
config.transformer.num_heads = 16
|
||||
config.transformer.num_layers = 24
|
||||
config.transformer.attention_dropout_rate = 0.0
|
||||
config.transformer.dropout_rate = 0.1
|
||||
config.classifier = 'token'
|
||||
config.representation_size = None
|
||||
return config
|
||||
|
||||
def get_l32_config():
|
||||
"""Returns the ViT-L/32 configuration."""
|
||||
config = get_l16_config()
|
||||
config.patches.size = (32, 32)
|
||||
return config
|
||||
|
||||
def get_h14_config():
|
||||
"""Returns the ViT-L/16 configuration."""
|
||||
config = ml_collections.ConfigDict()
|
||||
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
|
||||
config.hidden_size = 1280
|
||||
config.transformer = ml_collections.ConfigDict()
|
||||
config.transformer.mlp_dim = 5120
|
||||
config.transformer.num_heads = 16
|
||||
config.transformer.num_layers = 32
|
||||
config.transformer.attention_dropout_rate = 0.0
|
||||
config.transformer.dropout_rate = 0.1
|
||||
config.classifier = 'token'
|
||||
config.representation_size = None
|
||||
return config
|
390
models/modeling.py
Executable file
390
models/modeling.py
Executable file
@ -0,0 +1,390 @@
|
||||
# coding=utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
from os.path import join as pjoin
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
||||
from torch.nn.modules.utils import _pair
|
||||
from scipy import ndimage
|
||||
|
||||
import models.configs as configs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
||||
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
||||
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
||||
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
||||
FC_0 = "MlpBlock_3/Dense_0"
|
||||
FC_1 = "MlpBlock_3/Dense_1"
|
||||
ATTENTION_NORM = "LayerNorm_0"
|
||||
MLP_NORM = "LayerNorm_2"
|
||||
|
||||
def np2th(weights, conv=False):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv:
|
||||
weights = weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(weights)
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
class LabelSmoothing(nn.Module):
|
||||
"""
|
||||
NLL loss with label smoothing.
|
||||
"""
|
||||
def __init__(self, smoothing=0.0):
|
||||
"""
|
||||
Constructor for the LabelSmoothing module.
|
||||
:param smoothing: label smoothing factor
|
||||
"""
|
||||
super(LabelSmoothing, self).__init__()
|
||||
self.confidence = 1.0 - smoothing
|
||||
self.smoothing = smoothing
|
||||
|
||||
def forward(self, x, target):
|
||||
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
||||
|
||||
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
||||
nll_loss = nll_loss.squeeze(1)
|
||||
smooth_loss = -logprobs.mean(dim=-1)
|
||||
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||
return loss.mean()
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Attention, self).__init__()
|
||||
self.num_attention_heads = config.transformer["num_heads"]
|
||||
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.out = Linear(config.hidden_size, config.hidden_size)
|
||||
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
||||
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
||||
|
||||
self.softmax = Softmax(dim=-1)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
weights = attention_probs
|
||||
attention_probs = self.attn_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
attention_output = self.out(context_layer)
|
||||
attention_output = self.proj_dropout(attention_output)
|
||||
return attention_output, weights
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Mlp, self).__init__()
|
||||
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
|
||||
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
|
||||
self.act_fn = ACT2FN["gelu"]
|
||||
self.dropout = Dropout(config.transformer["dropout_rate"])
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
nn.init.xavier_uniform_(self.fc1.weight)
|
||||
nn.init.xavier_uniform_(self.fc2.weight)
|
||||
nn.init.normal_(self.fc1.bias, std=1e-6)
|
||||
nn.init.normal_(self.fc2.bias, std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class Embeddings(nn.Module):
|
||||
"""Construct the embeddings from patch, position embeddings.
|
||||
"""
|
||||
def __init__(self, config, img_size, in_channels=3):
|
||||
super(Embeddings, self).__init__()
|
||||
self.hybrid = None
|
||||
img_size = _pair(img_size)
|
||||
|
||||
patch_size = _pair(config.patches["size"])
|
||||
if config.split == 'non-overlap':
|
||||
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
||||
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
elif config.split == 'overlap':
|
||||
n_patches = ((img_size[0] - patch_size[0]) // config.slide_step + 1) * ((img_size[1] - patch_size[1]) // config.slide_step + 1)
|
||||
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=patch_size,
|
||||
stride=(config.slide_step, config.slide_step))
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
|
||||
self.dropout = Dropout(config.transformer["dropout_rate"])
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
|
||||
if self.hybrid:
|
||||
x = self.hybrid_model(x)
|
||||
x = self.patch_embeddings(x)
|
||||
x = x.flatten(2)
|
||||
x = x.transpose(-1, -2)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
embeddings = x + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Block, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
||||
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
||||
self.ffn = Mlp(config)
|
||||
self.attn = Attention(config)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
x = self.attention_norm(x)
|
||||
x, weights = self.attn(x)
|
||||
x = x + h
|
||||
|
||||
h = x
|
||||
x = self.ffn_norm(x)
|
||||
x = self.ffn(x)
|
||||
x = x + h
|
||||
return x, weights
|
||||
|
||||
def load_from(self, weights, n_block):
|
||||
ROOT = f"Transformer/encoderblock_{n_block}"
|
||||
with torch.no_grad():
|
||||
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
||||
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
||||
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
||||
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
||||
|
||||
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
|
||||
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
|
||||
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
|
||||
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
|
||||
|
||||
self.attn.query.weight.copy_(query_weight)
|
||||
self.attn.key.weight.copy_(key_weight)
|
||||
self.attn.value.weight.copy_(value_weight)
|
||||
self.attn.out.weight.copy_(out_weight)
|
||||
self.attn.query.bias.copy_(query_bias)
|
||||
self.attn.key.bias.copy_(key_bias)
|
||||
self.attn.value.bias.copy_(value_bias)
|
||||
self.attn.out.bias.copy_(out_bias)
|
||||
|
||||
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
|
||||
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
|
||||
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
|
||||
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
|
||||
|
||||
self.ffn.fc1.weight.copy_(mlp_weight_0)
|
||||
self.ffn.fc2.weight.copy_(mlp_weight_1)
|
||||
self.ffn.fc1.bias.copy_(mlp_bias_0)
|
||||
self.ffn.fc2.bias.copy_(mlp_bias_1)
|
||||
|
||||
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
|
||||
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
|
||||
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
||||
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
||||
|
||||
|
||||
class Part_Attention(nn.Module):
|
||||
def __init__(self):
|
||||
super(Part_Attention, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
length = len(x)
|
||||
last_map = x[0]
|
||||
for i in range(1, length):
|
||||
last_map = torch.matmul(x[i], last_map)
|
||||
last_map = last_map[:,:,0,1:]
|
||||
|
||||
_, max_inx = last_map.max(2)
|
||||
return _, max_inx
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Encoder, self).__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
for _ in range(config.transformer["num_layers"] - 1):
|
||||
layer = Block(config)
|
||||
self.layer.append(copy.deepcopy(layer))
|
||||
self.part_select = Part_Attention()
|
||||
self.part_layer = Block(config)
|
||||
self.part_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
attn_weights = []
|
||||
for layer in self.layer:
|
||||
hidden_states, weights = layer(hidden_states)
|
||||
attn_weights.append(weights)
|
||||
part_num, part_inx = self.part_select(attn_weights)
|
||||
part_inx = part_inx + 1
|
||||
parts = []
|
||||
B, num = part_inx.shape
|
||||
for i in range(B):
|
||||
parts.append(hidden_states[i, part_inx[i,:]])
|
||||
parts = torch.stack(parts).squeeze(1)
|
||||
concat = torch.cat((hidden_states[:,0].unsqueeze(1), parts), dim=1)
|
||||
part_states, part_weights = self.part_layer(concat)
|
||||
part_encoded = self.part_norm(part_states)
|
||||
|
||||
return part_encoded
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config, img_size):
|
||||
super(Transformer, self).__init__()
|
||||
self.embeddings = Embeddings(config, img_size=img_size)
|
||||
self.encoder = Encoder(config)
|
||||
|
||||
def forward(self, input_ids):
|
||||
embedding_output = self.embeddings(input_ids)
|
||||
part_encoded = self.encoder(embedding_output)
|
||||
return part_encoded
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, config, img_size=224, num_classes=21843, smoothing_value=0, zero_head=False):
|
||||
super(VisionTransformer, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.smoothing_value = smoothing_value
|
||||
self.zero_head = zero_head
|
||||
self.classifier = config.classifier
|
||||
self.transformer = Transformer(config, img_size)
|
||||
self.part_head = Linear(config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
part_tokens = self.transformer(x)
|
||||
part_logits = self.part_head(part_tokens[:, 0])
|
||||
|
||||
if labels is not None:
|
||||
if self.smoothing_value == 0:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
else:
|
||||
loss_fct = LabelSmoothing(self.smoothing_value)
|
||||
part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
|
||||
contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1))
|
||||
loss = part_loss + contrast_loss
|
||||
return loss, part_logits
|
||||
else:
|
||||
return part_logits
|
||||
|
||||
def load_from(self, weights):
|
||||
with torch.no_grad():
|
||||
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
|
||||
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
|
||||
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
|
||||
self.transformer.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
|
||||
self.transformer.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
|
||||
|
||||
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
|
||||
posemb_new = self.transformer.embeddings.position_embeddings
|
||||
if posemb.size() == posemb_new.size():
|
||||
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
||||
else:
|
||||
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
|
||||
ntok_new = posemb_new.size(1)
|
||||
|
||||
if self.classifier == "token":
|
||||
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
||||
ntok_new -= 1
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
|
||||
gs_old = int(np.sqrt(len(posemb_grid)))
|
||||
gs_new = int(np.sqrt(ntok_new))
|
||||
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
|
||||
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
||||
|
||||
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
||||
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
|
||||
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
|
||||
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
|
||||
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
|
||||
|
||||
for bname, block in self.transformer.encoder.named_children():
|
||||
if bname.startswith('part') == False:
|
||||
for uname, unit in block.named_children():
|
||||
unit.load_from(weights, n_block=uname)
|
||||
|
||||
if self.transformer.embeddings.hybrid:
|
||||
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
|
||||
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
|
||||
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
|
||||
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
|
||||
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
|
||||
|
||||
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
|
||||
for uname, unit in block.named_children():
|
||||
unit.load_from(weights, n_block=bname, n_unit=uname)
|
||||
|
||||
|
||||
def con_loss(features, labels):
|
||||
B, _ = features.shape
|
||||
features = F.normalize(features)
|
||||
cos_matrix = features.mm(features.t())
|
||||
pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float()
|
||||
neg_label_matrix = 1 - pos_label_matrix
|
||||
pos_cos_matrix = 1 - cos_matrix
|
||||
neg_cos_matrix = cos_matrix - 0.4
|
||||
neg_cos_matrix[neg_cos_matrix < 0] = 0
|
||||
loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum()
|
||||
loss /= (B * B)
|
||||
return loss
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
'ViT-B_16': configs.get_b16_config(),
|
||||
'ViT-B_32': configs.get_b32_config(),
|
||||
'ViT-L_16': configs.get_l16_config(),
|
||||
'ViT-L_32': configs.get_l32_config(),
|
||||
'ViT-H_14': configs.get_h14_config(),
|
||||
'testing': configs.get_testing(),
|
||||
}
|
153
predict.py
Executable file
153
predict.py
Executable file
@ -0,0 +1,153 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import time
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.metrics import f1_score
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from models.modeling import VisionTransformer, CONFIGS
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
|
||||
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
||||
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("self.args.device =", self.args.device)
|
||||
self.args.nprocs = torch.cuda.device_count()
|
||||
|
||||
self.cls_dict = {}
|
||||
self.num_classes = 0
|
||||
self.model = None
|
||||
self.prepare_model()
|
||||
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
|
||||
def prepare_model(self):
|
||||
config = CONFIGS["ViT-B_16"]
|
||||
config.split = self.args.split
|
||||
config.slide_step = self.args.slide_step
|
||||
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
|
||||
print("use model_name: ", model_name)
|
||||
if model_name.lower() == "emptyJudge5".lower():
|
||||
self.num_classes = 5
|
||||
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
||||
elif model_name.lower() == "emptyJudge4".lower():
|
||||
self.num_classes = 4
|
||||
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
|
||||
elif model_name.lower() == "emptyJudge3".lower():
|
||||
self.num_classes = 3
|
||||
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
|
||||
elif model_name.lower() == "emptyJudge2".lower():
|
||||
self.num_classes = 2
|
||||
self.cls_dict = {0: "noemp", 1: "yesemp"}
|
||||
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
||||
if self.args.pretrained_model is not None:
|
||||
if not torch.cuda.is_available():
|
||||
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
|
||||
self.model.load_state_dict(pretrained_model)
|
||||
else:
|
||||
pretrained_model = torch.load(self.args.pretrained_model)['model']
|
||||
self.model.load_state_dict(pretrained_model)
|
||||
self.model.to(self.args.device)
|
||||
self.model.eval()
|
||||
|
||||
def normal_predict(self, img_path):
|
||||
# img = cv2.imread(img_path)
|
||||
img = Image.open(img_path)
|
||||
if img is None:
|
||||
print(
|
||||
"Image file failed to read: {}".format(img_path))
|
||||
else:
|
||||
x = self.test_transform(img)
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda()
|
||||
part_logits = self.model(x.unsqueeze(0))
|
||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||||
clas_ids = topN[0][0]
|
||||
# print(probs[0, topN[0][0]].item())
|
||||
return clas_ids, probs[0, clas_ids].item()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
predictor = Predictor(args)
|
||||
|
||||
y_true = []
|
||||
y_pred = []
|
||||
test_dir = "/data/pfc/fineGrained/test_5cls"
|
||||
dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
|
||||
total = 0
|
||||
num = 0
|
||||
t0 = time.time()
|
||||
for dir_name, label in dir_dict.items():
|
||||
cur_folder = os.path.join(test_dir, dir_name)
|
||||
errorPath = os.path.join(test_dir, dir_name + "_error")
|
||||
# os.makedirs(errorPath, exist_ok=True)
|
||||
for cur_file in os.listdir(cur_folder):
|
||||
total += 1
|
||||
print("%d processing: %s" % (total, cur_file))
|
||||
cur_img_file = os.path.join(cur_folder, cur_file)
|
||||
error_img_dst = os.path.join(errorPath, cur_file)
|
||||
cur_pred, pred_score = predictor.normal_predict(cur_img_file)
|
||||
|
||||
label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label)
|
||||
cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred)
|
||||
y_true.append(int(label))
|
||||
y_pred.append(int(cur_pred))
|
||||
if int(label) == int(cur_pred):
|
||||
num += 1
|
||||
# else:
|
||||
# print(cur_file, "predict: ", cur_pred, "true: ", int(label))
|
||||
# print(cur_file, "predict: ", cur_pred, "true: ", int(label), "pred_score:", pred_score)
|
||||
# os.system("cp %s %s" % (cur_img_file, error_img_dst))
|
||||
t1 = time.time()
|
||||
print('The cast of time is :%f seconds' % (t1-t0))
|
||||
rate = float(num)/total
|
||||
print('The classification accuracy is %f' % rate)
|
||||
|
||||
rst_C = confusion_matrix(y_true, y_pred)
|
||||
rst_f1 = f1_score(y_true, y_pred, average='macro')
|
||||
print(rst_C)
|
||||
print(rst_f1)
|
||||
|
||||
'''
|
||||
test_imgs: yesemp=145, noemp=453 大图
|
||||
|
||||
output/emptyjudge5_checkpoint.bin
|
||||
The classification accuracy is 0.976589
|
||||
[[446 7] 1.5%
|
||||
[ 7 138]] 4.8%
|
||||
0.968135799649844
|
||||
|
||||
output/emptyjudge4_checkpoint.bin
|
||||
The classification accuracy is 0.976589
|
||||
[[450 3] 0.6%
|
||||
[ 11 134]] 7.5%
|
||||
0.9675186616384996
|
||||
|
||||
test_5cls: yesemp=319, noemp=925 小图
|
||||
|
||||
|
||||
output/emptyjudge4_checkpoint.bin
|
||||
The classification accuracy is 0.937299
|
||||
[[885 40] 4.3%
|
||||
[ 38 281]] 11.9%
|
||||
0.9179586038961038
|
||||
|
||||
|
||||
'''
|
119
prepara_data.py
Executable file
119
prepara_data.py
Executable file
@ -0,0 +1,119 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import random
|
||||
|
||||
|
||||
# ----------- 改写名称 --------------
|
||||
# index = 0
|
||||
# src_dir = "/data/fineGrained/emptyJudge5"
|
||||
# dst_dir = src_dir + "_new"
|
||||
# os.makedirs(dst_dir, exist_ok=True)
|
||||
# for sub in os.listdir(src_dir):
|
||||
# sub_path = os.path.join(src_dir, sub)
|
||||
# sub_path_dst = os.path.join(dst_dir, sub)
|
||||
# os.makedirs(sub_path_dst, exist_ok=True)
|
||||
# for cur_f in os.listdir(sub_path):
|
||||
# cur_img = os.path.join(sub_path, cur_f)
|
||||
# cur_img_dst = os.path.join(sub_path_dst, "a%05d.jpg" % index)
|
||||
# index += 1
|
||||
# os.system("mv %s %s" % (cur_img, cur_img_dst))
|
||||
|
||||
|
||||
# ----------- 删除过小图像 --------------
|
||||
# src_dir = "/data/fineGrained/emptyJudge5"
|
||||
# for sub in os.listdir(src_dir):
|
||||
# sub_path = os.path.join(src_dir, sub)
|
||||
# for cur_f in os.listdir(sub_path):
|
||||
# filepath = os.path.join(sub_path, cur_f)
|
||||
# res = subprocess.check_output(['file', filepath])
|
||||
# pp = res.decode("utf-8").split(",")[-2]
|
||||
# height = int(pp.split("x")[1])
|
||||
# width = int(pp.split("x")[0])
|
||||
# min_l = min(height, width)
|
||||
# if min_l <= 448:
|
||||
# os.system("rm %s" % filepath)
|
||||
|
||||
|
||||
# ----------- 获取有效图片并写images.txt --------------
|
||||
# src_dir = "/data/fineGrained/emptyJudge4/images"
|
||||
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
|
||||
# all_dict = {"yesemp":[], "noemp":[], "hard": [], "stack": []}
|
||||
# for sub, value in src_dict.items():
|
||||
# sub_path = os.path.join(src_dir, sub)
|
||||
# for cur_f in os.listdir(sub_path):
|
||||
# all_dict[sub].append(os.path.join(sub, cur_f))
|
||||
#
|
||||
# yesnum = len(all_dict["yesemp"])
|
||||
# nonum = len(all_dict["noemp"])
|
||||
# hardnum = len(all_dict["hard"])
|
||||
# stacknum = len(all_dict["stack"])
|
||||
# thnum = min(yesnum, nonum, hardnum, stacknum)
|
||||
# images_txt = src_dir + ".txt"
|
||||
# index = 1
|
||||
#
|
||||
# def write_images(cur_list, thnum, fw, index):
|
||||
# for feat_path in random.sample(cur_list, thnum):
|
||||
# fw.write(str(index) + " " + feat_path + "\n")
|
||||
# index += 1
|
||||
# return index
|
||||
#
|
||||
# with open(images_txt, "w") as fw:
|
||||
# index = write_images(all_dict["noemp"], thnum, fw, index)
|
||||
# index = write_images(all_dict["yesemp"], thnum, fw, index)
|
||||
# index = write_images(all_dict["hard"], thnum, fw, index)
|
||||
# index = write_images(all_dict["stack"], thnum, fw, index)
|
||||
|
||||
# ----------- 写 image_class_labels.txt + train_test_split.txt --------------
|
||||
# src_dir = "/data/fineGrained/emptyJudge4"
|
||||
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
|
||||
# images_txt = os.path.join(src_dir, "images.txt")
|
||||
# image_class_labels_txt = os.path.join(src_dir, "image_class_labels.txt")
|
||||
# imgs_cnt = 0
|
||||
# with open(image_class_labels_txt, "w") as fw:
|
||||
# with open(images_txt, "r") as fr:
|
||||
# for cur_l in fr:
|
||||
# imgs_cnt += 1
|
||||
# img_index, img_f = cur_l.strip().split(" ")
|
||||
# folder_name = img_f.split("/")[0]
|
||||
# if folder_name in src_dict:
|
||||
# cur_line = img_index + " " + str(int(src_dict[folder_name])+1)
|
||||
# fw.write(cur_line + "\n")
|
||||
#
|
||||
# train_num = int(imgs_cnt*0.85)
|
||||
# print("train_num= ", train_num, ", imgs_cnt= ", imgs_cnt)
|
||||
# all_list = [1]*train_num + [0]*(imgs_cnt-train_num)
|
||||
# assert len(all_list) == imgs_cnt
|
||||
# random.shuffle(all_list)
|
||||
# train_test_split_txt = os.path.join(src_dir, "train_test_split.txt")
|
||||
# with open(train_test_split_txt, "w") as fw:
|
||||
# with open(images_txt, "r") as fr:
|
||||
# for cur_l in fr:
|
||||
# img_index, img_f = cur_l.strip().split(" ")
|
||||
# cur_line = img_index + " " + str(all_list[int(img_index) - 1])
|
||||
# fw.write(cur_line + "\n")
|
||||
|
||||
# ----------- 生成标准测试集 --------------
|
||||
# src_dir = "/data/fineGrained/emptyJudge5/images"
|
||||
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
|
||||
# all_dict = {"noemp":[], "yesemp":[], "hard": [], "fly": [], "stack": []}
|
||||
# for sub, value in src_dict.items():
|
||||
# sub_path = os.path.join(src_dir, sub)
|
||||
# for cur_f in os.listdir(sub_path):
|
||||
# all_dict[sub].append(cur_f)
|
||||
#
|
||||
# dst_dir = src_dir + "_test"
|
||||
# os.makedirs(dst_dir, exist_ok=True)
|
||||
# for sub, value in src_dict.items():
|
||||
# sub_path = os.path.join(src_dir, sub)
|
||||
# sub_path_dst = os.path.join(dst_dir, sub)
|
||||
# os.makedirs(sub_path_dst, exist_ok=True)
|
||||
#
|
||||
# cur_list = all_dict[sub]
|
||||
# test_num = int(len(cur_list) * 0.05)
|
||||
# for cur_f in random.sample(cur_list, test_num):
|
||||
# cur_path = os.path.join(sub_path, cur_f)
|
||||
# cur_path_dst = os.path.join(sub_path_dst, cur_f)
|
||||
# os.system("cp %s %s" % (cur_path, cur_path_dst))
|
||||
|
80
requirements.txt
Executable file
80
requirements.txt
Executable file
@ -0,0 +1,80 @@
|
||||
absl-py==1.0.0
|
||||
Bottleneck==1.3.2
|
||||
brotlipy==0.7.0
|
||||
cachetools==5.0.0
|
||||
certifi==2021.10.8
|
||||
cffi @ file:///tmp/build/80754af9/cffi_1625807838443/work
|
||||
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
|
||||
click==8.0.3
|
||||
contextlib2==21.6.0
|
||||
cryptography @ file:///tmp/build/80754af9/cryptography_1635366571107/work
|
||||
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
|
||||
docopt==0.6.2
|
||||
esdk-obs-python==3.21.8
|
||||
faiss==1.7.1
|
||||
Flask @ file:///tmp/build/80754af9/flask_1634118196080/work
|
||||
fonttools==4.25.0
|
||||
gevent @ file:///tmp/build/80754af9/gevent_1628273677693/work
|
||||
google-auth==2.6.0
|
||||
google-auth-oauthlib==0.4.6
|
||||
greenlet @ file:///tmp/build/80754af9/greenlet_1628887725296/work
|
||||
grpcio==1.44.0
|
||||
gunicorn==20.1.0
|
||||
h5py @ file:///tmp/build/80754af9/h5py_1637138879700/work
|
||||
idna @ file:///tmp/build/80754af9/idna_1637925883363/work
|
||||
importlib-metadata==4.11.3
|
||||
itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work
|
||||
Jinja2 @ file:///tmp/build/80754af9/jinja2_1635780242639/work
|
||||
kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work
|
||||
Markdown==3.3.6
|
||||
MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528148836/work
|
||||
matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1638289681807/work
|
||||
mkl-fft==1.3.1
|
||||
mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work
|
||||
mkl-service==2.4.0
|
||||
ml-collections==0.1.0
|
||||
munkres==1.1.4
|
||||
numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work
|
||||
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095647912/work
|
||||
oauthlib==3.2.0
|
||||
olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work
|
||||
opencv-python==4.5.4.60
|
||||
packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
|
||||
pandas==1.3.4
|
||||
Pillow==8.4.0
|
||||
pipreqs==0.4.11
|
||||
protobuf==3.19.4
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
|
||||
pycryptodome==3.10.1
|
||||
pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1635333100036/work
|
||||
pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work
|
||||
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
|
||||
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
|
||||
pytz==2021.3
|
||||
PyYAML==6.0
|
||||
requests @ file:///tmp/build/80754af9/requests_1629994808627/work
|
||||
requests-oauthlib==1.3.1
|
||||
rsa==4.8
|
||||
scipy @ file:///tmp/build/80754af9/scipy_1630606796110/work
|
||||
seaborn @ file:///tmp/build/80754af9/seaborn_1629307859561/work
|
||||
sip==4.19.13
|
||||
six @ file:///tmp/build/80754af9/six_1623709665295/work
|
||||
supervisor==4.2.2
|
||||
tensorboard==2.8.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
torch==1.8.0
|
||||
torchaudio==0.8.0a0+a751e1d
|
||||
torchvision==0.9.0
|
||||
tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work
|
||||
tqdm @ file:///tmp/build/80754af9/tqdm_1635330843403/work
|
||||
typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work
|
||||
urllib3==1.26.7
|
||||
Werkzeug @ file:///tmp/build/80754af9/werkzeug_1635505089296/work
|
||||
yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work
|
||||
yarg==0.1.9
|
||||
zipp==3.7.0
|
||||
zope.event==4.5.0
|
||||
zope.interface @ file:///tmp/build/80754af9/zope.interface_1625035545636/work
|
64
test_single.py
Executable file
64
test_single.py
Executable file
@ -0,0 +1,64 @@
|
||||
# coding=utf-8
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import argparse
|
||||
from models.modeling import VisionTransformer, CONFIGS
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", choices=["CUB_200_2011", "emptyJudge5", "emptyJudge4"], default="emptyJudge5", help="Which dataset.")
|
||||
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
|
||||
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
||||
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
args.nprocs = torch.cuda.device_count()
|
||||
|
||||
# Prepare Model
|
||||
config = CONFIGS["ViT-B_16"]
|
||||
config.split = args.split
|
||||
config.slide_step = args.slide_step
|
||||
|
||||
cls_dict = {}
|
||||
num_classes = 0
|
||||
if args.dataset == "emptyJudge5":
|
||||
num_classes = 5
|
||||
cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
||||
elif args.dataset == "emptyJudge4":
|
||||
num_classes = 4
|
||||
cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
|
||||
elif args.dataset == "emptyJudge3":
|
||||
num_classes = 3
|
||||
cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
|
||||
elif args.dataset == "emptyJudge2":
|
||||
num_classes = 2
|
||||
cls_dict = {0: "noemp", 1: "yesemp"}
|
||||
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
|
||||
if args.pretrained_model is not None:
|
||||
pretrained_model = torch.load(args.pretrained_model, map_location=torch.device('cpu'))['model']
|
||||
model.load_state_dict(pretrained_model)
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
# transforms.CenterCrop((448, 448)),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
img = Image.open("img.jpg")
|
||||
x = test_transform(img)
|
||||
part_logits = model(x.unsqueeze(0))
|
||||
|
||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||
top5 = torch.argsort(probs, dim=-1, descending=True)
|
||||
print("Prediction Label\n")
|
||||
for idx in top5[0, :5]:
|
||||
print(f'{probs[0, idx.item()]:.5f} : {cls_dict[idx.item()]}', end='\n')
|
||||
|
379
train.py
Executable file
379
train.py
Executable file
@ -0,0 +1,379 @@
|
||||
# coding=utf-8
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from models.modeling import VisionTransformer, CONFIGS
|
||||
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
|
||||
from utils.data_utils import get_loader
|
||||
from utils.dist_util import get_world_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def reduce_mean(tensor, nprocs):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= nprocs
|
||||
return rt
|
||||
|
||||
|
||||
def save_model(args, model):
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
|
||||
checkpoint = {
|
||||
'model': model_to_save.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, model_checkpoint)
|
||||
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
||||
|
||||
|
||||
def setup(args):
|
||||
# Prepare model
|
||||
config = CONFIGS[args.model_type]
|
||||
config.split = args.split
|
||||
config.slide_step = args.slide_step
|
||||
|
||||
if args.dataset == "CUB_200_2011":
|
||||
num_classes = 200
|
||||
elif args.dataset == "car":
|
||||
num_classes = 196
|
||||
elif args.dataset == "nabirds":
|
||||
num_classes = 555
|
||||
elif args.dataset == "dog":
|
||||
num_classes = 120
|
||||
elif args.dataset == "INat2017":
|
||||
num_classes = 5089
|
||||
elif args.dataset == "emptyJudge5":
|
||||
num_classes = 5
|
||||
elif args.dataset == "emptyJudge4":
|
||||
num_classes = 4
|
||||
elif args.dataset == "emptyJudge3":
|
||||
num_classes = 3
|
||||
|
||||
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
|
||||
|
||||
model.load_from(np.load(args.pretrained_dir))
|
||||
if args.pretrained_model is not None:
|
||||
pretrained_model = torch.load(args.pretrained_model)['model']
|
||||
model.load_state_dict(pretrained_model)
|
||||
model.to(args.device)
|
||||
num_params = count_parameters(model)
|
||||
|
||||
logger.info("{}".format(config))
|
||||
logger.info("Training parameters %s", args)
|
||||
logger.info("Total Parameter: \t%2.1fM" % num_params)
|
||||
return args, model
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return params/1000000
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def valid(args, model, writer, test_loader, global_step):
|
||||
eval_losses = AverageMeter()
|
||||
|
||||
logger.info("***** Running Validation *****")
|
||||
# logger.info("val Num steps = %d", len(test_loader))
|
||||
# logger.info("val Batch size = %d", args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
all_preds, all_label = [], []
|
||||
epoch_iterator = tqdm(test_loader,
|
||||
desc="Validating... (loss=X.X)",
|
||||
bar_format="{l_bar}{r_bar}",
|
||||
dynamic_ncols=True,
|
||||
disable=args.local_rank not in [-1, 0])
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
x, y = batch
|
||||
with torch.no_grad():
|
||||
logits = model(x)
|
||||
|
||||
eval_loss = loss_fct(logits, y)
|
||||
eval_loss = eval_loss.mean()
|
||||
eval_losses.update(eval_loss.item())
|
||||
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
|
||||
if len(all_preds) == 0:
|
||||
all_preds.append(preds.detach().cpu().numpy())
|
||||
all_label.append(y.detach().cpu().numpy())
|
||||
else:
|
||||
all_preds[0] = np.append(
|
||||
all_preds[0], preds.detach().cpu().numpy(), axis=0
|
||||
)
|
||||
all_label[0] = np.append(
|
||||
all_label[0], y.detach().cpu().numpy(), axis=0
|
||||
)
|
||||
epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)
|
||||
|
||||
all_preds, all_label = all_preds[0], all_label[0]
|
||||
accuracy = simple_accuracy(all_preds, all_label)
|
||||
accuracy = torch.tensor(accuracy).to(args.device)
|
||||
# dist.barrier()
|
||||
# val_accuracy = reduce_mean(accuracy, args.nprocs)
|
||||
# val_accuracy = val_accuracy.detach().cpu().numpy()
|
||||
val_accuracy = accuracy.detach().cpu().numpy()
|
||||
|
||||
logger.info("\n")
|
||||
logger.info("Validation Results")
|
||||
logger.info("Global Steps: %d" % global_step)
|
||||
logger.info("Valid Loss: %2.5f" % eval_losses.avg)
|
||||
logger.info("Valid Accuracy: %2.5f" % val_accuracy)
|
||||
if args.local_rank in [-1, 0]:
|
||||
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
|
||||
return val_accuracy
|
||||
|
||||
|
||||
def train(args, model):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
|
||||
# Prepare dataset
|
||||
train_loader, test_loader = get_loader(args)
|
||||
logger.info("train Num steps = %d", len(train_loader))
|
||||
logger.info("val Num steps = %d", len(test_loader))
|
||||
# Prepare optimizer and scheduler
|
||||
optimizer = torch.optim.SGD(model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
momentum=0.9,
|
||||
weight_decay=args.weight_decay)
|
||||
t_total = args.num_steps
|
||||
if args.decay_type == "cosine":
|
||||
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||
else:
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Total optimization steps = %d", args.num_steps)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (
|
||||
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
|
||||
model.zero_grad()
|
||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||
losses = AverageMeter()
|
||||
global_step, best_acc = 0, 0
|
||||
start_time = time.time()
|
||||
while True:
|
||||
model.train()
|
||||
epoch_iterator = tqdm(train_loader,
|
||||
desc="Training (X / X Steps) (loss=X.X)",
|
||||
bar_format="{l_bar}{r_bar}",
|
||||
dynamic_ncols=True,
|
||||
disable=args.local_rank not in [-1, 0])
|
||||
all_preds, all_label = [], []
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
x, y = batch
|
||||
|
||||
loss, logits = model(x, y)
|
||||
loss = loss.mean()
|
||||
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
|
||||
if len(all_preds) == 0:
|
||||
all_preds.append(preds.detach().cpu().numpy())
|
||||
all_label.append(y.detach().cpu().numpy())
|
||||
else:
|
||||
all_preds[0] = np.append(
|
||||
all_preds[0], preds.detach().cpu().numpy(), axis=0
|
||||
)
|
||||
all_label[0] = np.append(
|
||||
all_label[0], y.detach().cpu().numpy(), axis=0
|
||||
)
|
||||
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
losses.update(loss.item()*args.gradient_accumulation_steps)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
epoch_iterator.set_description(
|
||||
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val)
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
|
||||
writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)
|
||||
if global_step % args.eval_every == 0:
|
||||
with torch.no_grad():
|
||||
accuracy = valid(args, model, writer, test_loader, global_step)
|
||||
if args.local_rank in [-1, 0]:
|
||||
if best_acc < accuracy:
|
||||
save_model(args, model)
|
||||
best_acc = accuracy
|
||||
logger.info("best accuracy so far: %f" % best_acc)
|
||||
model.train()
|
||||
|
||||
if global_step % t_total == 0:
|
||||
break
|
||||
all_preds, all_label = all_preds[0], all_label[0]
|
||||
accuracy = simple_accuracy(all_preds, all_label)
|
||||
accuracy = torch.tensor(accuracy).to(args.device)
|
||||
# dist.barrier()
|
||||
# train_accuracy = reduce_mean(accuracy, args.nprocs)
|
||||
# train_accuracy = train_accuracy.detach().cpu().numpy()
|
||||
train_accuracy = accuracy.detach().cpu().numpy()
|
||||
logger.info("train accuracy so far: %f" % train_accuracy)
|
||||
losses.reset()
|
||||
if global_step % t_total == 0:
|
||||
break
|
||||
|
||||
writer.close()
|
||||
logger.info("Best Accuracy: \t%f" % best_acc)
|
||||
logger.info("End Training!")
|
||||
end_time = time.time()
|
||||
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--name", required=True,
|
||||
help="Name of this run. Used for monitoring.")
|
||||
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
|
||||
default="CUB_200_2011", help="Which dataset.")
|
||||
parser.add_argument('--data_root', type=str, default='/data/fineGrained')
|
||||
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
|
||||
default="ViT-B_16",help="Which variant to use.")
|
||||
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
|
||||
help="Where to search for pretrained ViT models.")
|
||||
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
||||
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
|
||||
parser.add_argument("--output_dir", default="./output", type=str,
|
||||
help="The output directory where checkpoints will be written.")
|
||||
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||
parser.add_argument("--train_batch_size", default=64, type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size", default=16, type=int,
|
||||
help="Total batch size for eval.")
|
||||
parser.add_argument("--eval_every", default=200, type=int,
|
||||
help="Run prediction on validation set every so many steps."
|
||||
"Will always run one evaluation at the end of training.")
|
||||
|
||||
parser.add_argument("--learning_rate", default=3e-2, type=float,
|
||||
help="The initial learning rate for SGD.")
|
||||
parser.add_argument("--weight_decay", default=0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--num_steps", default=8000, type=int, #100000
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
|
||||
help="How to decay the learning rate.")
|
||||
parser.add_argument("--warmup_steps", default=500, type=int,
|
||||
help="Step of training to perform learning rate warmup for.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O2',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--loss_scale', type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
|
||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
||||
|
||||
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
|
||||
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.data_root = '{}/{}'.format(args.data_root, args.dataset)
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=60))
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
args.nprocs = torch.cuda.device_count()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
|
||||
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Model & Tokenizer Setup
|
||||
args, model = setup(args)
|
||||
# Training
|
||||
train(args, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
utils/__init__.py
Executable file
0
utils/__init__.py
Executable file
204
utils/autoaugment.py
Executable file
204
utils/autoaugment.py
Executable file
@ -0,0 +1,204 @@
|
||||
"""
|
||||
Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
|
||||
"""
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
__all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy']
|
||||
|
||||
|
||||
class AutoAugImageNetPolicy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class AutoAugCIFAR10Policy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class AutoAugSVHNPolicy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
# self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
|
||||
# operation1, ranges[operation1][magnitude_idx1],
|
||||
# operation2, ranges[operation2][magnitude_idx2])
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1:
|
||||
img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2:
|
||||
img = self.operation2(img, self.magnitude2)
|
||||
return img
|
135
utils/data_utils.py
Executable file
135
utils/data_utils.py
Executable file
@ -0,0 +1,135 @@
|
||||
import logging
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
|
||||
|
||||
from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge
|
||||
from .autoaugment import AutoAugImageNetPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_loader(args):
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if args.dataset == 'CUB_200_2011':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = CUB(root=args.data_root, is_train=True, transform=train_transform)
|
||||
testset = CUB(root=args.data_root, is_train=False, transform=test_transform)
|
||||
elif args.dataset == 'car':
|
||||
trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'),
|
||||
os.path.join(args.data_root,'cars_train'),
|
||||
os.path.join(args.data_root,'devkit/cars_meta.mat'),
|
||||
# cleaned=os.path.join(data_dir,'cleaned.dat'),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
AutoAugImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
)
|
||||
testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'),
|
||||
os.path.join(args.data_root,'cars_test'),
|
||||
os.path.join(args.data_root,'devkit/cars_meta.mat'),
|
||||
# cleaned=os.path.join(data_dir,'cleaned_test.dat'),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
)
|
||||
elif args.dataset == 'dog':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = dogs(root=args.data_root,
|
||||
train=True,
|
||||
cropped=False,
|
||||
transform=train_transform,
|
||||
download=False
|
||||
)
|
||||
testset = dogs(root=args.data_root,
|
||||
train=False,
|
||||
cropped=False,
|
||||
transform=test_transform,
|
||||
download=False
|
||||
)
|
||||
elif args.dataset == 'nabirds':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = NABirds(root=args.data_root, train=True, transform=train_transform)
|
||||
testset = NABirds(root=args.data_root, train=False, transform=test_transform)
|
||||
elif args.dataset == 'INat2017':
|
||||
train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
|
||||
transforms.RandomCrop((304, 304)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
AutoAugImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
|
||||
transforms.CenterCrop((304, 304)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = INat2017(args.data_root, 'train', train_transform)
|
||||
testset = INat2017(args.data_root, 'val', test_transform)
|
||||
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
|
||||
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
# transforms.CenterCrop((448, 448)),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
|
||||
testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
|
||||
test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset)
|
||||
train_loader = DataLoader(trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
test_loader = DataLoader(testset,
|
||||
sampler=test_sampler,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=4,
|
||||
pin_memory=True) if testset is not None else None
|
||||
|
||||
return train_loader, test_loader
|
629
utils/dataset.py
Executable file
629
utils/dataset.py
Executable file
@ -0,0 +1,629 @@
|
||||
import os
|
||||
import json
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy import io
|
||||
import scipy.misc
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.folder import default_loader
|
||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||
|
||||
|
||||
class emptyJudge():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
class CUB():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
class CarsDataset(Dataset):
|
||||
def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None):
|
||||
"""
|
||||
Args:
|
||||
mat_anno (string): Path to the MATLAB annotation file.
|
||||
data_dir (string): Directory with all the images.
|
||||
transform (callable, optional): Optional transform to be applied
|
||||
on a sample.
|
||||
"""
|
||||
|
||||
self.full_data_set = io.loadmat(mat_anno)
|
||||
self.car_annotations = self.full_data_set['annotations']
|
||||
self.car_annotations = self.car_annotations[0]
|
||||
|
||||
if cleaned is not None:
|
||||
cleaned_annos = []
|
||||
print("Cleaning up data set (only take pics with rgb chans)...")
|
||||
clean_files = np.loadtxt(cleaned, dtype=str)
|
||||
for c in self.car_annotations:
|
||||
if c[-1][0] in clean_files:
|
||||
cleaned_annos.append(c)
|
||||
self.car_annotations = cleaned_annos
|
||||
|
||||
self.car_names = scipy.io.loadmat(car_names)['class_names']
|
||||
self.car_names = np.array(self.car_names[0])
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.car_annotations)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0])
|
||||
image = Image.open(img_name).convert('RGB')
|
||||
car_class = self.car_annotations[idx][-2][0][0]
|
||||
car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1
|
||||
assert car_class < 196
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
# return image, car_class, img_name
|
||||
return image, car_class
|
||||
|
||||
def map_class(self, id):
|
||||
id = np.ravel(id)
|
||||
ret = self.car_names[id - 1][0][0]
|
||||
return ret
|
||||
|
||||
def show_batch(self, img_batch, class_batch):
|
||||
|
||||
for i in range(img_batch.shape[0]):
|
||||
ax = plt.subplot(1, img_batch.shape[0], i + 1)
|
||||
title_str = self.map_class(int(class_batch[i]))
|
||||
img = np.transpose(img_batch[i, ...], (1, 2, 0))
|
||||
ax.imshow(img)
|
||||
ax.set_title(title_str.__str__(), {'fontsize': 5})
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert(len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class dogs(Dataset):
|
||||
"""`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``omniglot-py`` exists.
|
||||
cropped (bool, optional): If true, the images will be cropped into the bounding box specified
|
||||
in the annotations
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset tar files from the internet and
|
||||
puts it in root directory. If the tar files are already downloaded, they are not
|
||||
downloaded again.
|
||||
"""
|
||||
folder = 'dog'
|
||||
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
train=True,
|
||||
cropped=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
download=False):
|
||||
|
||||
# self.root = join(os.path.expanduser(root), self.folder)
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.cropped = cropped
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
split = self.load_split()
|
||||
|
||||
self.images_folder = join(self.root, 'Images')
|
||||
self.annotations_folder = join(self.root, 'Annotation')
|
||||
self._breeds = list_dir(self.images_folder)
|
||||
|
||||
if self.cropped:
|
||||
self._breed_annotations = [[(annotation, box, idx)
|
||||
for box in self.get_boxes(join(self.annotations_folder, annotation))]
|
||||
for annotation, idx in split]
|
||||
self._flat_breed_annotations = sum(self._breed_annotations, [])
|
||||
|
||||
self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
|
||||
else:
|
||||
self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]
|
||||
|
||||
self._flat_breed_images = self._breed_images
|
||||
|
||||
self.classes = ["Chihuaha",
|
||||
"Japanese Spaniel",
|
||||
"Maltese Dog",
|
||||
"Pekinese",
|
||||
"Shih-Tzu",
|
||||
"Blenheim Spaniel",
|
||||
"Papillon",
|
||||
"Toy Terrier",
|
||||
"Rhodesian Ridgeback",
|
||||
"Afghan Hound",
|
||||
"Basset Hound",
|
||||
"Beagle",
|
||||
"Bloodhound",
|
||||
"Bluetick",
|
||||
"Black-and-tan Coonhound",
|
||||
"Walker Hound",
|
||||
"English Foxhound",
|
||||
"Redbone",
|
||||
"Borzoi",
|
||||
"Irish Wolfhound",
|
||||
"Italian Greyhound",
|
||||
"Whippet",
|
||||
"Ibizian Hound",
|
||||
"Norwegian Elkhound",
|
||||
"Otterhound",
|
||||
"Saluki",
|
||||
"Scottish Deerhound",
|
||||
"Weimaraner",
|
||||
"Staffordshire Bullterrier",
|
||||
"American Staffordshire Terrier",
|
||||
"Bedlington Terrier",
|
||||
"Border Terrier",
|
||||
"Kerry Blue Terrier",
|
||||
"Irish Terrier",
|
||||
"Norfolk Terrier",
|
||||
"Norwich Terrier",
|
||||
"Yorkshire Terrier",
|
||||
"Wirehaired Fox Terrier",
|
||||
"Lakeland Terrier",
|
||||
"Sealyham Terrier",
|
||||
"Airedale",
|
||||
"Cairn",
|
||||
"Australian Terrier",
|
||||
"Dandi Dinmont",
|
||||
"Boston Bull",
|
||||
"Miniature Schnauzer",
|
||||
"Giant Schnauzer",
|
||||
"Standard Schnauzer",
|
||||
"Scotch Terrier",
|
||||
"Tibetan Terrier",
|
||||
"Silky Terrier",
|
||||
"Soft-coated Wheaten Terrier",
|
||||
"West Highland White Terrier",
|
||||
"Lhasa",
|
||||
"Flat-coated Retriever",
|
||||
"Curly-coater Retriever",
|
||||
"Golden Retriever",
|
||||
"Labrador Retriever",
|
||||
"Chesapeake Bay Retriever",
|
||||
"German Short-haired Pointer",
|
||||
"Vizsla",
|
||||
"English Setter",
|
||||
"Irish Setter",
|
||||
"Gordon Setter",
|
||||
"Brittany",
|
||||
"Clumber",
|
||||
"English Springer Spaniel",
|
||||
"Welsh Springer Spaniel",
|
||||
"Cocker Spaniel",
|
||||
"Sussex Spaniel",
|
||||
"Irish Water Spaniel",
|
||||
"Kuvasz",
|
||||
"Schipperke",
|
||||
"Groenendael",
|
||||
"Malinois",
|
||||
"Briard",
|
||||
"Kelpie",
|
||||
"Komondor",
|
||||
"Old English Sheepdog",
|
||||
"Shetland Sheepdog",
|
||||
"Collie",
|
||||
"Border Collie",
|
||||
"Bouvier des Flandres",
|
||||
"Rottweiler",
|
||||
"German Shepard",
|
||||
"Doberman",
|
||||
"Miniature Pinscher",
|
||||
"Greater Swiss Mountain Dog",
|
||||
"Bernese Mountain Dog",
|
||||
"Appenzeller",
|
||||
"EntleBucher",
|
||||
"Boxer",
|
||||
"Bull Mastiff",
|
||||
"Tibetan Mastiff",
|
||||
"French Bulldog",
|
||||
"Great Dane",
|
||||
"Saint Bernard",
|
||||
"Eskimo Dog",
|
||||
"Malamute",
|
||||
"Siberian Husky",
|
||||
"Affenpinscher",
|
||||
"Basenji",
|
||||
"Pug",
|
||||
"Leonberg",
|
||||
"Newfoundland",
|
||||
"Great Pyrenees",
|
||||
"Samoyed",
|
||||
"Pomeranian",
|
||||
"Chow",
|
||||
"Keeshond",
|
||||
"Brabancon Griffon",
|
||||
"Pembroke",
|
||||
"Cardigan",
|
||||
"Toy Poodle",
|
||||
"Miniature Poodle",
|
||||
"Standard Poodle",
|
||||
"Mexican Hairless",
|
||||
"Dingo",
|
||||
"Dhole",
|
||||
"African Hunting Dog"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._flat_breed_images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target character class.
|
||||
"""
|
||||
image_name, target_class = self._flat_breed_images[index]
|
||||
image_path = join(self.images_folder, image_name)
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
if self.cropped:
|
||||
image = image.crop(self._flat_breed_annotations[index][1])
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
if self.target_transform:
|
||||
target_class = self.target_transform(target_class)
|
||||
|
||||
return image, target_class
|
||||
|
||||
def download(self):
|
||||
import tarfile
|
||||
|
||||
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
|
||||
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
|
||||
for filename in ['images', 'annotation', 'lists']:
|
||||
tar_filename = filename + '.tar'
|
||||
url = self.download_url_prefix + '/' + tar_filename
|
||||
download_url(url, self.root, tar_filename, None)
|
||||
print('Extracting downloaded file: ' + join(self.root, tar_filename))
|
||||
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
|
||||
tar_file.extractall(self.root)
|
||||
os.remove(join(self.root, tar_filename))
|
||||
|
||||
@staticmethod
|
||||
def get_boxes(path):
|
||||
import xml.etree.ElementTree
|
||||
e = xml.etree.ElementTree.parse(path).getroot()
|
||||
boxes = []
|
||||
for objs in e.iter('object'):
|
||||
boxes.append([int(objs.find('bndbox').find('xmin').text),
|
||||
int(objs.find('bndbox').find('ymin').text),
|
||||
int(objs.find('bndbox').find('xmax').text),
|
||||
int(objs.find('bndbox').find('ymax').text)])
|
||||
return boxes
|
||||
|
||||
def load_split(self):
|
||||
if self.train:
|
||||
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
|
||||
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
|
||||
else:
|
||||
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
|
||||
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
|
||||
|
||||
split = [item[0][0] for item in split]
|
||||
labels = [item[0]-1 for item in labels]
|
||||
return list(zip(split, labels))
|
||||
|
||||
def stats(self):
|
||||
counts = {}
|
||||
for index in range(len(self._flat_breed_images)):
|
||||
image_name, target_class = self._flat_breed_images[index]
|
||||
if target_class not in counts.keys():
|
||||
counts[target_class] = 1
|
||||
else:
|
||||
counts[target_class] += 1
|
||||
|
||||
print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))
|
||||
return counts
|
||||
|
||||
|
||||
class NABirds(Dataset):
|
||||
"""`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of the dataset.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
base_folder = 'nabirds/images'
|
||||
|
||||
def __init__(self, root, train=True, transform=None):
|
||||
dataset_path = os.path.join(root, 'nabirds')
|
||||
self.root = root
|
||||
self.loader = default_loader
|
||||
self.train = train
|
||||
self.transform = transform
|
||||
|
||||
image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
|
||||
sep=' ', names=['img_id', 'filepath'])
|
||||
image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
|
||||
sep=' ', names=['img_id', 'target'])
|
||||
# Since the raw labels are non-continuous, map them to new ones
|
||||
self.label_map = get_continuous_class_map(image_class_labels['target'])
|
||||
train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
|
||||
sep=' ', names=['img_id', 'is_training_img'])
|
||||
data = image_paths.merge(image_class_labels, on='img_id')
|
||||
self.data = data.merge(train_test_split, on='img_id')
|
||||
# Load in the train / test split
|
||||
if self.train:
|
||||
self.data = self.data[self.data.is_training_img == 1]
|
||||
else:
|
||||
self.data = self.data[self.data.is_training_img == 0]
|
||||
|
||||
# Load in the class data
|
||||
self.class_names = load_class_names(dataset_path)
|
||||
self.class_hierarchy = load_hierarchy(dataset_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data.iloc[idx]
|
||||
path = os.path.join(self.root, self.base_folder, sample.filepath)
|
||||
target = self.label_map[sample.target]
|
||||
img = self.loader(path)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, target
|
||||
|
||||
|
||||
def get_continuous_class_map(class_labels):
|
||||
label_set = set(class_labels)
|
||||
return {k: i for i, k in enumerate(label_set)}
|
||||
|
||||
|
||||
def load_class_names(dataset_path=''):
|
||||
names = {}
|
||||
|
||||
with open(os.path.join(dataset_path, 'classes.txt')) as f:
|
||||
for line in f:
|
||||
pieces = line.strip().split()
|
||||
class_id = pieces[0]
|
||||
names[class_id] = ' '.join(pieces[1:])
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def load_hierarchy(dataset_path=''):
|
||||
parents = {}
|
||||
|
||||
with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
|
||||
for line in f:
|
||||
pieces = line.strip().split()
|
||||
child_id, parent_id = pieces
|
||||
parents[child_id] = parent_id
|
||||
|
||||
return parents
|
||||
|
||||
|
||||
class INat2017(VisionDataset):
|
||||
"""`iNaturalist 2017 <https://github.com/visipedia/inat_comp/blob/master/2017/README.md>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory of the dataset.
|
||||
split (string, optional): The dataset split, supports ``train``, or ``val``.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
base_folder = 'train_val_images/'
|
||||
file_list = {
|
||||
'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
|
||||
'train_val_images.tar.gz',
|
||||
'7c784ea5e424efaec655bd392f87301f'),
|
||||
'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
|
||||
'train_val2017.zip',
|
||||
'444c835f6459867ad69fcb36478786e7')
|
||||
}
|
||||
|
||||
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
|
||||
super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
self.loader = default_loader
|
||||
self.split = verify_str_arg(split, "split", ("train", "val",))
|
||||
|
||||
if self._check_exists():
|
||||
print('Files already downloaded and verified.')
|
||||
elif download:
|
||||
if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
|
||||
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
|
||||
print('Downloading...')
|
||||
self._download()
|
||||
print('Extracting...')
|
||||
extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
|
||||
extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Dataset not found. You can use download=True to download it.')
|
||||
anno_filename = split + '2017.json'
|
||||
with open(os.path.join(self.root, anno_filename), 'r') as fp:
|
||||
all_annos = json.load(fp)
|
||||
|
||||
self.annos = all_annos['annotations']
|
||||
self.images = all_annos['images']
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = os.path.join(self.root, self.images[index]['file_name'])
|
||||
target = self.annos[index]['category_id']
|
||||
|
||||
image = self.loader(path)
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, self.base_folder))
|
||||
|
||||
def _download(self):
|
||||
for url, filename, md5 in self.file_list.values():
|
||||
download_url(url, root=self.root, filename=filename)
|
||||
if not check_integrity(os.path.join(self.root, filename), md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
30
utils/dist_util.py
Executable file
30
utils/dist_util.py
Executable file
@ -0,0 +1,30 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
def format_step(step):
|
||||
if isinstance(step, str):
|
||||
return step
|
||||
s = ""
|
||||
if len(step) > 0:
|
||||
s += "Training Epoch: {} ".format(step[0])
|
||||
if len(step) > 1:
|
||||
s += "Training Iteration: {} ".format(step[1])
|
||||
if len(step) > 2:
|
||||
s += "Validation Iteration: {} ".format(step[2])
|
||||
return s
|
63
utils/scheduler.py
Executable file
63
utils/scheduler.py
Executable file
@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConstantLRSchedule(LambdaLR):
|
||||
""" Constant learning rate schedule.
|
||||
"""
|
||||
def __init__(self, optimizer, last_epoch=-1):
|
||||
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
|
||||
|
||||
|
||||
class WarmupConstantSchedule(LambdaLR):
|
||||
""" Linear warmup and then constant.
|
||||
Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
|
||||
Keeps learning rate schedule equal to 1. after warmup_steps.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1.0, self.warmup_steps))
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupLinearSchedule(LambdaLR):
|
||||
""" Linear warmup and then linear decay.
|
||||
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
||||
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1, self.warmup_steps))
|
||||
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||||
|
||||
|
||||
class WarmupCosineSchedule(LambdaLR):
|
||||
""" Linear warmup and then cosine decay.
|
||||
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
||||
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
|
||||
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
self.cycles = cycles
|
||||
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1.0, self.warmup_steps))
|
||||
# progress after warmup
|
||||
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
|
||||
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
|
Reference in New Issue
Block a user