更改
This commit is contained in:
105
test_ori.py
105
test_ori.py
@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
# from config import config as conf
|
||||
from tools.dataset import get_transform
|
||||
from tools.image_joint import merge_imgs
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
@ -22,6 +23,7 @@ embedding_size = conf["base"]["embedding_size"]
|
||||
img_size = conf["transform"]["img_size"]
|
||||
device = conf["base"]["device"]
|
||||
|
||||
|
||||
def unique_image(pair_list: str) -> Set[str]:
|
||||
unique_images = set()
|
||||
try:
|
||||
@ -38,7 +40,7 @@ def unique_image(pair_list: str) -> Set[str]:
|
||||
except IOError as e:
|
||||
print(f"Error reading pair list file: {e}")
|
||||
raise
|
||||
|
||||
|
||||
return unique_images
|
||||
|
||||
|
||||
@ -56,11 +58,11 @@ def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
|
||||
image_list = list(images)
|
||||
num_images = len(image_list)
|
||||
batches = []
|
||||
|
||||
|
||||
for i in range(0, num_images, batch_size):
|
||||
batch_end = min(i + batch_size, num_images)
|
||||
batches.append(image_list[i:batch_end])
|
||||
|
||||
|
||||
return batches
|
||||
|
||||
|
||||
@ -89,21 +91,21 @@ def test_preprocess(images: list, transform) -> torch.Tensor:
|
||||
|
||||
|
||||
def featurize(
|
||||
images: List[str],
|
||||
transform: callable,
|
||||
net: nn.Module,
|
||||
device: torch.device,
|
||||
train: bool = False
|
||||
images: List[str],
|
||||
transform: callable,
|
||||
net: nn.Module,
|
||||
device: torch.device,
|
||||
train: bool = False
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
try:
|
||||
# Select appropriate preprocessing
|
||||
preprocess_fn = _preprocess if train else test_preprocess
|
||||
|
||||
|
||||
# Preprocess and move to device
|
||||
data = preprocess_fn(images, transform)
|
||||
data = data.to(device)
|
||||
net = net.to(device)
|
||||
|
||||
|
||||
# Extract features with automatic mixed precision
|
||||
with torch.no_grad():
|
||||
if conf['models']['half']:
|
||||
@ -111,12 +113,16 @@ def featurize(
|
||||
features = net(data)
|
||||
# Create path-to-feature mapping
|
||||
return {img: feature for img, feature in zip(images, features)}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in feature extraction: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def cosin_metric(x1, x2):
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
|
||||
|
||||
def threshold_search(y_score, y_true):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
@ -179,22 +185,23 @@ def compute_accuracy_recall(score, labels):
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||
|
||||
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
|
||||
PreciseNeg[-1], recall_TN[-1]))
|
||||
showHist(Same, Cross)
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||
|
||||
|
||||
def compute_accuracy(
|
||||
feature_dict: Dict[str, torch.Tensor],
|
||||
pair_list: str,
|
||||
test_root: str
|
||||
feature_dict: Dict[str, torch.Tensor],
|
||||
pair_list: str,
|
||||
test_root: str
|
||||
) -> Tuple[float, float]:
|
||||
try:
|
||||
with open(pair_list, 'r') as f:
|
||||
@ -205,37 +212,43 @@ def compute_accuracy(
|
||||
|
||||
similarities = []
|
||||
labels = []
|
||||
|
||||
|
||||
for pair in pairs:
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
|
||||
try:
|
||||
img1, img2, label = pair.split()
|
||||
img1_path = osp.join(test_root, img1)
|
||||
img2_path = osp.join(test_root, img2)
|
||||
|
||||
# Verify features exist
|
||||
if img1_path not in feature_dict or img2_path not in feature_dict:
|
||||
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
||||
|
||||
# Get features and compute similarity
|
||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||
similarity = cosin_metric(feat1, feat2)
|
||||
|
||||
similarities.append(similarity)
|
||||
labels.append(int(label))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
continue
|
||||
|
||||
# try:
|
||||
img1, img2, label = pair.split()
|
||||
img1_path = osp.join(test_root, img1)
|
||||
img2_path = osp.join(test_root, img2)
|
||||
|
||||
# Verify features exist
|
||||
if img1_path not in feature_dict or img2_path not in feature_dict:
|
||||
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
||||
|
||||
# Get features and compute similarity
|
||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||
similarity = cosin_metric(feat1, feat2)
|
||||
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
|
||||
if conf['data']['save_image_joint']:
|
||||
merge_imgs(img1_path,
|
||||
img2_path,
|
||||
conf['data']['image_joint_pth'],
|
||||
similarity,
|
||||
label)
|
||||
similarities.append(similarity)
|
||||
labels.append(int(label))
|
||||
|
||||
# except Exception as e:
|
||||
# print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
# continue
|
||||
|
||||
# Find optimal threshold and accuracy
|
||||
accuracy, threshold = threshold_search(similarities, labels)
|
||||
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||||
|
||||
|
||||
return accuracy, threshold
|
||||
|
||||
|
||||
@ -267,10 +280,10 @@ def compute_group_accuracy(content_list_read):
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if data_loaded[-1] == '1':
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||
Same.append(similarity)
|
||||
else:
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||
Cross.append(similarity)
|
||||
allLabel.append(data_loaded[-1])
|
||||
allSimilarity.extend(similarity)
|
||||
@ -291,7 +304,17 @@ def init_model():
|
||||
print('load model {} '.format(conf['models']['backbone']))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||
###############正常模型加载################
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
#######################################
|
||||
####### 对比学习模型临时运用###
|
||||
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||
# new_state_dict = {}
|
||||
# for k, v in state_dict.items():
|
||||
# new_key = k.replace("module.base_model.", "module.")
|
||||
# new_state_dict[new_key] = v
|
||||
# model.load_state_dict(new_state_dict, strict=False)
|
||||
###########################
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
|
Reference in New Issue
Block a user