This commit is contained in:
lee
2025-07-02 14:41:12 +08:00
parent 061820c34f
commit 537ed838fc
8 changed files with 124 additions and 58 deletions

View File

@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
# from config import config as conf
from tools.dataset import get_transform
from tools.image_joint import merge_imgs
from configs import trainer_tools
import yaml
@ -22,6 +23,7 @@ embedding_size = conf["base"]["embedding_size"]
img_size = conf["transform"]["img_size"]
device = conf["base"]["device"]
def unique_image(pair_list: str) -> Set[str]:
unique_images = set()
try:
@ -38,7 +40,7 @@ def unique_image(pair_list: str) -> Set[str]:
except IOError as e:
print(f"Error reading pair list file: {e}")
raise
return unique_images
@ -56,11 +58,11 @@ def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
image_list = list(images)
num_images = len(image_list)
batches = []
for i in range(0, num_images, batch_size):
batch_end = min(i + batch_size, num_images)
batches.append(image_list[i:batch_end])
return batches
@ -89,21 +91,21 @@ def test_preprocess(images: list, transform) -> torch.Tensor:
def featurize(
images: List[str],
transform: callable,
net: nn.Module,
device: torch.device,
train: bool = False
images: List[str],
transform: callable,
net: nn.Module,
device: torch.device,
train: bool = False
) -> Dict[str, torch.Tensor]:
try:
# Select appropriate preprocessing
preprocess_fn = _preprocess if train else test_preprocess
# Preprocess and move to device
data = preprocess_fn(images, transform)
data = data.to(device)
net = net.to(device)
# Extract features with automatic mixed precision
with torch.no_grad():
if conf['models']['half']:
@ -111,12 +113,16 @@ def featurize(
features = net(data)
# Create path-to-feature mapping
return {img: feature for img, feature in zip(images, features)}
except Exception as e:
print(f"Error in feature extraction: {e}")
raise
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def threshold_search(y_score, y_true):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
@ -179,22 +185,23 @@ def compute_accuracy_recall(score, labels):
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
PreciseNeg[-1], recall_TN[-1]))
showHist(Same, Cross)
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
def compute_accuracy(
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
) -> Tuple[float, float]:
try:
with open(pair_list, 'r') as f:
@ -205,37 +212,43 @@ def compute_accuracy(
similarities = []
labels = []
for pair in pairs:
pair = pair.strip()
if not pair:
continue
try:
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
# Verify features exist
if img1_path not in feature_dict or img2_path not in feature_dict:
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
# Get features and compute similarity
feat1 = feature_dict[img1_path].cpu().numpy()
feat2 = feature_dict[img2_path].cpu().numpy()
similarity = cosin_metric(feat1, feat2)
similarities.append(similarity)
labels.append(int(label))
except Exception as e:
print(f"Skipping invalid pair: {pair}. Error: {e}")
continue
# try:
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
# Verify features exist
if img1_path not in feature_dict or img2_path not in feature_dict:
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
# Get features and compute similarity
feat1 = feature_dict[img1_path].cpu().numpy()
feat2 = feature_dict[img2_path].cpu().numpy()
similarity = cosin_metric(feat1, feat2)
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
if conf['data']['save_image_joint']:
merge_imgs(img1_path,
img2_path,
conf['data']['image_joint_pth'],
similarity,
label)
similarities.append(similarity)
labels.append(int(label))
# except Exception as e:
# print(f"Skipping invalid pair: {pair}. Error: {e}")
# continue
# Find optimal threshold and accuracy
accuracy, threshold = threshold_search(similarities, labels)
compute_accuracy_recall(np.array(similarities), np.array(labels))
return accuracy, threshold
@ -267,10 +280,10 @@ def compute_group_accuracy(content_list_read):
d = featurize(group[0], conf.test_transform, model, conf.device)
one_group_list.append(d.values())
if data_loaded[-1] == '1':
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
Same.append(similarity)
else:
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
Cross.append(similarity)
allLabel.append(data_loaded[-1])
allSimilarity.extend(similarity)
@ -291,7 +304,17 @@ def init_model():
print('load model {} '.format(conf['models']['backbone']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
model = nn.DataParallel(model).to(conf['base']['device'])
###############正常模型加载################
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
#######################################
####### 对比学习模型临时运用###
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
# new_state_dict = {}
# for k, v in state_dict.items():
# new_key = k.replace("module.base_model.", "module.")
# new_state_dict[new_key] = v
# model.load_state_dict(new_state_dict, strict=False)
###########################
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype