This commit is contained in:
lee
2025-07-02 14:41:12 +08:00
parent 061820c34f
commit 537ed838fc
8 changed files with 124 additions and 58 deletions

View File

@ -3,6 +3,12 @@
<component name="CopilotChatHistory">
<option name="conversations">
<list>
<Conversation>
<option name="createTime" value="1750474299387" />
<option name="id" value="0197906617fb7194a0407baae2b1e2eb" />
<option name="title" value="新对话 2025年6月21日 10:51:39" />
<option name="updateTime" value="1750474299387" />
</Conversation>
<Conversation>
<option name="createTime" value="1749793513436" />
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />

View File

@ -18,20 +18,20 @@ models:
# 训练参数
training:
epochs: 300 # 总训练轮次
epochs: 600 # 总训练轮次
batch_size: 64 # 批次大小
lr: 0.005 # 初始学习率
lr: 0.0004 # 初始学习率
optimizer: "sgd" # 优化器类型
metric: 'arcface' # 损失函数类型可选arcface/cosface/sphereface/softmax
loss: "cross_entropy" # 损失函数类型可选cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax
lr_step: 10 # 学习率调整间隔epoch
lr_decay: 0.98 # 学习率衰减率
lr_decay: 0.95 # 学习率衰减率
weight_decay: 0.0005 # 权重衰减
scheduler: "cosine_annealing" # 学习率调度器可选cosine_annealing/step/none
scheduler: "step" # 学习率调度器可选cosine_annealing/step/none
num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
checkpoints: "./checkpoints/resnet18_scatter_6.26/" # 模型保存目录
restore: True
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
@ -46,8 +46,8 @@ data:
train_batch_size: 128 # 训练批次大小
val_batch_size: 100 # 验证批次大小
num_workers: 32 # 数据加载线程数
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
data_train_dir: "../data_center/scatter/v2/train" # 训练数据集根目录
data_val_dir: "../data_center/scatter/v2/val" # 验证数据集根目录
transform:
img_size: 224 # 图像尺寸
@ -59,7 +59,7 @@ transform:
# 日志与监控
logging:
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
logging_dir: "./log/2025.6.25-scatter.txt" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch

View File

@ -13,17 +13,21 @@ base:
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 0.75
model_path: "./checkpoints/resnet18_1009/best.pth"
channel_ratio: 1.0
model_path: "checkpoints/resnet18_scatter_6.26/best.pth"
half: false # 是否启用半精度测试fp16
contrast_learning: false
# 数据配置
data:
test_batch_size: 128 # 训练批次大小
num_workers: 32 # 数据加载线程数
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
test_dir: "../data_center/scatter/v2/val_extar" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
test_list: "../data_center/scatter/val_extar_cross_same.txt"
group_test: false
save_image_joint: true
image_joint_pth: "./joint_images"
transform:
img_size: 224 # 图像尺寸

View File

@ -18,8 +18,8 @@ models:
channel_ratio: 0.75
model_path: "../checkpoints/resnet18_1009/best.pth"
onnx_model: "../checkpoints/resnet18_1009/best.onnx"
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2.rknn"
rknn_batch_size: 1
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2_batch16.rknn"
rknn_batch_size: 16
# 日志与监控
logging:

View File

@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
# from config import config as conf
from tools.dataset import get_transform
from tools.image_joint import merge_imgs
from configs import trainer_tools
import yaml
@ -22,6 +23,7 @@ embedding_size = conf["base"]["embedding_size"]
img_size = conf["transform"]["img_size"]
device = conf["base"]["device"]
def unique_image(pair_list: str) -> Set[str]:
unique_images = set()
try:
@ -38,7 +40,7 @@ def unique_image(pair_list: str) -> Set[str]:
except IOError as e:
print(f"Error reading pair list file: {e}")
raise
return unique_images
@ -56,11 +58,11 @@ def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
image_list = list(images)
num_images = len(image_list)
batches = []
for i in range(0, num_images, batch_size):
batch_end = min(i + batch_size, num_images)
batches.append(image_list[i:batch_end])
return batches
@ -89,21 +91,21 @@ def test_preprocess(images: list, transform) -> torch.Tensor:
def featurize(
images: List[str],
transform: callable,
net: nn.Module,
device: torch.device,
train: bool = False
images: List[str],
transform: callable,
net: nn.Module,
device: torch.device,
train: bool = False
) -> Dict[str, torch.Tensor]:
try:
# Select appropriate preprocessing
preprocess_fn = _preprocess if train else test_preprocess
# Preprocess and move to device
data = preprocess_fn(images, transform)
data = data.to(device)
net = net.to(device)
# Extract features with automatic mixed precision
with torch.no_grad():
if conf['models']['half']:
@ -111,12 +113,16 @@ def featurize(
features = net(data)
# Create path-to-feature mapping
return {img: feature for img, feature in zip(images, features)}
except Exception as e:
print(f"Error in feature extraction: {e}")
raise
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def threshold_search(y_score, y_true):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
@ -179,22 +185,23 @@ def compute_accuracy_recall(score, labels):
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
PreciseNeg[-1], recall_TN[-1]))
showHist(Same, Cross)
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
def compute_accuracy(
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
) -> Tuple[float, float]:
try:
with open(pair_list, 'r') as f:
@ -205,37 +212,43 @@ def compute_accuracy(
similarities = []
labels = []
for pair in pairs:
pair = pair.strip()
if not pair:
continue
try:
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
# Verify features exist
if img1_path not in feature_dict or img2_path not in feature_dict:
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
# Get features and compute similarity
feat1 = feature_dict[img1_path].cpu().numpy()
feat2 = feature_dict[img2_path].cpu().numpy()
similarity = cosin_metric(feat1, feat2)
similarities.append(similarity)
labels.append(int(label))
except Exception as e:
print(f"Skipping invalid pair: {pair}. Error: {e}")
continue
# try:
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
# Verify features exist
if img1_path not in feature_dict or img2_path not in feature_dict:
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
# Get features and compute similarity
feat1 = feature_dict[img1_path].cpu().numpy()
feat2 = feature_dict[img2_path].cpu().numpy()
similarity = cosin_metric(feat1, feat2)
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
if conf['data']['save_image_joint']:
merge_imgs(img1_path,
img2_path,
conf['data']['image_joint_pth'],
similarity,
label)
similarities.append(similarity)
labels.append(int(label))
# except Exception as e:
# print(f"Skipping invalid pair: {pair}. Error: {e}")
# continue
# Find optimal threshold and accuracy
accuracy, threshold = threshold_search(similarities, labels)
compute_accuracy_recall(np.array(similarities), np.array(labels))
return accuracy, threshold
@ -267,10 +280,10 @@ def compute_group_accuracy(content_list_read):
d = featurize(group[0], conf.test_transform, model, conf.device)
one_group_list.append(d.values())
if data_loaded[-1] == '1':
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
Same.append(similarity)
else:
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
Cross.append(similarity)
allLabel.append(data_loaded[-1])
allSimilarity.extend(similarity)
@ -291,7 +304,17 @@ def init_model():
print('load model {} '.format(conf['models']['backbone']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
model = nn.DataParallel(model).to(conf['base']['device'])
###############正常模型加载################
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
#######################################
####### 对比学习模型临时运用###
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
# new_state_dict = {}
# for k, v in state_dict.items():
# new_key = k.replace("module.base_model.", "module.")
# new_state_dict[new_key] = v
# model.load_state_dict(new_state_dict, strict=False)
###########################
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype

View File

@ -188,7 +188,7 @@ class PairGenerator:
if __name__ == "__main__":
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
original_path = '/home/lc/data_center/scatter/val_extar'
parent_dir = str(Path(original_path).parent)
generator = PairGenerator()

33
tools/image_joint.py Normal file
View File

@ -0,0 +1,33 @@
from PIL import Image, ImageDraw, ImageFont
import os
def merge_imgs(img1_path, img2_path, save_path, similar=None, label=None):
position = (50, 50) # 文字的左上角坐标
color = (255, 0, 0) # 红色文字,格式为 RGB
if not os.path.exists(os.sep.join([save_path, str(label)])):
os.makedirs(os.sep.join([save_path, str(label)]))
save_path = os.sep.join([save_path, str(label)])
img_name = os.path.basename(img1_path).split('.')[0]+'_'+os.path.basename(img2_path).split('.')[0]+'.png'
img1 = Image.open(img1_path)
img2 = Image.open(img2_path)
img1 = img1.resize((224,224))
img2 = img2.resize((224,224))
print('img1_path', img1)
print('img2_path', img2)
assert img1.height == img2.height
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
# print('new_img', new_img)
new_img.paste(img1, (0, 0))
new_img.paste(img2, (img1.width + 10, 0))
if similar is not None:
similar = str(similar)+'_'+str(label)
draw = ImageDraw.Draw(new_img)
draw.text(position, str(similar), color, font_size=36)
os.makedirs(save_path, exist_ok=True)
img_save = os.path.join(save_path, img_name)
new_img.save(img_save)

View File

@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
with open('configs/compare.yml', 'r') as f:
with open('configs/scatter.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup