训练数据前置处理与提升训练效率
This commit is contained in:
32
test_ori.py
32
test_ori.py
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
||||
# from config import config as conf
|
||||
from tools.dataset import get_transform
|
||||
from tools.image_joint import merge_imgs
|
||||
from tools.getHeatMap import cal_cam
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
@ -201,10 +202,11 @@ def compute_accuracy_recall(score, labels):
|
||||
|
||||
def compute_accuracy(
|
||||
feature_dict: Dict[str, torch.Tensor],
|
||||
pair_list: str,
|
||||
test_root: str
|
||||
cam: cal_cam,
|
||||
) -> Tuple[float, float]:
|
||||
try:
|
||||
pair_list = conf['data']['test_list']
|
||||
test_root = conf['data']['test_dir']
|
||||
with open(pair_list, 'r') as f:
|
||||
pairs = f.readlines()
|
||||
except IOError as e:
|
||||
@ -220,6 +222,7 @@ def compute_accuracy(
|
||||
continue
|
||||
|
||||
# try:
|
||||
print(f"Processing pair: {pair}")
|
||||
img1, img2, label = pair.split()
|
||||
img1_path = osp.join(test_root, img1)
|
||||
img2_path = osp.join(test_root, img2)
|
||||
@ -236,9 +239,10 @@ def compute_accuracy(
|
||||
if conf['data']['save_image_joint']:
|
||||
merge_imgs(img1_path,
|
||||
img2_path,
|
||||
conf['data']['image_joint_pth'],
|
||||
conf,
|
||||
similarity,
|
||||
label)
|
||||
label,
|
||||
cam)
|
||||
similarities.append(similarity)
|
||||
labels.append(int(label))
|
||||
|
||||
@ -306,7 +310,8 @@ def init_model():
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||
###############正常模型加载################
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device']))
|
||||
#######################################
|
||||
####### 对比学习模型临时运用###
|
||||
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||
@ -321,7 +326,18 @@ def init_model():
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
else:
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
try:
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device']))
|
||||
except:
|
||||
state_dict = torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device'])
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k.replace("module.", "")
|
||||
new_state_dict[new_key] = v
|
||||
model.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
@ -332,7 +348,7 @@ def init_model():
|
||||
if __name__ == '__main__':
|
||||
model = init_model()
|
||||
model.eval()
|
||||
|
||||
cam = cal_cam(model, conf)
|
||||
if not conf['data']['group_test']:
|
||||
images = unique_image(conf['data']['test_list'])
|
||||
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||||
@ -342,7 +358,7 @@ if __name__ == '__main__':
|
||||
for group in groups:
|
||||
d = featurize(group, test_transform, model, conf['base']['device'])
|
||||
feature_dict.update(d)
|
||||
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
|
||||
accuracy, threshold = compute_accuracy(feature_dict, cam)
|
||||
print(
|
||||
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||||
)
|
||||
|
Reference in New Issue
Block a user