训练数据前置处理与提升训练效率

This commit is contained in:
lee
2025-07-10 14:24:05 +08:00
parent 0701538a73
commit 09f41f6289
15 changed files with 430 additions and 116 deletions

View File

@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
# from config import config as conf
from tools.dataset import get_transform
from tools.image_joint import merge_imgs
from tools.getHeatMap import cal_cam
from configs import trainer_tools
import yaml
from datetime import datetime
@ -201,10 +202,11 @@ def compute_accuracy_recall(score, labels):
def compute_accuracy(
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
cam: cal_cam,
) -> Tuple[float, float]:
try:
pair_list = conf['data']['test_list']
test_root = conf['data']['test_dir']
with open(pair_list, 'r') as f:
pairs = f.readlines()
except IOError as e:
@ -220,6 +222,7 @@ def compute_accuracy(
continue
# try:
print(f"Processing pair: {pair}")
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
@ -236,9 +239,10 @@ def compute_accuracy(
if conf['data']['save_image_joint']:
merge_imgs(img1_path,
img2_path,
conf['data']['image_joint_pth'],
conf,
similarity,
label)
label,
cam)
similarities.append(similarity)
labels.append(int(label))
@ -306,7 +310,8 @@ def init_model():
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
model = nn.DataParallel(model).to(conf['base']['device'])
###############正常模型加载################
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
model.load_state_dict(torch.load(conf['models']['model_path'],
map_location=conf['base']['device']))
#######################################
####### 对比学习模型临时运用###
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
@ -321,7 +326,18 @@ def init_model():
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
else:
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
try:
model.load_state_dict(torch.load(conf['models']['model_path'],
map_location=conf['base']['device']))
except:
state_dict = torch.load(conf['models']['model_path'],
map_location=conf['base']['device'])
new_state_dict = {}
for k, v in state_dict.items():
new_key = k.replace("module.", "")
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict, strict=False)
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
@ -332,7 +348,7 @@ def init_model():
if __name__ == '__main__':
model = init_model()
model.eval()
cam = cal_cam(model, conf)
if not conf['data']['group_test']:
images = unique_image(conf['data']['test_list'])
images = [osp.join(conf['data']['test_dir'], img) for img in images]
@ -342,7 +358,7 @@ if __name__ == '__main__':
for group in groups:
d = featurize(group, test_transform, model, conf['base']['device'])
feature_dict.update(d)
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
accuracy, threshold = compute_accuracy(feature_dict, cam)
print(
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
)