更改
This commit is contained in:
6
.idea/CopilotChatHistory.xml
generated
6
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,12 @@
|
|||||||
<component name="CopilotChatHistory">
|
<component name="CopilotChatHistory">
|
||||||
<option name="conversations">
|
<option name="conversations">
|
||||||
<list>
|
<list>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1750474299387" />
|
||||||
|
<option name="id" value="0197906617fb7194a0407baae2b1e2eb" />
|
||||||
|
<option name="title" value="新对话 2025年6月21日 10:51:39" />
|
||||||
|
<option name="updateTime" value="1750474299387" />
|
||||||
|
</Conversation>
|
||||||
<Conversation>
|
<Conversation>
|
||||||
<option name="createTime" value="1749793513436" />
|
<option name="createTime" value="1749793513436" />
|
||||||
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />
|
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />
|
||||||
|
@ -18,20 +18,20 @@ models:
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
training:
|
training:
|
||||||
epochs: 300 # 总训练轮次
|
epochs: 600 # 总训练轮次
|
||||||
batch_size: 64 # 批次大小
|
batch_size: 64 # 批次大小
|
||||||
lr: 0.005 # 初始学习率
|
lr: 0.0004 # 初始学习率
|
||||||
optimizer: "sgd" # 优化器类型
|
optimizer: "sgd" # 优化器类型
|
||||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
lr_step: 10 # 学习率调整间隔(epoch)
|
lr_step: 10 # 学习率调整间隔(epoch)
|
||||||
lr_decay: 0.98 # 学习率衰减率
|
lr_decay: 0.95 # 学习率衰减率
|
||||||
weight_decay: 0.0005 # 权重衰减
|
weight_decay: 0.0005 # 权重衰减
|
||||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
|
checkpoints: "./checkpoints/resnet18_scatter_6.26/" # 模型保存目录
|
||||||
restore: True
|
restore: True
|
||||||
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
|
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ data:
|
|||||||
train_batch_size: 128 # 训练批次大小
|
train_batch_size: 128 # 训练批次大小
|
||||||
val_batch_size: 100 # 验证批次大小
|
val_batch_size: 100 # 验证批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
|
data_train_dir: "../data_center/scatter/v2/train" # 训练数据集根目录
|
||||||
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
|
data_val_dir: "../data_center/scatter/v2/val" # 验证数据集根目录
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
@ -59,7 +59,7 @@ transform:
|
|||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
|
logging_dir: "./log/2025.6.25-scatter.txt" # 日志保存目录
|
||||||
tensorboard: true # 是否启用TensorBoard
|
tensorboard: true # 是否启用TensorBoard
|
||||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
@ -13,17 +13,21 @@ base:
|
|||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 0.75
|
channel_ratio: 1.0
|
||||||
model_path: "./checkpoints/resnet18_1009/best.pth"
|
model_path: "checkpoints/resnet18_scatter_6.26/best.pth"
|
||||||
half: false # 是否启用半精度测试(fp16)
|
half: false # 是否启用半精度测试(fp16)
|
||||||
|
contrast_learning: false
|
||||||
|
|
||||||
# 数据配置
|
# 数据配置
|
||||||
data:
|
data:
|
||||||
test_batch_size: 128 # 训练批次大小
|
test_batch_size: 128 # 训练批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
test_dir: "../data_center/scatter/v2/val_extar" # 验证数据集根目录
|
||||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||||
test_list: "../data_center/contrast_learning/contrast_test_data/cross_same.txt"
|
test_list: "../data_center/scatter/val_extar_cross_same.txt"
|
||||||
|
group_test: false
|
||||||
|
save_image_joint: true
|
||||||
|
image_joint_pth: "./joint_images"
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
|
@ -18,8 +18,8 @@ models:
|
|||||||
channel_ratio: 0.75
|
channel_ratio: 0.75
|
||||||
model_path: "../checkpoints/resnet18_1009/best.pth"
|
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||||
onnx_model: "../checkpoints/resnet18_1009/best.onnx"
|
onnx_model: "../checkpoints/resnet18_1009/best.onnx"
|
||||||
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2.rknn"
|
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2_batch16.rknn"
|
||||||
rknn_batch_size: 1
|
rknn_batch_size: 16
|
||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
|
105
test_ori.py
105
test_ori.py
@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
# from config import config as conf
|
# from config import config as conf
|
||||||
from tools.dataset import get_transform
|
from tools.dataset import get_transform
|
||||||
|
from tools.image_joint import merge_imgs
|
||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ embedding_size = conf["base"]["embedding_size"]
|
|||||||
img_size = conf["transform"]["img_size"]
|
img_size = conf["transform"]["img_size"]
|
||||||
device = conf["base"]["device"]
|
device = conf["base"]["device"]
|
||||||
|
|
||||||
|
|
||||||
def unique_image(pair_list: str) -> Set[str]:
|
def unique_image(pair_list: str) -> Set[str]:
|
||||||
unique_images = set()
|
unique_images = set()
|
||||||
try:
|
try:
|
||||||
@ -38,7 +40,7 @@ def unique_image(pair_list: str) -> Set[str]:
|
|||||||
except IOError as e:
|
except IOError as e:
|
||||||
print(f"Error reading pair list file: {e}")
|
print(f"Error reading pair list file: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return unique_images
|
return unique_images
|
||||||
|
|
||||||
|
|
||||||
@ -56,11 +58,11 @@ def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
|
|||||||
image_list = list(images)
|
image_list = list(images)
|
||||||
num_images = len(image_list)
|
num_images = len(image_list)
|
||||||
batches = []
|
batches = []
|
||||||
|
|
||||||
for i in range(0, num_images, batch_size):
|
for i in range(0, num_images, batch_size):
|
||||||
batch_end = min(i + batch_size, num_images)
|
batch_end = min(i + batch_size, num_images)
|
||||||
batches.append(image_list[i:batch_end])
|
batches.append(image_list[i:batch_end])
|
||||||
|
|
||||||
return batches
|
return batches
|
||||||
|
|
||||||
|
|
||||||
@ -89,21 +91,21 @@ def test_preprocess(images: list, transform) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def featurize(
|
def featurize(
|
||||||
images: List[str],
|
images: List[str],
|
||||||
transform: callable,
|
transform: callable,
|
||||||
net: nn.Module,
|
net: nn.Module,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
train: bool = False
|
train: bool = False
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
try:
|
try:
|
||||||
# Select appropriate preprocessing
|
# Select appropriate preprocessing
|
||||||
preprocess_fn = _preprocess if train else test_preprocess
|
preprocess_fn = _preprocess if train else test_preprocess
|
||||||
|
|
||||||
# Preprocess and move to device
|
# Preprocess and move to device
|
||||||
data = preprocess_fn(images, transform)
|
data = preprocess_fn(images, transform)
|
||||||
data = data.to(device)
|
data = data.to(device)
|
||||||
net = net.to(device)
|
net = net.to(device)
|
||||||
|
|
||||||
# Extract features with automatic mixed precision
|
# Extract features with automatic mixed precision
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if conf['models']['half']:
|
if conf['models']['half']:
|
||||||
@ -111,12 +113,16 @@ def featurize(
|
|||||||
features = net(data)
|
features = net(data)
|
||||||
# Create path-to-feature mapping
|
# Create path-to-feature mapping
|
||||||
return {img: feature for img, feature in zip(images, features)}
|
return {img: feature for img, feature in zip(images, features)}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in feature extraction: {e}")
|
print(f"Error in feature extraction: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def cosin_metric(x1, x2):
|
def cosin_metric(x1, x2):
|
||||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||||
|
|
||||||
|
|
||||||
def threshold_search(y_score, y_true):
|
def threshold_search(y_score, y_true):
|
||||||
y_score = np.asarray(y_score)
|
y_score = np.asarray(y_score)
|
||||||
y_true = np.asarray(y_true)
|
y_true = np.asarray(y_true)
|
||||||
@ -179,22 +185,23 @@ def compute_accuracy_recall(score, labels):
|
|||||||
f_labels = (labels == 0)
|
f_labels = (labels == 0)
|
||||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||||
|
|
||||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||||
|
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
|
||||||
|
PreciseNeg[-1], recall_TN[-1]))
|
||||||
showHist(Same, Cross)
|
showHist(Same, Cross)
|
||||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy(
|
def compute_accuracy(
|
||||||
feature_dict: Dict[str, torch.Tensor],
|
feature_dict: Dict[str, torch.Tensor],
|
||||||
pair_list: str,
|
pair_list: str,
|
||||||
test_root: str
|
test_root: str
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
try:
|
try:
|
||||||
with open(pair_list, 'r') as f:
|
with open(pair_list, 'r') as f:
|
||||||
@ -205,37 +212,43 @@ def compute_accuracy(
|
|||||||
|
|
||||||
similarities = []
|
similarities = []
|
||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
pair = pair.strip()
|
pair = pair.strip()
|
||||||
if not pair:
|
if not pair:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
img1, img2, label = pair.split()
|
img1, img2, label = pair.split()
|
||||||
img1_path = osp.join(test_root, img1)
|
img1_path = osp.join(test_root, img1)
|
||||||
img2_path = osp.join(test_root, img2)
|
img2_path = osp.join(test_root, img2)
|
||||||
|
|
||||||
# Verify features exist
|
# Verify features exist
|
||||||
if img1_path not in feature_dict or img2_path not in feature_dict:
|
if img1_path not in feature_dict or img2_path not in feature_dict:
|
||||||
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
||||||
|
|
||||||
# Get features and compute similarity
|
# Get features and compute similarity
|
||||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||||
similarity = cosin_metric(feat1, feat2)
|
similarity = cosin_metric(feat1, feat2)
|
||||||
|
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
|
||||||
similarities.append(similarity)
|
if conf['data']['save_image_joint']:
|
||||||
labels.append(int(label))
|
merge_imgs(img1_path,
|
||||||
|
img2_path,
|
||||||
except Exception as e:
|
conf['data']['image_joint_pth'],
|
||||||
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
similarity,
|
||||||
continue
|
label)
|
||||||
|
similarities.append(similarity)
|
||||||
|
labels.append(int(label))
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||||
|
# continue
|
||||||
|
|
||||||
# Find optimal threshold and accuracy
|
# Find optimal threshold and accuracy
|
||||||
accuracy, threshold = threshold_search(similarities, labels)
|
accuracy, threshold = threshold_search(similarities, labels)
|
||||||
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||||||
|
|
||||||
return accuracy, threshold
|
return accuracy, threshold
|
||||||
|
|
||||||
|
|
||||||
@ -267,10 +280,10 @@ def compute_group_accuracy(content_list_read):
|
|||||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||||
one_group_list.append(d.values())
|
one_group_list.append(d.values())
|
||||||
if data_loaded[-1] == '1':
|
if data_loaded[-1] == '1':
|
||||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||||
Same.append(similarity)
|
Same.append(similarity)
|
||||||
else:
|
else:
|
||||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||||
Cross.append(similarity)
|
Cross.append(similarity)
|
||||||
allLabel.append(data_loaded[-1])
|
allLabel.append(data_loaded[-1])
|
||||||
allSimilarity.extend(similarity)
|
allSimilarity.extend(similarity)
|
||||||
@ -291,7 +304,17 @@ def init_model():
|
|||||||
print('load model {} '.format(conf['models']['backbone']))
|
print('load model {} '.format(conf['models']['backbone']))
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||||
|
###############正常模型加载################
|
||||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||||
|
#######################################
|
||||||
|
####### 对比学习模型临时运用###
|
||||||
|
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||||
|
# new_state_dict = {}
|
||||||
|
# for k, v in state_dict.items():
|
||||||
|
# new_key = k.replace("module.base_model.", "module.")
|
||||||
|
# new_state_dict[new_key] = v
|
||||||
|
# model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
###########################
|
||||||
if conf['models']['half']:
|
if conf['models']['half']:
|
||||||
model.half()
|
model.half()
|
||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
|
@ -188,7 +188,7 @@ class PairGenerator:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
original_path = '/home/lc/data_center/contrast_learning/contrast_test_data/test'
|
original_path = '/home/lc/data_center/scatter/val_extar'
|
||||||
parent_dir = str(Path(original_path).parent)
|
parent_dir = str(Path(original_path).parent)
|
||||||
generator = PairGenerator()
|
generator = PairGenerator()
|
||||||
|
|
||||||
|
33
tools/image_joint.py
Normal file
33
tools/image_joint.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def merge_imgs(img1_path, img2_path, save_path, similar=None, label=None):
|
||||||
|
position = (50, 50) # 文字的左上角坐标
|
||||||
|
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||||
|
if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||||
|
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||||
|
save_path = os.sep.join([save_path, str(label)])
|
||||||
|
img_name = os.path.basename(img1_path).split('.')[0]+'_'+os.path.basename(img2_path).split('.')[0]+'.png'
|
||||||
|
img1 = Image.open(img1_path)
|
||||||
|
img2 = Image.open(img2_path)
|
||||||
|
img1 = img1.resize((224,224))
|
||||||
|
img2 = img2.resize((224,224))
|
||||||
|
print('img1_path', img1)
|
||||||
|
print('img2_path', img2)
|
||||||
|
assert img1.height == img2.height
|
||||||
|
|
||||||
|
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||||
|
|
||||||
|
# print('new_img', new_img)
|
||||||
|
new_img.paste(img1, (0, 0))
|
||||||
|
new_img.paste(img2, (img1.width + 10, 0))
|
||||||
|
|
||||||
|
if similar is not None:
|
||||||
|
similar = str(similar)+'_'+str(label)
|
||||||
|
draw = ImageDraw.Draw(new_img)
|
||||||
|
draw.text(position, str(similar), color, font_size=36)
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
img_save = os.path.join(save_path, img_name)
|
||||||
|
new_img.save(img_save)
|
||||||
|
|
@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
|
|||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
with open('configs/compare.yml', 'r') as f:
|
with open('configs/scatter.yml', 'r') as f:
|
||||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
# Data Setup
|
# Data Setup
|
||||||
|
Reference in New Issue
Block a user