更改
This commit is contained in:
6
.idea/CopilotChatHistory.xml
generated
6
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,12 @@
|
||||
<component name="CopilotChatHistory">
|
||||
<option name="conversations">
|
||||
<list>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1750474299387" />
|
||||
<option name="id" value="0197906617fb7194a0407baae2b1e2eb" />
|
||||
<option name="title" value="新对话 2025年6月21日 10:51:39" />
|
||||
<option name="updateTime" value="1750474299387" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749793513436" />
|
||||
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />
|
||||
|
@ -18,20 +18,20 @@ models:
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
epochs: 300 # 总训练轮次
|
||||
epochs: 600 # 总训练轮次
|
||||
batch_size: 64 # 批次大小
|
||||
lr: 0.005 # 初始学习率
|
||||
lr: 0.0004 # 初始学习率
|
||||
optimizer: "sgd" # 优化器类型
|
||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
lr_decay: 0.95 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
|
||||
checkpoints: "./checkpoints/resnet18_scatter_6.26/" # 模型保存目录
|
||||
restore: True
|
||||
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
|
||||
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
|
||||
|
||||
|
||||
|
||||
@ -46,8 +46,8 @@ data:
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 100 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
|
||||
data_train_dir: "../data_center/scatter/v2/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/scatter/v2/val" # 验证数据集根目录
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
@ -59,7 +59,7 @@ transform:
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
|
||||
logging_dir: "./log/2025.6.25-scatter.txt" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
|
@ -13,17 +13,21 @@ base:
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
model_path: "./checkpoints/resnet18_1009/best.pth"
|
||||
channel_ratio: 1.0
|
||||
model_path: "checkpoints/resnet18_scatter_6.26/best.pth"
|
||||
half: false # 是否启用半精度测试(fp16)
|
||||
contrast_learning: false
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
test_batch_size: 128 # 训练批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
||||
test_dir: "../data_center/scatter/v2/val_extar" # 验证数据集根目录
|
||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||
test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
|
||||
test_list: "../data_center/scatter/val_extar_cross_same.txt"
|
||||
group_test: false
|
||||
save_image_joint: true
|
||||
image_joint_pth: "./joint_images"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
|
@ -18,8 +18,8 @@ models:
|
||||
channel_ratio: 0.75
|
||||
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
onnx_model: "../checkpoints/resnet18_1009/best.onnx"
|
||||
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2.rknn"
|
||||
rknn_batch_size: 1
|
||||
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2_batch16.rknn"
|
||||
rknn_batch_size: 16
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
|
41
test_ori.py
41
test_ori.py
@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
# from config import config as conf
|
||||
from tools.dataset import get_transform
|
||||
from tools.image_joint import merge_imgs
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
@ -22,6 +23,7 @@ embedding_size = conf["base"]["embedding_size"]
|
||||
img_size = conf["transform"]["img_size"]
|
||||
device = conf["base"]["device"]
|
||||
|
||||
|
||||
def unique_image(pair_list: str) -> Set[str]:
|
||||
unique_images = set()
|
||||
try:
|
||||
@ -115,8 +117,12 @@ def featurize(
|
||||
except Exception as e:
|
||||
print(f"Error in feature extraction: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def cosin_metric(x1, x2):
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
|
||||
|
||||
def threshold_search(y_score, y_true):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
@ -179,14 +185,15 @@ def compute_accuracy_recall(score, labels):
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||
|
||||
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
|
||||
PreciseNeg[-1], recall_TN[-1]))
|
||||
showHist(Same, Cross)
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||
|
||||
@ -211,7 +218,7 @@ def compute_accuracy(
|
||||
if not pair:
|
||||
continue
|
||||
|
||||
try:
|
||||
# try:
|
||||
img1, img2, label = pair.split()
|
||||
img1_path = osp.join(test_root, img1)
|
||||
img2_path = osp.join(test_root, img2)
|
||||
@ -224,13 +231,19 @@ def compute_accuracy(
|
||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||
similarity = cosin_metric(feat1, feat2)
|
||||
|
||||
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
|
||||
if conf['data']['save_image_joint']:
|
||||
merge_imgs(img1_path,
|
||||
img2_path,
|
||||
conf['data']['image_joint_pth'],
|
||||
similarity,
|
||||
label)
|
||||
similarities.append(similarity)
|
||||
labels.append(int(label))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
continue
|
||||
# except Exception as e:
|
||||
# print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
# continue
|
||||
|
||||
# Find optimal threshold and accuracy
|
||||
accuracy, threshold = threshold_search(similarities, labels)
|
||||
@ -267,10 +280,10 @@ def compute_group_accuracy(content_list_read):
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if data_loaded[-1] == '1':
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||
Same.append(similarity)
|
||||
else:
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||
Cross.append(similarity)
|
||||
allLabel.append(data_loaded[-1])
|
||||
allSimilarity.extend(similarity)
|
||||
@ -291,7 +304,17 @@ def init_model():
|
||||
print('load model {} '.format(conf['models']['backbone']))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||
###############正常模型加载################
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
#######################################
|
||||
####### 对比学习模型临时运用###
|
||||
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||
# new_state_dict = {}
|
||||
# for k, v in state_dict.items():
|
||||
# new_key = k.replace("module.base_model.", "module.")
|
||||
# new_state_dict[new_key] = v
|
||||
# model.load_state_dict(new_state_dict, strict=False)
|
||||
###########################
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
|
@ -188,7 +188,7 @@ class PairGenerator:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
|
||||
original_path = '/home/lc/data_center/scatter/val_extar'
|
||||
parent_dir = str(Path(original_path).parent)
|
||||
generator = PairGenerator()
|
||||
|
||||
|
33
tools/image_joint.py
Normal file
33
tools/image_joint.py
Normal file
@ -0,0 +1,33 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
|
||||
|
||||
def merge_imgs(img1_path, img2_path, save_path, similar=None, label=None):
|
||||
position = (50, 50) # 文字的左上角坐标
|
||||
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||
if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||
save_path = os.sep.join([save_path, str(label)])
|
||||
img_name = os.path.basename(img1_path).split('.')[0]+'_'+os.path.basename(img2_path).split('.')[0]+'.png'
|
||||
img1 = Image.open(img1_path)
|
||||
img2 = Image.open(img2_path)
|
||||
img1 = img1.resize((224,224))
|
||||
img2 = img2.resize((224,224))
|
||||
print('img1_path', img1)
|
||||
print('img2_path', img2)
|
||||
assert img1.height == img2.height
|
||||
|
||||
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||
|
||||
# print('new_img', new_img)
|
||||
new_img.paste(img1, (0, 0))
|
||||
new_img.paste(img2, (img1.width + 10, 0))
|
||||
|
||||
if similar is not None:
|
||||
similar = str(similar)+'_'+str(label)
|
||||
draw = ImageDraw.Draw(new_img)
|
||||
draw.text(position, str(similar), color, font_size=36)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
img_save = os.path.join(save_path, img_name)
|
||||
new_img.save(img_save)
|
||||
|
@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
with open('configs/compare.yml', 'r') as f:
|
||||
with open('configs/scatter.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Data Setup
|
||||
|
Reference in New Issue
Block a user