update
This commit is contained in:
@ -7,14 +7,12 @@ import cv2, base64
|
|||||||
import argparse
|
import argparse
|
||||||
import sys, os
|
import sys, os
|
||||||
import torch
|
import torch
|
||||||
|
from gevent.pywsgi import WSGIServer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
# import logging.config as log_config
|
from models.modeling import VisionTransformer, CONFIGS
|
||||||
sys.path.insert(0, ".")
|
sys.path.insert(0, ".")
|
||||||
|
|
||||||
<<<<<<< HEAD
|
|
||||||
#Flask对外服务接口
|
|
||||||
=======
|
|
||||||
import logging.config
|
import logging.config
|
||||||
from skywalking import agent, config
|
from skywalking import agent, config
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||||
@ -32,64 +30,63 @@ def setup_logging(path):
|
|||||||
logger = logging.getLogger("root")
|
logger = logging.getLogger("root")
|
||||||
return logger
|
return logger
|
||||||
logger = setup_logging('utils/logging.json')
|
logger = setup_logging('utils/logging.json')
|
||||||
>>>>>>> develop
|
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
#app.use_reloader=False
|
app.use_reloader=False
|
||||||
|
|
||||||
print("Autor:ieemoo_lc&ieemoo_lx")
|
|
||||||
print(torch.__version__)
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args(model_file="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"):
|
||||||
|
#def parse_args(model_file="output/emptyjudge5_checkpoint.bin"):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
|
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||||
parser.add_argument('--split', type=str, default='overlap', help="Split method")
|
parser.add_argument('--split', type=str, default='overlap', help="Split method")
|
||||||
parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split")
|
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
||||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
|
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value")
|
||||||
#使用自定义VIT
|
parser.add_argument("--pretrained_model", type=str, default=model_file, help="load pretrained model")
|
||||||
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/now/ieemooempty_vit_checkpoint.pth", help="load pretrained model")
|
|
||||||
opt, unknown = parser.parse_known_args()
|
opt, unknown = parser.parse_known_args()
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
||||||
class Predictor(object):
|
class Predictor(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
#self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
#self.args.device = torch.device("cpu")
|
|
||||||
#print(self.args.device)
|
#print(self.args.device)
|
||||||
#self.args.nprocs = torch.cuda.device_count()
|
self.args.nprocs = torch.cuda.device_count()
|
||||||
self.cls_dict = {}
|
self.cls_dict = {}
|
||||||
self.num_classes = 0
|
self.num_classes = 0
|
||||||
self.model = None
|
self.model = None
|
||||||
self.prepare_model()
|
self.prepare_model()
|
||||||
self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
def prepare_model(self):
|
def prepare_model(self):
|
||||||
# config = CONFIGS["ViT-B_16"]
|
config = CONFIGS["ViT-B_16"]
|
||||||
# config.split = self.args.split
|
config.split = self.args.split
|
||||||
# config.slide_step = self.args.slide_step
|
config.slide_step = self.args.slide_step
|
||||||
# self.num_classes = 5
|
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
|
||||||
# self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
#print("use model_name: ", model_name)
|
||||||
# self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
self.num_classes = 5
|
||||||
|
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
||||||
# if self.args.pretrained_model is not None:
|
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
||||||
# if not torch.cuda.is_available():
|
if self.args.pretrained_model is not None:
|
||||||
# self.model = torch.load(self.args.pretrained_model)
|
if not torch.cuda.is_available():
|
||||||
# else:
|
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
|
||||||
# self.model = torch.load(self.args.pretrained_model,map_location='cpu')
|
self.model.load_state_dict(pretrained_model)
|
||||||
self.model = torch.load(self.args.pretrained_model,map_location=torch.device('cpu'))
|
else:
|
||||||
|
pretrained_model = torch.load(self.args.pretrained_model)['model']
|
||||||
|
self.model.load_state_dict(pretrained_model)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
if torch.cuda.is_available():
|
self.model.to(self.args.device)
|
||||||
self.model.to("cuda")
|
#self.model.eval()
|
||||||
|
|
||||||
def normal_predict(self, img_data, result):
|
def normal_predict(self, img_data, result):
|
||||||
# img = Image.open(img_path)
|
# img = Image.open(img_path)
|
||||||
if img_data is None:
|
if img_data is None:
|
||||||
#print('error, img data is None')
|
#print('error, img data is None')
|
||||||
print('error, img data is None')
|
logger.warning('error, img data is None')
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -100,59 +97,47 @@ class Predictor(object):
|
|||||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||||||
clas_ids = topN[0][0]
|
clas_ids = topN[0][0]
|
||||||
|
|
||||||
print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
|
|
||||||
# if(int(clas_ids)==6 or int(clas_ids)==7):
|
|
||||||
# clas_ids = 0
|
|
||||||
# else:
|
|
||||||
# clas_ids = 1
|
|
||||||
# result["success"] = "true"
|
|
||||||
# result["rst_cls"] = str(clas_ids)
|
|
||||||
|
|
||||||
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
||||||
print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
|
#print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
|
||||||
result["success"] = "true"
|
result["success"] = "true"
|
||||||
result["rst_cls"] = str(clas_ids)
|
result["rst_cls"] = str(clas_ids)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
model_file ="../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin"
|
||||||
args = parse_args()
|
#model_file ="output/emptyjudge5_checkpoint.bin"
|
||||||
|
args = parse_args(model_file)
|
||||||
predictor = Predictor(args)
|
predictor = Predictor(args)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/isempty", methods=['POST'])
|
@app.route("/isempty", methods=['POST'])
|
||||||
def get_isempty():
|
def get_isempty():
|
||||||
#print("begin")
|
start = time.time()
|
||||||
|
#print('--------------------EmptyPredict-----------------')
|
||||||
data = request.get_data()
|
data = request.get_data()
|
||||||
|
ip = request.remote_addr
|
||||||
|
#print('------ ip = %s ------' % ip)
|
||||||
|
logger.info(ip)
|
||||||
|
|
||||||
json_data = json.loads(data.decode("utf-8"))
|
json_data = json.loads(data.decode("utf-8"))
|
||||||
|
getdateend = time.time()
|
||||||
|
#print('get date use time: {0:.2f}s'.format(getdateend - start))
|
||||||
|
|
||||||
pic = json_data.get("pic")
|
pic = json_data.get("pic")
|
||||||
imgdata = base64.b64decode(pic)
|
result = {"success": "false",
|
||||||
|
"rst_cls": '-1',
|
||||||
result ={}
|
}
|
||||||
imgdata_np = np.frombuffer(imgdata, dtype='uint8')
|
try:
|
||||||
img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
|
imgdata = base64.b64decode(pic)
|
||||||
img_data = Image.fromarray(np.uint8(img_src))
|
imgdata_np = np.frombuffer(imgdata, dtype='uint8')
|
||||||
result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
|
img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
|
||||||
|
img_data = Image.fromarray(np.uint8(img_src))
|
||||||
|
result = predictor.normal_predict(img_data, result) # 1==empty, 0==nonEmpty
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(e)
|
||||||
|
return repr(result)
|
||||||
|
logger.info(repr(result))
|
||||||
return repr(result)
|
return repr(result)
|
||||||
|
|
||||||
def getByte(path):
|
|
||||||
with open(path, 'rb') as f:
|
|
||||||
img_byte = base64.b64encode(f.read())
|
|
||||||
img_str = img_byte.decode('utf-8')
|
|
||||||
return img_str
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(host='0.0.0.0', port=8888)
|
app.run(host='192.168.1.142', port=8000)
|
||||||
|
|
||||||
# result ={}
|
|
||||||
# imgdata = base64.b64decode(getByte("img.jpg"))
|
|
||||||
# imgdata_np = np.frombuffer(imgdata, dtype='uint8')
|
|
||||||
# img_src = cv2.imdecode(imgdata_np, cv2.IMREAD_COLOR)
|
|
||||||
# img_data = Image.fromarray(np.uint8(img_src))
|
|
||||||
# result = predictor.normal_predict(img_data, result)
|
|
||||||
# print(result)
|
|
||||||
|
75
predict.py
75
predict.py
@ -9,24 +9,22 @@ from sklearn.metrics import f1_score
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from models.modeling import VisionTransformer, CONFIGS
|
from models.modeling import VisionTransformer, CONFIGS
|
||||||
import lightrise
|
|
||||||
|
|
||||||
#模型预测
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
|
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||||
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
|
parser.add_argument('--split', type=str, default='overlap', help="Split method") # non-overlap
|
||||||
parser.add_argument('--slide_step', type=int, default=2, help="Slide step for overlap split")
|
parser.add_argument('--slide_step', type=int, default=12, help="Slide step for overlap split")
|
||||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
||||||
parser.add_argument("--pretrained_model", type=str, default="../module/ieemoo-ai-isempty/model/new/ieemooempty_vit_checkpoint.pth", help="load pretrained model")
|
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
||||||
#parser.add_argument("--pretrained_model", type=str, default="output/ieemooempty_vit_checkpoint.pth", help="load pretrained model") #使用自定义VIT
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
class Predictor(object):
|
class Predictor(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.args.device = torch.device("cuda")
|
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print("self.args.device =", self.args.device)
|
print("self.args.device =", self.args.device)
|
||||||
self.args.nprocs = torch.cuda.device_count()
|
self.args.nprocs = torch.cuda.device_count()
|
||||||
|
|
||||||
@ -34,7 +32,7 @@ class Predictor(object):
|
|||||||
self.num_classes = 0
|
self.num_classes = 0
|
||||||
self.model = None
|
self.model = None
|
||||||
self.prepare_model()
|
self.prepare_model()
|
||||||
self.test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
self.test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
@ -42,14 +40,28 @@ class Predictor(object):
|
|||||||
config = CONFIGS["ViT-B_16"]
|
config = CONFIGS["ViT-B_16"]
|
||||||
config.split = self.args.split
|
config.split = self.args.split
|
||||||
config.slide_step = self.args.slide_step
|
config.slide_step = self.args.slide_step
|
||||||
self.num_classes = 5
|
model_name = os.path.basename(self.args.pretrained_model).replace("_checkpoint.bin", "")
|
||||||
self.cls_dict = {0: "noemp", 1: "yesemp"}
|
print("use model_name: ", model_name)
|
||||||
|
if model_name.lower() == "emptyJudge5".lower():
|
||||||
|
self.num_classes = 5
|
||||||
|
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "fly", 4: "stack"}
|
||||||
|
elif model_name.lower() == "emptyJudge4".lower():
|
||||||
|
self.num_classes = 4
|
||||||
|
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard", 3: "stack"}
|
||||||
|
elif model_name.lower() == "emptyJudge3".lower():
|
||||||
|
self.num_classes = 3
|
||||||
|
self.cls_dict = {0: "noemp", 1: "yesemp", 2: "hard"}
|
||||||
|
elif model_name.lower() == "emptyJudge2".lower():
|
||||||
|
self.num_classes = 2
|
||||||
|
self.cls_dict = {0: "noemp", 1: "yesemp"}
|
||||||
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
self.model = VisionTransformer(config, self.args.img_size, zero_head=True, num_classes=self.num_classes, smoothing_value=self.args.smoothing_value)
|
||||||
|
|
||||||
if self.args.pretrained_model is not None:
|
if self.args.pretrained_model is not None:
|
||||||
self.model = torch.load(self.args.pretrained_model,map_location='cpu')
|
if not torch.cuda.is_available():
|
||||||
|
pretrained_model = torch.load(self.args.pretrained_model, map_location=torch.device('cpu'))['model']
|
||||||
|
self.model.load_state_dict(pretrained_model)
|
||||||
|
else:
|
||||||
|
pretrained_model = torch.load(self.args.pretrained_model)['model']
|
||||||
|
self.model.load_state_dict(pretrained_model)
|
||||||
self.model.to(self.args.device)
|
self.model.to(self.args.device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@ -61,7 +73,9 @@ class Predictor(object):
|
|||||||
"Image file failed to read: {}".format(img_path))
|
"Image file failed to read: {}".format(img_path))
|
||||||
else:
|
else:
|
||||||
x = self.test_transform(img)
|
x = self.test_transform(img)
|
||||||
part_logits = self.model(x.unsqueeze(0).to(args.device))
|
if torch.cuda.is_available():
|
||||||
|
x = x.cuda()
|
||||||
|
part_logits = self.model(x.unsqueeze(0))
|
||||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||||||
clas_ids = topN[0][0]
|
clas_ids = topN[0][0]
|
||||||
@ -75,12 +89,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
y_true = []
|
y_true = []
|
||||||
y_pred = []
|
y_pred = []
|
||||||
test_dir = "./emptyJudge5/images/"
|
test_dir = "/data/pfc/fineGrained/test_5cls"
|
||||||
dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
|
dir_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
|
||||||
|
|
||||||
# test_dir = "../emptyJudge2/images"
|
|
||||||
# dir_dict = {"noempty":"0", "empty":"1"}
|
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
num = 0
|
num = 0
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@ -96,19 +106,6 @@ if __name__ == "__main__":
|
|||||||
cur_pred, pred_score = predictor.normal_predict(cur_img_file)
|
cur_pred, pred_score = predictor.normal_predict(cur_img_file)
|
||||||
|
|
||||||
label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label)
|
label = 0 if 2 == int(label) or 3 == int(label) or 4 == int(label) else int(label)
|
||||||
|
|
||||||
riseresult = lightrise.riseempty(Image.open(cur_img_file))
|
|
||||||
if(label==1):
|
|
||||||
if(int(riseresult["rst_cls"])==1):
|
|
||||||
label=1
|
|
||||||
else:
|
|
||||||
label=0
|
|
||||||
# else:
|
|
||||||
# if(riseresult["rst_cls"]==0):
|
|
||||||
# label=0
|
|
||||||
# else:
|
|
||||||
# label=1
|
|
||||||
|
|
||||||
cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred)
|
cur_pred = 0 if 2 == int(cur_pred) or 3 == int(cur_pred) or 4 == int(cur_pred) else int(cur_pred)
|
||||||
y_true.append(int(label))
|
y_true.append(int(label))
|
||||||
y_pred.append(int(cur_pred))
|
y_pred.append(int(cur_pred))
|
||||||
@ -128,18 +125,6 @@ if __name__ == "__main__":
|
|||||||
print(rst_C)
|
print(rst_C)
|
||||||
print(rst_f1)
|
print(rst_f1)
|
||||||
|
|
||||||
'''
|
|
||||||
所有数据集
|
|
||||||
|
|
||||||
The cast of time is :160.738966 seconds
|
|
||||||
The classification accuracy is 0.986836
|
|
||||||
[[4923 58]
|
|
||||||
[ 34 1974]]
|
|
||||||
0.9839851634589902
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
test_imgs: yesemp=145, noemp=453 大图
|
test_imgs: yesemp=145, noemp=453 大图
|
||||||
|
|
||||||
|
154
prepara_data.py
154
prepara_data.py
@ -1,38 +1,28 @@
|
|||||||
#encoding: utf-8
|
|
||||||
import os
|
import os
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
#生成数据集
|
|
||||||
|
|
||||||
|
|
||||||
# ----------- 改写名称 --------------
|
# ----------- 改写名称 --------------
|
||||||
# index = 0
|
# index = 0
|
||||||
# src_dir = "../emptyJudge2/images/"
|
# src_dir = "/data/fineGrained/emptyJudge5"
|
||||||
# dst_dir = src_dir
|
# dst_dir = src_dir + "_new"
|
||||||
# os.remove('../emptyJudge2/image_class_labels.txt')
|
# os.makedirs(dst_dir, exist_ok=True)
|
||||||
# os.remove('../emptyJudge2/images.txt')
|
|
||||||
# os.remove('../emptyJudge2/train_test_split.txt')
|
|
||||||
# if(os.path.exists(dst_dir)):
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# os.makedirs(dst_dir)
|
|
||||||
|
|
||||||
# for sub in os.listdir(src_dir):
|
# for sub in os.listdir(src_dir):
|
||||||
# sub_path = os.path.join(src_dir, sub)
|
# sub_path = os.path.join(src_dir, sub)
|
||||||
# print(sub_path)
|
|
||||||
# sub_path_dst = os.path.join(dst_dir, sub)
|
# sub_path_dst = os.path.join(dst_dir, sub)
|
||||||
|
# os.makedirs(sub_path_dst, exist_ok=True)
|
||||||
# for cur_f in os.listdir(sub_path):
|
# for cur_f in os.listdir(sub_path):
|
||||||
# cur_img = os.path.join(sub_path, cur_f)
|
# cur_img = os.path.join(sub_path, cur_f)
|
||||||
# cur_img_dst = os.path.join(sub_path_dst, "image%04d.jpg" % index)
|
# cur_img_dst = os.path.join(sub_path_dst, "a%05d.jpg" % index)
|
||||||
# index += 1
|
# index += 1
|
||||||
# os.system("mv %s %s" % (cur_img, cur_img_dst))
|
# os.system("mv %s %s" % (cur_img, cur_img_dst))
|
||||||
|
|
||||||
|
|
||||||
# ----------- 删除过小图像 --------------
|
# ----------- 删除过小图像 --------------
|
||||||
# src_dir = "../emptyJudge2/images/"
|
# src_dir = "/data/fineGrained/emptyJudge5"
|
||||||
# for sub in os.listdir(src_dir):
|
# for sub in os.listdir(src_dir):
|
||||||
# sub_path = os.path.join(src_dir, sub)
|
# sub_path = os.path.join(src_dir, sub)
|
||||||
# for cur_f in os.listdir(sub_path):
|
# for cur_f in os.listdir(sub_path):
|
||||||
@ -47,59 +37,83 @@ import random
|
|||||||
|
|
||||||
|
|
||||||
# ----------- 获取有效图片并写images.txt --------------
|
# ----------- 获取有效图片并写images.txt --------------
|
||||||
src_dir = "../emptyJudge2/images/"
|
# src_dir = "/data/fineGrained/emptyJudge4/images"
|
||||||
src_dict = {"noempty":"0", "empty":"1"}
|
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
|
||||||
all_dict = {"noempty":[], "empty":[]}
|
# all_dict = {"yesemp":[], "noemp":[], "hard": [], "stack": []}
|
||||||
for sub, value in src_dict.items():
|
# for sub, value in src_dict.items():
|
||||||
sub_path = os.path.join(src_dir, sub)
|
# sub_path = os.path.join(src_dir, sub)
|
||||||
for cur_f in os.listdir(sub_path):
|
# for cur_f in os.listdir(sub_path):
|
||||||
all_dict[sub].append(os.path.join(sub, cur_f))
|
# all_dict[sub].append(os.path.join(sub, cur_f))
|
||||||
|
#
|
||||||
yesnum = len(all_dict["empty"])
|
# yesnum = len(all_dict["yesemp"])
|
||||||
#print(yesnum)
|
# nonum = len(all_dict["noemp"])
|
||||||
nonum = len(all_dict["noempty"])
|
# hardnum = len(all_dict["hard"])
|
||||||
#print(nonum)
|
# stacknum = len(all_dict["stack"])
|
||||||
images_txt = "../emptyJudge2/images.txt"
|
# thnum = min(yesnum, nonum, hardnum, stacknum)
|
||||||
index = 0
|
# images_txt = src_dir + ".txt"
|
||||||
|
# index = 1
|
||||||
|
#
|
||||||
def write_images(cur_list, num, fw, index):
|
# def write_images(cur_list, thnum, fw, index):
|
||||||
for feat_path in random.sample(cur_list, num):
|
# for feat_path in random.sample(cur_list, thnum):
|
||||||
fw.write(str(index) + " " + feat_path + "\n")
|
# fw.write(str(index) + " " + feat_path + "\n")
|
||||||
index += 1
|
# index += 1
|
||||||
return index
|
# return index
|
||||||
|
#
|
||||||
with open(images_txt, "w") as fw:
|
# with open(images_txt, "w") as fw:
|
||||||
index = write_images(all_dict["noempty"], nonum, fw, index)
|
# index = write_images(all_dict["noemp"], thnum, fw, index)
|
||||||
index = write_images(all_dict["empty"], yesnum, fw, index)
|
# index = write_images(all_dict["yesemp"], thnum, fw, index)
|
||||||
|
# index = write_images(all_dict["hard"], thnum, fw, index)
|
||||||
|
# index = write_images(all_dict["stack"], thnum, fw, index)
|
||||||
|
|
||||||
# ----------- 写 image_class_labels.txt + train_test_split.txt --------------
|
# ----------- 写 image_class_labels.txt + train_test_split.txt --------------
|
||||||
src_dir = "../emptyJudge2/"
|
# src_dir = "/data/fineGrained/emptyJudge4"
|
||||||
src_dict = {"noempty":"0", "empty":"1"}
|
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "stack": "3"}
|
||||||
images_txt = os.path.join(src_dir, "images.txt")
|
# images_txt = os.path.join(src_dir, "images.txt")
|
||||||
image_class_labels_txt = os.path.join(src_dir, "image_class_labels.txt")
|
# image_class_labels_txt = os.path.join(src_dir, "image_class_labels.txt")
|
||||||
imgs_cnt = 0
|
# imgs_cnt = 0
|
||||||
with open(image_class_labels_txt, "w") as fw:
|
# with open(image_class_labels_txt, "w") as fw:
|
||||||
with open(images_txt, "r") as fr:
|
# with open(images_txt, "r") as fr:
|
||||||
for cur_l in fr:
|
# for cur_l in fr:
|
||||||
imgs_cnt += 1
|
# imgs_cnt += 1
|
||||||
img_index, img_f = cur_l.strip().split(" ")
|
# img_index, img_f = cur_l.strip().split(" ")
|
||||||
folder_name = img_f.split("/")[0]
|
# folder_name = img_f.split("/")[0]
|
||||||
if folder_name in src_dict:
|
# if folder_name in src_dict:
|
||||||
cur_line = img_index + " " + str(int(src_dict[folder_name])+1)
|
# cur_line = img_index + " " + str(int(src_dict[folder_name])+1)
|
||||||
fw.write(cur_line + "\n")
|
# fw.write(cur_line + "\n")
|
||||||
|
#
|
||||||
|
# train_num = int(imgs_cnt*0.85)
|
||||||
|
# print("train_num= ", train_num, ", imgs_cnt= ", imgs_cnt)
|
||||||
|
# all_list = [1]*train_num + [0]*(imgs_cnt-train_num)
|
||||||
|
# assert len(all_list) == imgs_cnt
|
||||||
|
# random.shuffle(all_list)
|
||||||
|
# train_test_split_txt = os.path.join(src_dir, "train_test_split.txt")
|
||||||
|
# with open(train_test_split_txt, "w") as fw:
|
||||||
|
# with open(images_txt, "r") as fr:
|
||||||
|
# for cur_l in fr:
|
||||||
|
# img_index, img_f = cur_l.strip().split(" ")
|
||||||
|
# cur_line = img_index + " " + str(all_list[int(img_index) - 1])
|
||||||
|
# fw.write(cur_line + "\n")
|
||||||
|
|
||||||
|
# ----------- 生成标准测试集 --------------
|
||||||
|
# src_dir = "/data/fineGrained/emptyJudge5/images"
|
||||||
|
# src_dict = {"noemp":"0", "yesemp":"1", "hard": "2", "fly": "3", "stack": "4"}
|
||||||
|
# all_dict = {"noemp":[], "yesemp":[], "hard": [], "fly": [], "stack": []}
|
||||||
|
# for sub, value in src_dict.items():
|
||||||
|
# sub_path = os.path.join(src_dir, sub)
|
||||||
|
# for cur_f in os.listdir(sub_path):
|
||||||
|
# all_dict[sub].append(cur_f)
|
||||||
|
#
|
||||||
|
# dst_dir = src_dir + "_test"
|
||||||
|
# os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
# for sub, value in src_dict.items():
|
||||||
|
# sub_path = os.path.join(src_dir, sub)
|
||||||
|
# sub_path_dst = os.path.join(dst_dir, sub)
|
||||||
|
# os.makedirs(sub_path_dst, exist_ok=True)
|
||||||
|
#
|
||||||
|
# cur_list = all_dict[sub]
|
||||||
|
# test_num = int(len(cur_list) * 0.05)
|
||||||
|
# for cur_f in random.sample(cur_list, test_num):
|
||||||
|
# cur_path = os.path.join(sub_path, cur_f)
|
||||||
|
# cur_path_dst = os.path.join(sub_path_dst, cur_f)
|
||||||
|
# os.system("cp %s %s" % (cur_path, cur_path_dst))
|
||||||
|
|
||||||
train_num = int(imgs_cnt*0.85)
|
|
||||||
print("train_num= ", train_num, ", imgs_cnt= ", imgs_cnt)
|
|
||||||
all_list = [1]*train_num + [0]*(imgs_cnt-train_num)
|
|
||||||
assert len(all_list) == imgs_cnt
|
|
||||||
random.shuffle(all_list)
|
|
||||||
train_test_split_txt = os.path.join(src_dir, "train_test_split.txt")
|
|
||||||
with open(train_test_split_txt, "w") as fw:
|
|
||||||
with open(images_txt, "r") as fr:
|
|
||||||
for cur_l in fr:
|
|
||||||
img_index, img_f = cur_l.strip().split(" ")
|
|
||||||
cur_line = img_index + " " + str(all_list[int(img_index) - 1])
|
|
||||||
fw.write(cur_line + "\n")
|
|
||||||
|
105
train.py
105
train.py
@ -24,9 +24,7 @@ import pdb
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
|
||||||
|
|
||||||
#计算并存储平均值
|
|
||||||
class AverageMeter(object):
|
class AverageMeter(object):
|
||||||
"""Computes and stores the average and current value"""
|
"""Computes and stores the average and current value"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -44,24 +42,19 @@ class AverageMeter(object):
|
|||||||
self.count += n
|
self.count += n
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
#简单准确率
|
|
||||||
def simple_accuracy(preds, labels):
|
def simple_accuracy(preds, labels):
|
||||||
return (preds == labels).mean()
|
return (preds == labels).mean()
|
||||||
|
|
||||||
#求均值
|
|
||||||
def reduce_mean(tensor, nprocs):
|
def reduce_mean(tensor, nprocs):
|
||||||
rt = tensor.clone()
|
rt = tensor.clone()
|
||||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||||
rt /= nprocs
|
rt /= nprocs
|
||||||
return rt
|
return rt
|
||||||
|
|
||||||
#保存模型
|
|
||||||
def save_model(args, model):
|
def save_model(args, model):
|
||||||
<<<<<<< HEAD
|
|
||||||
model_checkpoint = os.path.join("../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
|
|
||||||
torch.save(model, model_checkpoint)
|
|
||||||
logger.info("Saved model checkpoint to [File: %s]", "../module/ieemoo-ai-isempty/model/now/emptyjudge5_checkpoint.bin")
|
|
||||||
=======
|
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model
|
model_to_save = model.module if hasattr(model, 'module') else model
|
||||||
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
|
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
@ -79,24 +72,36 @@ def save_eve_model(args, model, eve_name):
|
|||||||
torch.save(checkpoint, model_checkpoint)
|
torch.save(checkpoint, model_checkpoint)
|
||||||
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
|
||||||
|
|
||||||
>>>>>>> develop
|
|
||||||
|
|
||||||
#根据数据集配置模型
|
|
||||||
def setup(args):
|
def setup(args):
|
||||||
# Prepare model
|
# Prepare model
|
||||||
config = CONFIGS[args.model_type]
|
config = CONFIGS[args.model_type]
|
||||||
config.split = args.split
|
config.split = args.split
|
||||||
config.slide_step = args.slide_step
|
config.slide_step = args.slide_step
|
||||||
|
|
||||||
if args.dataset == "emptyJudge5":
|
if args.dataset == "CUB_200_2011":
|
||||||
|
num_classes = 200
|
||||||
|
elif args.dataset == "car":
|
||||||
|
num_classes = 196
|
||||||
|
elif args.dataset == "nabirds":
|
||||||
|
num_classes = 555
|
||||||
|
elif args.dataset == "dog":
|
||||||
|
num_classes = 120
|
||||||
|
elif args.dataset == "INat2017":
|
||||||
|
num_classes = 5089
|
||||||
|
elif args.dataset == "emptyJudge5":
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
|
elif args.dataset == "emptyJudge4":
|
||||||
|
num_classes = 4
|
||||||
|
elif args.dataset == "emptyJudge3":
|
||||||
|
num_classes = 3
|
||||||
|
|
||||||
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
|
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, smoothing_value=args.smoothing_value)
|
||||||
|
|
||||||
if args.pretrained_dir is not None:
|
model.load_from(np.load(args.pretrained_dir))
|
||||||
model.load_from(np.load(args.pretrained_dir)) #他人预训练模型
|
|
||||||
if args.pretrained_model is not None:
|
if args.pretrained_model is not None:
|
||||||
model = torch.load(args.pretrained_model) #自己预训练模型
|
pretrained_model = torch.load(args.pretrained_model)['model']
|
||||||
|
model.load_state_dict(pretrained_model)
|
||||||
#model.to(args.device)
|
#model.to(args.device)
|
||||||
#pdb.set_trace()
|
#pdb.set_trace()
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
@ -104,15 +109,15 @@ def setup(args):
|
|||||||
logger.info("{}".format(config))
|
logger.info("{}".format(config))
|
||||||
logger.info("Training parameters %s", args)
|
logger.info("Training parameters %s", args)
|
||||||
logger.info("Total Parameter: \t%2.1fM" % num_params)
|
logger.info("Total Parameter: \t%2.1fM" % num_params)
|
||||||
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
|
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
|
||||||
return args, model
|
return args, model
|
||||||
|
|
||||||
#计算模型参数数量
|
|
||||||
def count_parameters(model):
|
def count_parameters(model):
|
||||||
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
return params/1000000
|
return params/1000000
|
||||||
|
|
||||||
#随机种子
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@ -120,7 +125,7 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
#模型验证
|
|
||||||
def valid(args, model, writer, test_loader, global_step):
|
def valid(args, model, writer, test_loader, global_step):
|
||||||
eval_losses = AverageMeter()
|
eval_losses = AverageMeter()
|
||||||
|
|
||||||
@ -177,7 +182,7 @@ def valid(args, model, writer, test_loader, global_step):
|
|||||||
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
|
writer.add_scalar("test/accuracy", scalar_value=val_accuracy, global_step=global_step)
|
||||||
return val_accuracy
|
return val_accuracy
|
||||||
|
|
||||||
#模型训练
|
|
||||||
def train(args, model):
|
def train(args, model):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@ -293,54 +298,37 @@ def train(args, model):
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
|
logger.info("Total Training Time: \t%f" % ((end_time - start_time) / 3600))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#主函数
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# Required parameters
|
# Required parameters
|
||||||
parser.add_argument("--name", type=str, default='ieemooempty',
|
parser.add_argument("--name", required=True,
|
||||||
help="Name of this run. Used for monitoring.")
|
help="Name of this run. Used for monitoring.")
|
||||||
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
|
parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "INat2017", "emptyJudge5", "emptyJudge4"],
|
||||||
<<<<<<< HEAD
|
|
||||||
default="emptyJudge5", help="Which dataset.")
|
|
||||||
parser.add_argument('--data_root', type=str, default='./')
|
|
||||||
=======
|
|
||||||
default="CUB_200_2011", help="Which dataset.")
|
default="CUB_200_2011", help="Which dataset.")
|
||||||
parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained')
|
parser.add_argument('--data_root', type=str, default='/data/pfc/fineGrained')
|
||||||
>>>>>>> develop
|
|
||||||
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
|
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", "ViT-L_32", "ViT-H_14"],
|
||||||
default="ViT-B_16",help="Which variant to use.")
|
default="ViT-B_16",help="Which variant to use.")
|
||||||
parser.add_argument("--pretrained_dir", type=str, default="./preckpts/ViT-B_16.npz",
|
parser.add_argument("--pretrained_dir", type=str, default="ckpts/ViT-B_16.npz",
|
||||||
help="Where to search for pretrained ViT models.")
|
help="Where to search for pretrained ViT models.")
|
||||||
#parser.add_argument("--pretrained_model", type=str, default="./output/ieemooempty_checkpoint_good.pth", help="load pretrained model") #None
|
parser.add_argument("--pretrained_model", type=str, default="output/emptyjudge5_checkpoint.bin", help="load pretrained model")
|
||||||
# parser.add_argument("--pretrained_dir", type=str, default=None,
|
#parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model")
|
||||||
# help="Where to search for pretrained ViT models.")
|
|
||||||
parser.add_argument("--pretrained_model", type=str, default=None, help="load pretrained model") #None
|
|
||||||
parser.add_argument("--output_dir", default="./output", type=str,
|
parser.add_argument("--output_dir", default="./output", type=str,
|
||||||
help="The output directory where checkpoints will be written.")
|
help="The output directory where checkpoints will be written.")
|
||||||
parser.add_argument("--img_size", default=600, type=int, help="Resolution size")
|
parser.add_argument("--img_size", default=448, type=int, help="Resolution size")
|
||||||
parser.add_argument("--train_batch_size", default=8, type=int,
|
parser.add_argument("--train_batch_size", default=64, type=int,
|
||||||
help="Total batch size for training.")
|
help="Total batch size for training.")
|
||||||
parser.add_argument("--eval_batch_size", default=8, type=int,
|
parser.add_argument("--eval_batch_size", default=16, type=int,
|
||||||
help="Total batch size for eval.")
|
help="Total batch size for eval.")
|
||||||
<<<<<<< HEAD
|
|
||||||
parser.add_argument("--eval_every", default=786, type=int,
|
|
||||||
=======
|
|
||||||
parser.add_argument("--eval_every", default=200, type=int, #200
|
parser.add_argument("--eval_every", default=200, type=int, #200
|
||||||
>>>>>>> develop
|
|
||||||
help="Run prediction on validation set every so many steps."
|
help="Run prediction on validation set every so many steps."
|
||||||
"Will always run one evaluation at the end of training.")
|
"Will always run one evaluation at the end of training.")
|
||||||
|
|
||||||
parser.add_argument("--learning_rate", default=3e-2, type=float,
|
parser.add_argument("--learning_rate", default=3e-2, type=float,
|
||||||
help="The initial learning rate for SGD.")
|
help="The initial learning rate for SGD.")
|
||||||
parser.add_argument("--weight_decay", default=0.00001, type=float,
|
parser.add_argument("--weight_decay", default=0, type=float,
|
||||||
help="Weight deay if we apply some.")
|
help="Weight deay if we apply some.")
|
||||||
<<<<<<< HEAD
|
|
||||||
parser.add_argument("--num_steps", default=78600, type=int, #100000
|
|
||||||
=======
|
|
||||||
parser.add_argument("--num_steps", default=40000, type=int, #100000
|
parser.add_argument("--num_steps", default=40000, type=int, #100000
|
||||||
>>>>>>> develop
|
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
|
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
|
||||||
help="How to decay the learning rate.")
|
help="How to decay the learning rate.")
|
||||||
@ -355,6 +343,15 @@ def main():
|
|||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
|
parser.add_argument('--fp16', action='store_true',
|
||||||
|
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||||
|
parser.add_argument('--fp16_opt_level', type=str, default='O2',
|
||||||
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html")
|
||||||
|
parser.add_argument('--loss_scale', type=float, default=0,
|
||||||
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||||
|
"0 (default value): dynamic loss scaling.\n"
|
||||||
|
"Positive power of 2: static loss scaling value.\n")
|
||||||
|
|
||||||
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
|
||||||
|
|
||||||
@ -366,10 +363,7 @@ def main():
|
|||||||
# Setup CUDA, GPU & distributed training
|
# Setup CUDA, GPU & distributed training
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
<<<<<<< HEAD
|
|
||||||
=======
|
|
||||||
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||||
>>>>>>> develop
|
|
||||||
args.n_gpu = torch.cuda.device_count()
|
args.n_gpu = torch.cuda.device_count()
|
||||||
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
#print('torch.cuda.device_count()>>>>>>>>>>>>>>>>>>>>>>>>>', torch.cuda.device_count())
|
||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
@ -384,8 +378,8 @@ def main():
|
|||||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" %
|
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
|
||||||
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1)))
|
(args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@ -397,5 +391,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.cuda.empty_cache()
|
|
||||||
main()
|
main()
|
||||||
|
@ -101,7 +101,7 @@ def get_loader(args):
|
|||||||
testset = INat2017(args.data_root, 'val', test_transform)
|
testset = INat2017(args.data_root, 'val', test_transform)
|
||||||
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
|
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
|
||||||
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||||
transforms.RandomCrop((320, 320)),
|
transforms.RandomCrop((448, 448)),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
@ -109,7 +109,7 @@ def get_loader(args):
|
|||||||
# transforms.CenterCrop((448, 448)),
|
# transforms.CenterCrop((448, 448)),
|
||||||
# transforms.ToTensor(),
|
# transforms.ToTensor(),
|
||||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR),
|
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||||
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
|
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
|
||||||
|
@ -5,7 +5,7 @@ from os.path import join
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
from scipy import io
|
from scipy import io
|
||||||
import imageio
|
import scipy.misc
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset
|
|||||||
from torchvision.datasets.folder import default_loader
|
from torchvision.datasets.folder import default_loader
|
||||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||||
|
|
||||||
#对各种数据集的底层读取
|
|
||||||
class emptyJudge():
|
class emptyJudge():
|
||||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||||
self.root = root
|
self.root = root
|
||||||
@ -37,12 +37,12 @@ class emptyJudge():
|
|||||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||||
train_file_list[:data_len]]
|
train_file_list[:data_len]]
|
||||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||||
if not self.is_train:
|
if not self.is_train:
|
||||||
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||||
test_file_list[:data_len]]
|
test_file_list[:data_len]]
|
||||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||||
@ -51,7 +51,7 @@ class emptyJudge():
|
|||||||
if self.is_train:
|
if self.is_train:
|
||||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||||
if len(img.shape) == 2:
|
if len(img.shape) == 2:
|
||||||
img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
|
img = np.stack([img] * 3, 2)
|
||||||
img = Image.fromarray(img, mode='RGB')
|
img = Image.fromarray(img, mode='RGB')
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
@ -91,12 +91,12 @@ class CUB():
|
|||||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||||
train_file_list[:data_len]]
|
train_file_list[:data_len]]
|
||||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||||
if not self.is_train:
|
if not self.is_train:
|
||||||
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||||
test_file_list[:data_len]]
|
test_file_list[:data_len]]
|
||||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||||
|
Reference in New Issue
Block a user