Compare commits

19 Commits
dev ... master

Author SHA1 Message Date
lee
c978787ff8 多机并行计算 2025-08-18 10:14:05 +08:00
lee
99a204ee22 多机并行计算 2025-08-14 10:09:54 +08:00
lee
bc896fc688 修改Dataloader提升训练效率 2025-08-07 11:00:36 +08:00
lee
27ffb62223 修改Dataloader提升训练效率 2025-08-07 10:56:32 +08:00
lee
ebba07d1ca 修改Dataloader提升训练效率 2025-08-07 10:52:42 +08:00
lee
3392d76e38 智能秤分析 2025-08-06 17:03:28 +08:00
lee
54898e30ec 数据分析 2025-07-17 14:33:18 +08:00
lee
09f41f6289 训练数据前置处理与提升训练效率 2025-07-10 14:24:05 +08:00
lee
0701538a73 散称训练数据前置处理 2025-07-07 15:19:22 +08:00
lee
6640f2bc5e 训练代码优化 2025-07-03 15:16:58 +08:00
lee
bcbabd9313 并行训练代码优化 2025-07-03 14:20:37 +08:00
lee
5deaf4727f 并行训练代码优化 2025-07-02 18:02:28 +08:00
lee
2219c0a303 更改 2025-07-02 14:53:37 +08:00
lee
537ed838fc 更改 2025-07-02 14:41:12 +08:00
lee
061820c34f 更改 2025-06-19 17:36:24 +08:00
lee
bf9604ec29 更改 2025-06-18 11:30:55 +08:00
lee
180a41ae90 更改 2025-06-13 13:22:41 +08:00
lee
e27e6c3d5b 更改 2025-06-13 10:57:02 +08:00
lee
1803f319a5 增加学习率调度方式 2025-06-13 10:45:53 +08:00
37 changed files with 3070 additions and 732 deletions

1
.gitignore vendored
View File

@ -8,4 +8,5 @@ loss/
checkpoints/
search_library/
quant_imgs/
electronic_imgs/
README.md

View File

@ -3,6 +3,120 @@
<component name="CopilotChatHistory">
<option name="conversations">
<list>
<Conversation>
<option name="createTime" value="1755228773977" />
<option name="id" value="0198abc99e597020bf8aa3ef78bc8bd3" />
<option name="title" value="新对话 2025年8月15日 11:32:53" />
<option name="updateTime" value="1755228773977" />
</Conversation>
<Conversation>
<option name="createTime" value="1755227620606" />
<option name="id" value="0198abb804fe7bf8ab3ac9ecfeae6d3f" />
<option name="title" value="新对话 2025年8月15日 11:13:40" />
<option name="updateTime" value="1755227620606" />
</Conversation>
<Conversation>
<option name="createTime" value="1755219481041" />
<option name="id" value="0198ab3bd1d17216b0dab33158ff294e" />
<option name="title" value="新对话 2025年8月15日 08:58:01" />
<option name="updateTime" value="1755219481041" />
</Conversation>
<Conversation>
<option name="createTime" value="1754286137102" />
<option name="id" value="0198739a1f0e75c38b0579ade7b34050" />
<option name="title" value="新对话 2025年8月04日 13:42:17" />
<option name="updateTime" value="1754286137102" />
</Conversation>
<Conversation>
<option name="createTime" value="1753932970546" />
<option name="id" value="01985e8d3a3170bf871ba640afdf246d" />
<option name="title" value="新对话 2025年7月31日 11:36:10" />
<option name="updateTime" value="1753932970546" />
</Conversation>
<Conversation>
<option name="createTime" value="1753932554257" />
<option name="id" value="01985e86e01170d6a09dca496e3dad46" />
<option name="title" value="新对话 2025年7月31日 11:29:14" />
<option name="updateTime" value="1753932554257" />
</Conversation>
<Conversation>
<option name="createTime" value="1753680371881" />
<option name="id" value="01984f7ee0a9779aabcd3f1671b815b3" />
<option name="title" value="新对话 2025年7月28日 13:26:11" />
<option name="updateTime" value="1753680371881" />
</Conversation>
<Conversation>
<option name="createTime" value="1753405176017" />
<option name="id" value="01983f17b8d173dda926d0ffa5422bbf" />
<option name="title" value="新对话 2025年7月25日 08:59:36" />
<option name="updateTime" value="1753405176017" />
</Conversation>
<Conversation>
<option name="createTime" value="1753065086744" />
<option name="id" value="01982ad25f18712f862c5c18b627f40d" />
<option name="title" value="新对话 2025年7月21日 10:31:26" />
<option name="updateTime" value="1753065086744" />
</Conversation>
<Conversation>
<option name="createTime" value="1752195523240" />
<option name="id" value="0197f6fde2a87e68b893b3a36dfc838f" />
<option name="title" value="新对话 2025年7月11日 08:58:43" />
<option name="updateTime" value="1752195523240" />
</Conversation>
<Conversation>
<option name="createTime" value="1752114061266" />
<option name="id" value="0197f222dfd27515a3dbfea638532ee5" />
<option name="title" value="新对话 2025年7月10日 10:21:01" />
<option name="updateTime" value="1752114061266" />
</Conversation>
<Conversation>
<option name="createTime" value="1751970991660" />
<option name="id" value="0197e99bce2c7a569dee594fb9b6e152" />
<option name="title" value="新对话 2025年7月08日 18:36:31" />
<option name="updateTime" value="1751970991660" />
</Conversation>
<Conversation>
<option name="createTime" value="1751441743239" />
<option name="id" value="0197ca101d8771bd80f2bc4aaf1a8f19" />
<option name="title" value="新对话 2025年7月02日 15:35:43" />
<option name="updateTime" value="1751441743239" />
</Conversation>
<Conversation>
<option name="createTime" value="1751441398488" />
<option name="id" value="0197ca0adad875168de40d792dcb7b4c" />
<option name="title" value="新对话 2025年7月02日 15:29:58" />
<option name="updateTime" value="1751441398488" />
</Conversation>
<Conversation>
<option name="createTime" value="1750474299387" />
<option name="id" value="0197906617fb7194a0407baae2b1e2eb" />
<option name="title" value="新对话 2025年6月21日 10:51:39" />
<option name="updateTime" value="1750474299387" />
</Conversation>
<Conversation>
<option name="createTime" value="1749793513436" />
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />
<option name="title" value="新对话 2025年6月13日 13:45:13" />
<option name="updateTime" value="1749793513436" />
</Conversation>
<Conversation>
<option name="createTime" value="1749792408202" />
<option name="id" value="019767c1428a7aa28682039d57d19778" />
<option name="title" value="新对话 2025年6月13日 13:26:48" />
<option name="updateTime" value="1749792408202" />
</Conversation>
<Conversation>
<option name="createTime" value="1749718122230" />
<option name="id" value="01976353bef6703884544447c919013c" />
<option name="title" value="新对话 2025年6月12日 16:48:42" />
<option name="updateTime" value="1749718122230" />
</Conversation>
<Conversation>
<option name="createTime" value="1749648208122" />
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
<option name="title" value="新对话 2025年6月11日 21:23:28" />
<option name="updateTime" value="1749648208122" />
</Conversation>
<Conversation>
<option name="createTime" value="1749522765718" />
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
@ -57,16 +171,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -91,16 +196,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
</list>
@ -135,16 +231,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -169,16 +256,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -203,16 +281,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -237,16 +306,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -271,16 +331,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -305,16 +356,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -339,16 +381,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -373,16 +406,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -407,16 +431,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -441,16 +456,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -475,16 +481,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -509,16 +506,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -543,16 +531,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -577,16 +556,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -611,16 +581,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -645,16 +606,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -679,16 +631,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -713,16 +656,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -773,16 +707,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -807,16 +732,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
<Turn>
@ -841,16 +757,7 @@
</option>
<option name="status" value="SUCCESS" />
<option name="variables">
<list>
<CodebaseVariable>
<option name="selectedPlaceHolder">
<Object />
</option>
<option name="selectedVariable">
<Object />
</option>
</CodebaseVariable>
</list>
<list />
</option>
</Turn>
</list>

File diff suppressed because one or more lines are too long

View File

@ -20,20 +20,23 @@ models:
# 训练参数
training:
epochs: 600 # 总训练轮次
epochs: 400 # 总训练轮次
batch_size: 128 # 批次大小
lr: 0.001 # 初始学习率
lr: 0.01 # 初始学习率
optimizer: "sgd" # 优化器类型
metric: 'arcface' # 损失函数类型可选arcface/cosface/sphereface/softmax
loss: "cross_entropy" # 损失函数类型可选cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax
lr_step: 10 # 学习率调整间隔epoch
lr_decay: 0.98 # 学习率衰减率
lr_step: 5 # 学习率调整间隔epoch
lr_decay: 0.95 # 学习率衰减率
weight_decay: 0.0005 # 权重衰减
scheduler: "cosine_annealing" # 学习率调度器可选cosine_annealing/step/none
scheduler: "step" # 学习率调度器可选cosine/cosine_warm/step/None
num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
checkpoints: "./checkpoints/resnet18_pdd_test/" # 模型保存目录
restore: false
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
restore_model: "./checkpoints/resnet50_electornic_20250807/best.pth" # 模型恢复路径
cosine_t_0: 10 # 初始周期长度
cosine_t_mult: 1 # 周期长度倍率
cosine_eta_min: 0.00001 # 最小学习率
# 验证参数
validation:
@ -46,8 +49,8 @@ data:
train_batch_size: 128 # 训练批次大小
val_batch_size: 128 # 验证批次大小
num_workers: 32 # 数据加载线程数
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
data_train_dir: "../data_center/electornic/v1/train" # 训练数据集根目录
data_val_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
transform:
img_size: 224 # 图像尺寸
@ -59,11 +62,13 @@ transform:
# 日志与监控
logging:
logging_dir: "./logs" # 日志保存目录
logging_dir: "./logs/resnet50_electornic_log" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch
# 分布式训练(可选)
distributed:
enabled: false # 是否启用分布式训练
enabled: true # 是否启用分布式训练
backend: "nccl" # 分布式后端nccl/gloo
node_rank: 0 # 节点编号
node_num: 2 # 共计几个节点 一般几台机器就有几个节点

View File

@ -51,7 +51,7 @@ data:
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
train_batch_size: 128 # 训练批次大小
val_batch_size: 100 # 验证批次大小
num_workers: 4 # 数据加载线程数
num_workers: 16 # 数据加载线程数
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录

View File

@ -0,0 +1,54 @@
# configs/similar_analysis.yml
# 专为模型训练对比设计的配置文件
# 支持对比不同训练策略如蒸馏vs独立训练
# 基础配置
base:
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
device: "cuda" # 训练设备cuda/cpu
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 0.75
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
# model_path: "../checkpoints/resnet18_1009/best.pth"
heatmap:
feature_layer: "layer4"
show_heatmap: true
# 数据配置
data:
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
train_batch_size: 128 # 训练批次大小
val_batch_size: 8 # 验证批次大小
num_workers: 32 # 数据加载线程数
data_dir: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new"
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
total_pkl: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new/total.pkl"
result_txt: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new/result.txt"
transform:
img_size: 224 # 图像尺寸
img_mean: 0.5 # 图像均值
img_std: 0.5 # 图像方差
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
RandomRotation: 180 # 随机旋转角度
ColorJitter: 0.5 # 随机颜色抖动强度
# 日志与监控
logging:
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch
#event:
# oneToOne_max_th: 0.9
# oneToSn_min_th: 0.6
# event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
# stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
# pickle_path: "event.pickle"

View File

@ -18,20 +18,20 @@ models:
# 训练参数
training:
epochs: 300 # 总训练轮次
epochs: 800 # 总训练轮次
batch_size: 64 # 批次大小
lr: 0.005 # 初始学习率
lr: 0.01 # 初始学习率
optimizer: "sgd" # 优化器类型
metric: 'arcface' # 损失函数类型可选arcface/cosface/sphereface/softmax
loss: "cross_entropy" # 损失函数类型可选cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax
lr_step: 10 # 学习率调整间隔epoch
lr_decay: 0.98 # 学习率衰减率
lr_decay: 0.95 # 学习率衰减率
weight_decay: 0.0005 # 权重衰减
scheduler: "cosine_annealing" # 学习率调度器可选cosine_annealing/step/none
scheduler: "step" # 学习率调度器可选cosine_annealing/step/none
num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
restore: True
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
checkpoints: "./checkpoints/resnet18_scatter_7.4/" # 模型保存目录
restore: false
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
@ -46,8 +46,8 @@ data:
train_batch_size: 128 # 训练批次大小
val_batch_size: 100 # 验证批次大小
num_workers: 32 # 数据加载线程数
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
data_train_dir: "../data_center/scatter/v4/train" # 训练数据集根目录
data_val_dir: "../data_center/scatter/v4/val" # 验证数据集根目录
transform:
img_size: 224 # 图像尺寸
@ -59,7 +59,7 @@ transform:
# 日志与监控
logging:
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch

19
configs/scatter_data.yml Normal file
View File

@ -0,0 +1,19 @@
# configs/scatter_data.yml
# 专为散称前处理的配置文件
# 数据配置
data:
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
source_dir: "../../data_center/electornic/source" # 原始数据
train_dir: "../../data_center/electornic/v1/train" # 训练数据集根目录
val_dir: "../../data_center/electornic/v1/val" # 验证数据集根目录
extra_dir: "../../data_center/electornic/v1/extra" # 验证数据集根目录
split_ratio: 0.9
max_files: 10 # 数据集小于该阈值则归纳至extra
# 日志与监控
logging:
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
log_level: "info" # 日志级别debug/info/warning/error

View File

@ -0,0 +1,54 @@
# configs/similar_analysis.yml
# 专为模型训练对比设计的配置文件
# 支持对比不同训练策略如蒸馏vs独立训练
# 基础配置
base:
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
device: "cuda" # 训练设备cuda/cpu
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 0.75
# model_path: "../checkpoints/resnet18_1009/best.pth"
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
heatmap:
feature_layer: "layer4"
show_heatmap: true
# 数据配置
data:
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
train_batch_size: 128 # 训练批次大小
val_batch_size: 8 # 验证批次大小
num_workers: 32 # 数据加载线程数
data_dir: "/home/lc/data_center/image_analysis/error_compare_subimg"
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
transform:
img_size: 224 # 图像尺寸
img_mean: 0.5 # 图像均值
img_std: 0.5 # 图像方差
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
RandomRotation: 180 # 随机旋转角度
ColorJitter: 0.5 # 随机颜色抖动强度
# 日志与监控
logging:
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch
event:
oneToOneTxt: "/home/lc/detecttracking/oneToOne.txt"
oneToSnTxt: "/home/lc/detecttracking/oneToSn.txt"
oneToOne_max_th: 0.9
oneToSn_min_th: 0.6
event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
pickle_path: "event.pickle"

32
configs/sub_data.yml Normal file
View File

@ -0,0 +1,32 @@
# configs/sub_data.yml
# 专为对比模型训练的数据集设计的配置文件
# 支持对比不同训练策略如蒸馏vs独立训练
# 数据配置
data:
source_dir: "../../data_center/contrast_data/total" # 数据集名称(示例用,可替换为实际数据集)
train_dir: "../../data_center/contrast_data/v1/train" # 训练数据集根目录
val_dir: "../../data_center/contrast_data/v1/val" # 验证数据集根目录
data_extra_dir: "../../data_center/contrast_data/v1/extra"
max_files_ratio: 0.1
min_files: 10
split_ratio: 0.9
combine_scr_dir: "../../data_center/contrast_data/v1/val" # 合并数据集源目录·
combine_dst_dir: "../../data_center/contrast_data/v2/val" # 合并数据集目标目录
extend:
extend_same_dir: true
extend_extra: true
extend_extra_dir: "../../data_center/contrast_data/v1/extra" # 扩展测试集数据
extend_train: true
extend_train_dir: "../../data_center/contrast_data/v1/train" # 训练接数据扩展
limit:
count_limit: true
limit_count: 200
limit_dir: "../../data_center/contrast_data/v1/train" # 限制单个样本数量
control:
combine: true # 是否进行子类数据集合并
split: false # 子类数据集拆解与扩增

View File

@ -8,23 +8,28 @@ base:
log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
distributed: false # 是否启用分布式训练 启用热力图时不能用分布式训练
# 模型配置
models:
backbone: 'resnet18'
channel_ratio: 1.0
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
model_path: "checkpoints/resnet18_electornic_20250806/best.pth"
#resnet18_20250715_scale=0.75_sub
#resnet18_20250718_scale=0.75_nosub
half: false # 是否启用半精度测试fp16
contrast_learning: false
# 数据配置
data:
group_test: False # 数据集名称(示例用,可替换为实际数据集)
test_batch_size: 128 # 训练批次大小
num_workers: 32 # 数据加载线程数
test_dir: "../data_center/scatter/" # 验证数据集根目录
test_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/scatter/val_pair.txt"
test_list: "../data_center/electornic/v1/cross_same.txt"
group_test: false
save_image_joint: false
image_joint_pth: "./joint_images"
transform:
img_size: 224 # 图像尺寸
@ -34,6 +39,11 @@ transform:
RandomRotation: 180 # 随机旋转角度
ColorJitter: 0.5 # 随机颜色抖动强度
heatmap:
image_joint_pth: "./heatmap_joint_images"
feature_layer: "layer4"
show_heatmap: true
save:
save_dir: ""
save_name: ""

29
configs/transform.yml Normal file
View File

@ -0,0 +1,29 @@
# configs/transform.yml
# pth转换onnx配置文件
# 基础配置
base:
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
seed: 42 # 随机种子(保证可复现性)
device: "cuda" # 训练设备cuda/cpu
log_level: "info" # 日志级别debug/info/warning/error
embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练
dataset: "./dataset_electornic.txt" # 数据集名称
# 模型配置
models:
backbone: 'resnet101'
channel_ratio: 1.0
model_path: "../checkpoints/resnet101_electornic_20250807/best.pth"
onnx_model: "../checkpoints/resnet101_electornic_20250807/best.onnx"
rknn_model: "../checkpoints/resnet101_electornic_20250807/resnet101_electornic_3588.rknn"
rknn_batch_size: 1
# 日志与监控
logging:
logging_dir: "./logs" # 日志保存目录
tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch

View File

@ -1,4 +1,4 @@
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
from model import (resnet18, resnet34, resnet50, resnet101, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace
@ -13,7 +13,10 @@ class trainer_tools:
def get_backbone(self):
backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet101': lambda: resnet101(scale=self.conf['models']['channel_ratio'], pretrained=True),
'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
@ -54,3 +57,24 @@ class trainer_tools:
)
}
return optimizer_mapping
def get_scheduler(self, optimizer):
scheduler_mapping = {
'step': lambda: optim.lr_scheduler.StepLR(
optimizer,
step_size=self.conf['training']['lr_step'],
gamma=self.conf['training']['lr_decay']
),
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.conf['training']['epochs'],
eta_min=self.conf['training']['cosine_eta_min']
),
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=self.conf['training'].get('cosine_t_0', 10),
T_mult=self.conf['training'].get('cosine_t_mult', 1),
eta_min=self.conf['training'].get('cosine_eta_min', 0)
)
}
return scheduler_mapping

View File

@ -14,7 +14,7 @@ base:
models:
backbone: 'resnet18'
channel_ratio: 0.75
checkpoints: "../checkpoints/resnet18_1009/best.pth"
checkpoints: "../checkpoints/resnet18_20250718_scale=0.75_nosub/best.pth"
# 数据配置
data:
@ -22,7 +22,7 @@ data:
test_batch_size: 128 # 验证批次大小
num_workers: 32 # 数据加载线程数
half: true # 是否启用半精度数据
img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
img_dirs_path: "/home/lc/data_center/baseStlib/pic/stlib_base" # base标准库图片存储路径
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
xlsx_pth: false # 过滤商品, 默认None不进行过滤
@ -41,7 +41,8 @@ logging:
checkpoint_interval: 30 # 检查点保存间隔epoch
save:
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
json_path: "../data/feature_json_compare/" # 保存单个json文件
json_bin: "../search_library/resnet101_electronic.json" # 保存整个json文件
json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub_1k_合并" # 保存单个json文件路径
error_barcodes: "error_barcodes.txt"
barcodes_statistics: "../search_library/barcodes_statistics.txt"
create_single_json: true # 是否保存单个json文件

View File

View File

@ -0,0 +1,25 @@
import os
import shutil
def combine_dirs(conf):
source_root = conf['data']['combine_scr_dir']
target_root = conf['data']['combine_dst_dir']
for roots, dir_names, files in os.walk(source_root):
for dir_name in dir_names:
source_dir = os.path.join(roots, dir_name)
target_dir = os.path.join(target_root, dir_name.split('_')[0])
if not os.path.exists(target_dir):
os.mkdir(target_dir)
for filename in os.listdir(source_dir):
print(filename)
source_file = os.sep.join([source_dir, filename])
target_file = os.sep.join([target_dir, filename])
shutil.copy(source_file, target_file)
# print(f"已复制目录 {source_dir} 到 {target_dir}")
# if __name__ == '__main__':
# source_root = r'scatter_mini'
# target_root = r'C:\Users\123\Desktop\scatter-1'
# # combine_dirs(conf)

View File

@ -0,0 +1,111 @@
import os
import shutil
from pathlib import Path
import yaml
def count_files(directory):
"""统计目录中的文件数量"""
try:
return len([f for f in os.listdir(directory)
if os.path.isfile(os.path.join(directory, f))])
except Exception as e:
print(f"无法统计目录 {directory}: {e}")
return 0
def clear_empty_dirs(path):
"""
删除空目录
:param path: 目录路径
"""
for root, dirs, files in os.walk(path, topdown=False):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
try:
if not os.listdir(dir_path):
os.rmdir(dir_path)
print(f"Deleted empty directory: {dir_path}")
except Exception as e:
print(f"Error: {e.strerror}")
def get_max_files(conf):
max_files_ratio = conf['data']['max_files_ratio']
files_number = []
for root, dirs, files in os.walk(conf['data']['source_dir']):
if len(dirs) == 0:
if len(files) == 0:
print(root, dirs,files)
files_number.append(len(files))
files_number = sorted(files_number, reverse=False)
max_files = files_number[int(max_files_ratio * len(files_number))]
print(f"max_files: {max_files}")
if max_files < conf['data']['min_files']:
max_files = conf['data']['min_files']
return max_files
def megre_subdirs(pth):
for roots, dir_names, files in os.walk(pth):
print(f"image {dir_names}")
for image in dir_names:
inner_dir_path = os.path.join(pth, image)
for inner_roots, inner_dirs, inner_files in os.walk(inner_dir_path):
for inner_dir in inner_dirs:
src_dir = os.path.join(inner_roots, inner_dir)
dest_dir = os.path.join(pth, inner_dir)
# shutil.copytree(src_dir, dest_dir)
shutil.move(src_dir, dest_dir)
print(f"Copied {inner_dir} to {pth}")
clear_empty_dirs(pth)
# def split_subdirs(source_dir, target_dir, max_files=10):
def split_subdirs(conf):
"""
复制文件数≤max_files的子目录到目标目录
:param source_dir: 源目录路径
:param target_dir: 目标目录路径
:param max_files: 最大文件数阈值
"""
source_dir = conf['data']['source_dir']
target_extra_dir = conf['data']['data_extra_dir']
train_dir = conf['data']['train_dir']
max_files = get_max_files(conf)
megre_subdirs(source_dir) # 合并子目录,删除上级目录
# 创建目标目录
Path(target_extra_dir).mkdir(parents=True, exist_ok=True)
print(f"开始处理目录: {source_dir}")
print(f"目标目录: {target_extra_dir}")
print(f"筛选条件: 文件数 ≤ {max_files}\n")
# 遍历源目录
for subdir in os.listdir(source_dir):
subdir_path = os.path.join(source_dir, subdir)
if not os.path.isdir(subdir_path):
continue
try:
file_count = count_files(subdir_path)
print(f"复制 {subdir} (包含 {file_count} 个文件)")
if file_count <= max_files:
dest_path = os.path.join(target_extra_dir, subdir)
else:
dest_path = os.path.join(train_dir, subdir)
# 如果目标目录已存在则跳过
if os.path.exists(dest_path):
print(f"目录已存在,跳过: {dest_path}")
continue
print(f"复制 {subdir} (包含 {file_count} 个文件) 至 {dest_path}")
shutil.copytree(subdir_path, dest_path)
# shutil.move(subdir_path, dest_path)
except Exception as e:
print(f"处理目录 {subdir} 时出错: {e}")
print("\n处理完成")
if __name__ == "__main__":
# 配置路径
with open('../configs/sub_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# 执行复制操作
split_subdirs(conf)

View File

@ -0,0 +1,72 @@
import os
import shutil
import random
from pathlib import Path
import yaml
def is_image_file(filename):
"""检查文件是否为图像文件"""
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
return filename.lower().endswith(image_extensions)
def split_directory(conf):
"""
分割目录中的图像文件到train和val目录
:param src_dir: 源目录路径
:param train_dir: 训练集目录路径
:param val_dir: 验证集目录路径
:param split_ratio: 训练集比例(默认0.9)
"""
# 创建目标目录
train_dir = conf['data']['train_dir']
val_dir = conf['data']['val_dir']
split_ratio = conf['data']['split_ratio']
Path(val_dir).mkdir(parents=True, exist_ok=True)
# 遍历源目录
for root, dirs, files in os.walk(train_dir):
# 获取相对路径(train_dir)
rel_path = os.path.relpath(root, train_dir)
# 跳过当前目录(.)
if rel_path == '.':
continue
# 创建对应的目标子目录
val_subdir = os.path.join(val_dir, rel_path)
os.makedirs(val_subdir, exist_ok=True)
# 筛选图像文件
image_files = [f for f in files if is_image_file(f)]
if not image_files:
continue
# 随机打乱文件列表
random.shuffle(image_files)
# 计算分割点
split_point = int(len(image_files) * split_ratio)
# 复制文件到验证集
for file in image_files[split_point:]:
src = os.path.join(root, file)
dst = os.path.join(val_subdir, file)
# shutil.copy2(src, dst)
shutil.move(src, dst)
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
def control_train_number():
pass
if __name__ == "__main__":
# # 设置目录路径
# TRAIN_DIR = "/home/lc/data_center/electornic/v1/train"
# VAL_DIR = "/home/lc/data_center/electornic/v1/val"
with open('../configs/scatter_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
print("开始分割数据集...")
split_directory(conf)
print("数据集分割完成")

View File

@ -0,0 +1,192 @@
import os
import random
import shutil
from PIL import Image, ImageEnhance
class ImageExtendProcessor:
def __init__(self, conf):
self.conf = conf
def is_image_file(self, filename):
"""检查文件是否为图像文件"""
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
return filename.lower().endswith(image_extensions)
def random_cute_image(self, image_path, output_path, ratio=0.8):
"""
对图像进行随机裁剪
:param image_path: 输入图像路径
:param output_path: 输出图像路径
:param ratio: 裁剪比例决定裁剪区域的大小默认为0.8
"""
try:
with Image.open(image_path) as img:
# 获取图像尺寸
width, height = img.size
# 计算裁剪后的尺寸
new_width = int(width * ratio)
new_height = int(height * ratio)
# 随机生成裁剪起始点
left = random.randint(0, width - new_width)
top = random.randint(0, height - new_height)
right = left + new_width
bottom = top + new_height
# 执行裁剪
cropped_img = img.crop((left, top, right, bottom))
# 保存裁剪后的图像
cropped_img.save(output_path)
return True
except Exception as e:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def random_brightness(self, image_path, output_path, brightness_factor=None):
"""
对图像进行随机亮度调整
:param image_path: 输入图像路径
:param output_path: 输出图像路径
:param brightness_factor: 亮度调整因子,默认为随机值
"""
try:
with Image.open(image_path) as img:
# 创建一个ImageEnhance.Brightness对象
enhancer = ImageEnhance.Brightness(img)
# 如果没有指定亮度因子,则随机生成
if brightness_factor is None:
brightness_factor = random.uniform(0.5, 1.5)
# 应用亮度调整
brightened_img = enhancer.enhance(brightness_factor)
# 保存调整后的图像
brightened_img.save(output_path)
return True
except Exception as e:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def rotate_image(self, image_path, output_path, degrees):
"""旋转图像并保存到指定路径"""
try:
with Image.open(image_path) as img:
# 旋转图像并自动调整画布大小
rotated = img.rotate(degrees, expand=True)
rotated.save(output_path)
return True
except Exception as e:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def process_extra_directory(self, src_dir, dst_dir, same_directory, dir_name):
"""
处理单个目录中的图像文件
:param src_dir: 源目录路径
:param dst_dir: 目标目录路径
"""
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
# 获取目录中所有图像文件
image_files = [f for f in os.listdir(src_dir)
if self.is_image_file(f) and os.path.isfile(os.path.join(src_dir, f))]
# 处理每个图像文件
for img_file in image_files:
src_path = os.path.join(src_dir, img_file)
base_name, ext = os.path.splitext(img_file)
if not same_directory:
# 复制原始文件 (另存文件夹时启用)
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
if dir_name == 'extra':
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
for ratio in [0.8, 0.85, 0.9]:
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
self.random_cute_image(src_path, dst_path, ratio)
for brightness_factor in [0.8, 0.9, 1.0]:
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
self.random_brightness(src_path, dst_path, brightness_factor)
elif dir_name == 'train':
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
def image_extend(self, src_dir, dst_dir, same_directory=False, dir_name=None):
if same_directory:
n_dst_dir = src_dir
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
else:
n_dst_dir = dst_dir
print(f"处理目录 {src_dir} 中的图像文件 保存至不同目录下")
for src_subdir in os.listdir(src_dir):
src_subdir_path = os.path.join(src_dir, src_subdir)
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
if dir_name == 'extra':
self.process_extra_directory(src_subdir_path,
dst_subdir_path,
same_directory,
dir_name)
if dir_name == 'train':
if len(os.listdir(src_subdir_path)) < 50:
self.process_extra_directory(src_subdir_path,
dst_subdir_path,
same_directory,
dir_name)
def random_remove_image(self, subdir_path, max_count=1000):
"""
随机删除子目录中的图像文件直到数量不超过max_count
:param subdir_path: 子目录路径
:param max_count: 最大允许的图像数量
"""
# 统计图像文件数量
image_files = [f for f in os.listdir(subdir_path)
if self.is_image_file(f) and os.path.isfile(os.path.join(subdir_path, f))]
current_count = len(image_files)
# 如果图像数量不超过max_count则无需删除
if current_count <= max_count:
print(f"无需处理 {subdir_path} (包含 {current_count} 个图像)")
return
# 计算需要删除的文件数
remove_count = current_count - max_count
# 随机选择要删除的文件
files_to_remove = random.sample(image_files, remove_count)
# 删除选中的文件
for file in files_to_remove:
file_path = os.path.join(subdir_path, file)
os.remove(file_path)
print(f"已删除 {file_path}")
def control_number(self):
if self.conf['extend']['extend_extra']:
self.image_extend(self.conf['extend']['extend_extra_dir'],
'',
same_directory=self.conf['extend']['extend_same_dir'],
dir_name='extra')
if self.conf['extend']['extend_train']:
self.image_extend(self.conf['extend']['extend_train_dir'],
'',
same_directory=self.conf['extend']['extend_same_dir'],
dir_name='train')
if self.conf['limit']['count_limit']:
self.random_remove_image(self.conf['limit']['limit_dir'],
max_count=self.conf['limit']['limit_count'])
if __name__ == "__main__":
src_dir = "./scatter_mini"
dst_dir = "./scatter_add"
image_ex = ImageExtendProcessor()
image_ex.image_extend(src_dir, dst_dir, same_directory=False)

View File

@ -0,0 +1,21 @@
from create_extra import split_subdirs
from data_split import split_directory
from extend import ImageExtendProcessor
from combine_sub_class import combine_dirs
import yaml
def data_preprocessing(conf):
if conf['control']['split']:
split_subdirs(conf)
image_ex = ImageExtendProcessor(conf)
image_ex.control_number()
split_directory(conf)
if conf['control']['combine']:
combine_dirs(conf)
if __name__ == '__main__':
with open('../configs/sub_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
data_preprocessing(conf)

View File

@ -4,7 +4,7 @@ from .mobilevit import mobilevit_s
from .metric import ArcFace, CosFace
from .loss import FocalLoss
from .resbam import resnet
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
from .resnet_pre import resnet18, resnet34, resnet50, resnet101, resnet152,resnet14, CustomResNet18
from .mobilenet_v2 import mobilenet_v2
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
# from .mobilenet_v1 import mobilenet_v1

View File

@ -205,7 +205,7 @@ class ResNet(nn.Module):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
print("ResNet scale: >>>>>>>>>> ", scale)
print("通道剪枝 {}".format(scale))
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
@ -222,13 +222,13 @@ class ResNet(nn.Module):
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
self.maxpool2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
)
# self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
# self.maxpool2 = nn.Sequential(
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
# nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
# )
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
@ -368,7 +368,7 @@ def resnet18(pretrained=True, progress=True, **kwargs):
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
def resnet34(pretrained=True, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
@ -380,7 +380,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
def resnet50(pretrained=True, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
@ -392,7 +392,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
def resnet101(pretrained=True, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

View File

@ -11,10 +11,13 @@ import matplotlib.pyplot as plt
# from config import config as conf
from tools.dataset import get_transform
from tools.image_joint import merge_imgs
from tools.getHeatMap import cal_cam
from configs import trainer_tools
import yaml
from datetime import datetime
with open('configs/test.yml', 'r') as f:
with open('../configs/test.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Constants from config
@ -22,6 +25,7 @@ embedding_size = conf["base"]["embedding_size"]
img_size = conf["transform"]["img_size"]
device = conf["base"]["device"]
def unique_image(pair_list: str) -> Set[str]:
unique_images = set()
try:
@ -115,8 +119,12 @@ def featurize(
except Exception as e:
print(f"Error in feature extraction: {e}")
raise
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def threshold_search(y_score, y_true):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
@ -133,7 +141,7 @@ def threshold_search(y_score, y_true):
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
x = np.linspace(start=-1, stop=1.0, num=100, endpoint=True).tolist()
plt.figure(figsize=(10, 6))
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
@ -143,6 +151,7 @@ def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
plt.savefig('grid.png')
plt.show()
@ -154,19 +163,19 @@ def showHist(same, cross):
Cross = np.array(cross)
fig, axs = plt.subplots(2, 1)
axs[0].hist(Same, bins=50, edgecolor='black')
axs[0].set_xlim([-0.1, 1])
axs[0].hist(Same, bins=100, edgecolor='black')
axs[0].set_xlim([-1, 1])
axs[0].set_title('Same Barcode')
axs[1].hist(Cross, bins=50, edgecolor='black')
axs[1].set_xlim([-0.1, 1])
axs[1].hist(Cross, bins=100, edgecolor='black')
axs[1].set_xlim([-1, 1])
axs[1].set_title('Cross Barcode')
plt.savefig('plot.png')
def compute_accuracy_recall(score, labels):
th = 0.1
squence = np.linspace(-1, 1, num=50)
squence = np.linspace(-1, 1, num=100)
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
Same = score[:len(score) // 2]
Cross = score[len(score) // 2:]
@ -179,24 +188,26 @@ def compute_accuracy_recall(score, labels):
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
PreciseNeg[-1], recall_TN[-1]))
showHist(Same, Cross)
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
def compute_accuracy(
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
cam: cal_cam,
) -> Tuple[float, float]:
try:
pair_list = conf['data']['test_list']
test_root = conf['data']['test_dir']
with open(pair_list, 'r') as f:
pairs = f.readlines()
except IOError as e:
@ -211,7 +222,8 @@ def compute_accuracy(
if not pair:
continue
try:
# try:
print(f"Processing pair: {pair}")
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
@ -224,13 +236,20 @@ def compute_accuracy(
feat1 = feature_dict[img1_path].cpu().numpy()
feat2 = feature_dict[img2_path].cpu().numpy()
similarity = cosin_metric(feat1, feat2)
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
if conf['data']['save_image_joint']:
merge_imgs(img1_path,
img2_path,
conf,
similarity,
label,
cam)
similarities.append(similarity)
labels.append(int(label))
except Exception as e:
print(f"Skipping invalid pair: {pair}. Error: {e}")
continue
# except Exception as e:
# print(f"Skipping invalid pair: {pair}. Error: {e}")
# continue
# Find optimal threshold and accuracy
accuracy, threshold = threshold_search(similarities, labels)
@ -267,10 +286,10 @@ def compute_group_accuracy(content_list_read):
d = featurize(group[0], conf.test_transform, model, conf.device)
one_group_list.append(d.values())
if data_loaded[-1] == '1':
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
Same.append(similarity)
else:
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
Cross.append(similarity)
allLabel.append(data_loaded[-1])
allSimilarity.extend(similarity)
@ -291,14 +310,36 @@ def init_model():
print('load model {} '.format(conf['models']['backbone']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
model = nn.DataParallel(model).to(conf['base']['device'])
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
###############正常模型加载################
model.load_state_dict(torch.load(conf['models']['model_path'],
map_location=conf['base']['device']))
#######################################
####### 对比学习模型临时运用###
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
# new_state_dict = {}
# for k, v in state_dict.items():
# new_key = k.replace("module.base_model.", "module.")
# new_state_dict[new_key] = v
# model.load_state_dict(new_state_dict, strict=False)
###########################
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
else:
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
if conf.model_half:
try:
model.load_state_dict(torch.load(conf['models']['model_path'],
map_location=conf['base']['device']))
except:
state_dict = torch.load(conf['models']['model_path'],
map_location=conf['base']['device'])
new_state_dict = {}
for k, v in state_dict.items():
new_key = k.replace("module.", "")
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict, strict=False)
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
@ -308,7 +349,7 @@ def init_model():
if __name__ == '__main__':
model = init_model()
model.eval()
cam = cal_cam(model, conf)
if not conf['data']['group_test']:
images = unique_image(conf['data']['test_list'])
images = [osp.join(conf['data']['test_dir'], img) for img in images]
@ -318,7 +359,7 @@ if __name__ == '__main__':
for group in groups:
d = featurize(group, test_transform, model, conf['base']['device'])
feature_dict.update(d)
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
accuracy, threshold = compute_accuracy(feature_dict, cam)
print(
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
)

View File

@ -5,12 +5,14 @@ import torchvision.transforms as T
# from config import config as conf
import torch
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
def get_transform(cfg):
train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
@ -24,7 +26,7 @@ def get_transform(cfg):
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
test_transform = T.Compose([
# T.Lambda(pad_to_square), # 补边
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
T.ConvertImageDtype(torch.float32),
@ -32,37 +34,73 @@ def get_transform(cfg):
])
return train_transform, test_transform
def load_data(training=True, cfg=None):
def load_data(training=True, cfg=None, return_dataset=False):
train_transform, test_transform = get_transform(cfg)
if training:
dataroot = cfg['data']['data_train_dir']
transform = train_transform
# transform = conf.train_transform
# transform.yml = conf.train_transform
batch_size = cfg['data']['train_batch_size']
else:
dataroot = cfg['data']['data_val_dir']
# transform = conf.test_transform
# transform.yml = conf.test_transform
transform = test_transform
batch_size = cfg['data']['val_batch_size']
data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes)
if return_dataset:
return data, class_num
else:
loader = DataLoader(data,
batch_size=batch_size,
shuffle=True,
shuffle=True if training else False,
pin_memory=cfg['base']['pin_memory'],
num_workers=cfg['data']['num_workers'],
drop_last=True)
return loader, class_num
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
"""
MultiEpochsDataLoader 类
通过重用工作进程来提高数据加载效率避免每个epoch重新启动工作进程
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
"""
重复采样器避免每个epoch重新创建迭代器
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
# def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# return train_dataset, val_dataset, test_dataset

View File

@ -1,10 +1,10 @@
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg

View File

@ -0,0 +1,23 @@
../electronic_imgs/0.jpg
../electronic_imgs/1.jpg
../electronic_imgs/2.jpg
../electronic_imgs/3.jpg
../electronic_imgs/4.jpg
../electronic_imgs/5.jpg
../electronic_imgs/6.jpg
../electronic_imgs/7.jpg
../electronic_imgs/8.jpg
../electronic_imgs/9.jpg
../electronic_imgs/10.jpg
../electronic_imgs/11.jpg
../electronic_imgs/12.jpg
../electronic_imgs/13.jpg
../electronic_imgs/14.jpg
../electronic_imgs/15.jpg
../electronic_imgs/16.jpg
../electronic_imgs/17.jpg
../electronic_imgs/18.jpg
../electronic_imgs/19.jpg
../electronic_imgs/20.jpg
../electronic_imgs/21.jpg
../electronic_imgs/22.jpg

View File

@ -0,0 +1,144 @@
from similar_analysis import SimilarAnalysis
import os
import pickle
from tools.image_joint import merge_imgs
class EventSimilarAnalysis(SimilarAnalysis):
def __init__(self):
super(EventSimilarAnalysis, self).__init__()
self.fn_one2one_event, self.fp_one2one_event = self.One2one_similar_analysis()
self.fn_one2sn_event, self.fp_one2sn_event = self.One2Sn_similar_analysis()
if os.path.exists(self.conf['event']['pickle_path']):
print('pickle file exists')
else:
self.target_image = self.get_path()
def get_path(self):
events = [self.fn_one2one_event, self.fp_one2one_event,
self.fn_one2sn_event, self.fp_one2sn_event]
event_image_path = []
barcode_image_path = []
for event in events:
for event_name, bcd in event:
event_sub_image = os.sep.join([self.conf['event']['event_save_dir'],
event_name,
'subimgs'])
barcode_images = os.sep.join([self.conf['event']['stdlib_image_path'],
bcd])
for image_name in os.listdir(event_sub_image):
event_image_path.append(os.sep.join([event_sub_image, image_name]))
for barcode in os.listdir(barcode_images):
barcode_image_path.append(os.sep.join([barcode_images, barcode]))
return list(set(event_image_path + barcode_image_path))
def write_dict_to_pickle(self, data):
"""将字典写入pickle文件."""
with open(self.conf['event']['pickle_path'], 'wb') as file:
pickle.dump(data, file)
def get_dict_to_pickle(self):
with open(self.conf['event']['pickle_path'], 'rb') as f:
data = pickle.load(f)
return data
def create_total_feature(self):
feature_dicts = self.get_feature_map(self.target_image)
self.write_dict_to_pickle(feature_dicts)
print(feature_dicts)
def One2one_similar_analysis(self):
fn_event, fp_event = [], []
with open(self.conf['event']['oneToOneTxt'], 'r') as f:
lines = f.readlines()
for line in lines:
print(line.strip().split(' '))
event_infor = line.strip().split(' ')
label = event_infor[0]
event_name = event_infor[1]
bcd = event_infor[2]
simi1 = event_infor[3]
simi2 = event_infor[4]
if label == 'same' and float(simi2) < self.conf['event']['oneToOne_max_th']:
print(event_name, bcd, simi1)
fn_event.append((event_name, bcd))
elif label == 'diff' and float(simi2) > self.conf['event']['oneToSn_min_th']:
fp_event.append((event_name, bcd))
return fn_event, fp_event
def One2Sn_similar_analysis(self):
fn_event, fp_event = [], []
with open(self.conf['event']['oneToOneTxt'], 'r') as f:
lines = f.readlines()
for line in lines:
print(line.strip().split(' '))
event_infor = line.strip().split(' ')
label = event_infor[0]
event_name = event_infor[1]
bcd = event_infor[2]
simi = event_infor[3]
if label == 'fn':
print(event_name, bcd, simi)
fn_event.append((event_name, bcd))
elif label == 'fp':
fp_event.append((event_name, bcd))
return fn_event, fp_event
def save_joint_image(self, img_pth1, img_pth2, feature_dicts, record):
feature_dict1 = feature_dicts[img_pth1]
feature_dict2 = feature_dicts[img_pth2]
similarity = self.get_similarity(feature_dict1.cpu().numpy(),
feature_dict2.cpu().numpy())
dir_name = img_pth1.split('/')[-3]
save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name, record])
if "fp" in record:
if similarity > 0.8:
merge_imgs(img_pth1,
img_pth2,
self.conf,
similarity,
label=None,
cam=self.cam,
save_path=save_path)
else:
if similarity < 0.8:
merge_imgs(img_pth1,
img_pth2,
self.conf,
similarity,
label=None,
cam=self.cam,
save_path=save_path)
print(similarity)
def get_contrast(self, feature_dicts):
events_compare = [self.fp_one2one_event, self.fn_one2one_event, self.fp_one2sn_event, self.fn_one2sn_event]
event_record = ['fp_one2one', 'fn_one2one', 'fp_one2sn', 'fn_one2sn']
for event_compare, record in zip(events_compare, event_record):
for img, img_std in event_compare:
imgs_pth1 = os.sep.join([self.conf['event']['event_save_dir'],
img,
'subimgs'])
imgs_pth2 = os.sep.join([self.conf['event']['stdlib_image_path'],
img_std])
for img1 in os.listdir(imgs_pth1):
for img2 in os.listdir(imgs_pth2):
img_pth1 = os.sep.join([imgs_pth1, img1])
img_pth2 = os.sep.join([imgs_pth2, img2])
try:
self.save_joint_image(img_pth1, img_pth2, feature_dicts, record)
except Exception as e:
continue
print(e)
if __name__ == '__main__':
event_similar_analysis = EventSimilarAnalysis()
if os.path.exists(event_similar_analysis.conf['event']['pickle_path']):
print('pickle file exists')
else:
event_similar_analysis.create_total_feature() # 生成pickle文件, 生成时间较长,生成一个文件即可
feature_dicts = event_similar_analysis.get_dict_to_pickle()
# all_compare_img = event_similar_analysis.get_image_map()
event_similar_analysis.get_contrast(feature_dicts) # 获取比对结果

164
tools/getHeatMap.py Normal file
View File

@ -0,0 +1,164 @@
# -*- coding: UTF-8 -*-
import os
import torch
from torchvision import models
import torch.nn as nn
import torchvision.transforms as tfs
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
# from tools.config import cfg
# from comparative.tools.initmodel import initSimilarityModel
import yaml
from dataset import get_transform
class cal_cam(nn.Module):
def __init__(self, model, conf):
super(cal_cam, self).__init__()
self.conf = conf
self.device = self.conf['base']['device']
self.model = model
self.model.to(self.device)
# 要求梯度的层
self.feature_layer = conf['heatmap']['feature_layer']
# 记录梯度
self.gradient = []
# 记录输出的特征图
self.output = []
_, self.transform = get_transform(self.conf)
def get_conf(self, yaml_pth):
with open(yaml_pth, 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
return conf
def save_grad(self, grad):
self.gradient.append(grad)
def get_grad(self):
return self.gradient[-1].cpu().data
def get_feature(self):
return self.output[-1][0]
def process_img(self, input):
input = self.transform(input)
input = input.unsqueeze(0)
return input
# 计算最后一个卷积层的梯度,输出梯度和最后一个卷积层的特征图
def getGrad(self, input_):
self.gradient = [] # 清除之前的梯度
self.output = [] # 清除之前的特征图
# print(f"cuda.memory_allocated 1 {torch.cuda.memory_allocated()/ (1024 ** 3)}G")
input_ = input_.to(self.device).requires_grad_(True)
num = 1
for name, module in self.model._modules.items():
# print(f'module_name: {name}')
# print(f'module: {module}')
if (num == 1):
input = module(input_)
num = num + 1
continue
# 是待提取特征图的层
if (name == self.feature_layer):
input = module(input)
input.register_hook(self.save_grad)
self.output.append([input])
# 马上要到全连接层了
elif (name == "avgpool"):
input = module(input)
input = input.reshape(input.shape[0], -1)
# 普通的层
else:
input = module(input)
# print(f"cuda.memory_allocated 2 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
# 到这里input就是最后全连接层的输出了
index = torch.max(input, dim=-1)[1]
one_hot = torch.zeros((1, input.shape[-1]), dtype=torch.float32)
one_hot[0][index] = 1
confidenct = one_hot * input.cpu()
confidenct = torch.sum(confidenct, dim=-1).requires_grad_(True)
# print(f"cuda.memory_allocated 3 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
# 清除之前的所有梯度
self.model.zero_grad()
# 反向传播获取梯度
grad_output = torch.ones_like(confidenct)
confidenct.backward(grad_output)
# 获取特征图的梯度
grad_val = self.get_grad()
feature = self.get_feature()
# print(f"cuda.memory_allocated 4 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
return grad_val, feature, input_.grad
# 计算CAM
def getCam(self, grad_val, feature):
# 对特征图的每个通道进行全局池化
alpha = torch.mean(grad_val, dim=(2, 3)).cpu()
feature = feature.cpu()
# 将池化后的结果和相应通道特征图相乘
cam = torch.zeros((feature.shape[2], feature.shape[3]), dtype=torch.float32)
for idx in range(alpha.shape[1]):
cam = cam + alpha[0][idx] * feature[0][idx]
# 进行ReLU操作
cam = np.maximum(cam.detach().numpy(), 0)
# plt.imshow(cam)
# plt.colorbar()
# plt.savefig("cam.jpg")
# 将cam区域放大到输入图片大小
cam_ = cv2.resize(cam, (224, 224))
cam_ = cam_ - np.min(cam_)
cam_ = cam_ / np.max(cam_)
# plt.imshow(cam_)
# plt.savefig("cam_.jpg")
cam = torch.from_numpy(cam)
return cam, cam_
def show_img(self, cam_, img, heatmap_save_pth, imgname):
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
# cv2.imwrite("img.jpg", cam_img)
cv2.imwrite(os.sep.join([heatmap_save_pth, imgname]), cam_img)
def get_hot_map(self, img_pth):
img = Image.open(img_pth)
img = img.resize((224, 224))
input = self.process_img(img)
grad_val, feature, input_grad = self.getGrad(input)
cam, cam_ = self.getCam(grad_val, feature)
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
cam_img = Image.fromarray(np.uint8(cam_img))
return cam_img
# def __call__(self, img_root, heatmap_save_pth):
# for imgname in os.listdir(img_root):
# img = Image.open(os.sep.join([img_root, imgname]))
# img = img.resize((224, 224))
# # plt.imshow(img)
# # plt.savefig("airplane.jpg")
# input = self.process_img(img)
# grad_val, feature, input_grad = self.getGrad(input)
# cam, cam_ = self.getCam(grad_val, feature)
# self.show_img(cam_, img, heatmap_save_pth, imgname)
# return cam
if __name__ == "__main__":
cam = cal_cam()
img_root = "test_img/"
heatmap_save_pth = "heatmap_result"
cam(img_root, heatmap_save_pth)

213
tools/getpairs.py Normal file
View File

@ -0,0 +1,213 @@
import os
import random
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import logging
class PairGenerator:
"""Generate positive and negative image pairs for contrastive learning."""
def __init__(self, original_path):
self._setup_logging()
self.original_path = original_path
self._delete_space()
def _delete_space(self): # 删除图片文件名中的空格
print(self.original_path)
for root, dirs, files in os.walk(self.original_path):
for file_name in files:
if file_name.endswith('.jpg' or '.png'):
n_file_name = file_name.replace(' ', '')
os.rename(os.path.join(root, file_name), os.path.join(root, n_file_name))
if 'rotate' in file_name:
os.remove(os.path.join(root, file_name))
for dir_name in dirs:
n_dir_name = dir_name.replace(' ', '')
os.rename(os.path.join(root, dir_name), os.path.join(root, n_dir_name))
def _setup_logging(self):
"""Configure logging settings."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
"""Scan directory and return dict of {folder: [image_paths]}."""
root = Path(root_dir)
if not root.is_dir():
raise ValueError(f"Invalid directory: {root_dir}")
return {
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
for folder in root.iterdir() if folder.is_dir()
}
def _generate_same_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int,
min_size: int, # min_size is the minimum number of images per folder
group_size: Optional[int] = None
) -> List[Tuple[str, str, int]]:
"""Generate positive pairs from same folder."""
pairs = []
for folder, files in files_dict.items():
if len(files) < 2:
continue
if group_size:
# Group mode: generate all possible pairs within group
for i in range(0, len(files), group_size):
group = files[i:i + group_size]
pairs.extend([
(group[i], group[j], 1)
for i in range(len(group))
for j in range(i + 1, len(group))
])
else:
# Individual mode: random pairs
try:
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
except ValueError as e:
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
random.shuffle(pairs)
return pairs[:num_pairs]
def _generate_cross_pairs(
self,
files_dict: Dict[str, List[str]],
num_pairs: int
) -> List[Tuple[str, str, int]]:
"""Generate negative pairs from different folders."""
folders = list(files_dict.keys())
pairs = []
existing_pairs = set()
while len(pairs) < num_pairs and len(folders) >= 2:
folder1, folder2 = random.sample(folders, 2)
file1 = random.choice(files_dict[folder1])
file2 = random.choice(files_dict[folder2])
pair_key = (file1, file2)
reverse_key = (file2, file1)
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
pairs.append((file1, file2, 0))
existing_pairs.add(pair_key)
existing_pairs.add(reverse_key)
return pairs
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
"""Generate random pairs from file list."""
max_possible = len(files) // 2
if max_possible == 0:
return []
actual_pairs = min(num_pairs, max_possible)
indices = random.sample(range(len(files)), 2 * actual_pairs)
indices.sort()
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
def get_pairs(self,
root_dir: str,
num_pairs: int = 6000,
output_txt: Optional[str] = None
) -> List[Tuple[str, str, int]]:
"""
Generate individual image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
num_pairs: Number of pairs to generate
output_txt: Optional path to save pairs as txt file
Returns:
List of (path1, path2, label) tuples
"""
files_dict = self._get_image_files(root_dir)
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
if output_txt:
try:
with open(output_txt, 'w') as f:
for file1, file2, label in pairs:
f.write(f"{file1} {file2} {label}\n")
self.logger.info(f"Saved pairs to {output_txt}")
except IOError as e:
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
return pairs
def get_group_pairs(
self,
root_dir: str,
img_num: int = 20,
group_num: int = 10,
num_pairs: int = 5000
) -> List[Tuple[str, str, int]]:
"""
Generate grouped image pairs with labels (1=same, 0=different).
Args:
root_dir: Directory containing subfolders of images
img_num: Minimum images required per folder
group_num: Number of images per group
num_pairs: Number of pairs to generate
Returns:
List of (path1, path2, label) tuples
"""
# Filter folders with enough images
files_dict = {
k: v for k, v in self._get_image_files(root_dir).items()
if len(v) >= img_num
}
# Split into groups
grouped_files = {}
for folder, files in files_dict.items():
random.shuffle(files)
grouped_files[folder] = [
files[i:i + group_num]
for i in range(0, len(files), group_num)
]
# Generate pairs
same_pairs = self._generate_same_pairs(
grouped_files, num_pairs, min_size=30, group_size=group_num
)
cross_pairs = self._generate_cross_pairs(
grouped_files, len(same_pairs)
)
pairs = same_pairs + cross_pairs
self.logger.info(f"Generated {len(pairs)} group pairs")
# Save to JSON
with open("cross_same.json", 'w') as f:
json.dump(pairs, f)
return pairs
if __name__ == "__main__":
original_path = '/home/lc/data_center/electornic/v1/val'
parent_dir = str(Path(original_path).parent)
generator = PairGenerator(original_path)
# Example usage:
pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs

71
tools/image_joint.py Normal file
View File

@ -0,0 +1,71 @@
from PIL import Image, ImageDraw, ImageFont
from tools.getHeatMap import cal_cam
import os
def merge_imgs(img1_path, img2_path, conf, similar=None, label=None, cam=None, save_path=None):
save = True
position = (50, 50) # 文字的左上角坐标
color = (255, 0, 0) # 红色文字,格式为 RGB
# if not os.path.exists(os.sep.join([save_path, str(label)])):
# os.makedirs(os.sep.join([save_path, str(label)]))
# save_path = os.sep.join([save_path, str(label)])
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
if save_path is None:
save_path = conf['data']['image_joint_pth']
if not conf['heatmap']['show_heatmap']:
img1 = Image.open(img1_path)
img2 = Image.open(img2_path)
img1 = img1.resize((224, 224))
img2 = img2.resize((224, 224))
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
# save_path = conf['data']['image_joint_pth']
else:
assert cam is not None, 'cam is None'
img1 = cam.get_hot_map(img1_path)
img2 = cam.get_hot_map(img2_path)
img1_ori = Image.open(img1_path)
img2_ori = Image.open(img2_path)
img1_ori = img1_ori.resize((224, 224))
img2_ori = img2_ori.resize((224, 224))
new_img = Image.new('RGB',
(img1.width + img2.width + 10,
img1.height + img2.width + 10))
# save_path = conf['heatmap']['image_joint_pth']
# print('img1_path', img1)
# print('img2_path', img2)
if not os.path.exists(os.sep.join([save_path, str(label)])) and (label is not None):
os.makedirs(os.sep.join([save_path, str(label)]))
save_path = os.sep.join([save_path, str(label)])
if save_path is None:
# save_path = os.sep.join([save_path, str(label)])
pass
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
img_name = os.path.basename(img1_path).split('.')[0][:30] + '_' + os.path.basename(img2_path).split('.')[0][
:30] + '.png'
assert img1.height == img2.height
# print('new_img', new_img)
if not conf['heatmap']['show_heatmap']:
new_img.paste(img1, (0, 0))
new_img.paste(img2, (img1.width + 10, 0))
else:
new_img.paste(img1_ori, (10, 10))
new_img.paste(img2_ori, (img2_ori.width + 20, 10))
new_img.paste(img1, (10, img1.height+20))
new_img.paste(img2, (img2.width+20, img2.height+20))
if similar is not None:
if label == '1' and (similar > 0.5 or similar < 0.25):
save = False
elif label == '0' and similar > 0.25:
save = False
similar = str(similar) + '_' + str(label)
draw = ImageDraw.Draw(new_img)
draw.text(position, str(similar), color, font_size=36)
os.makedirs(save_path, exist_ok=True)
img_save = os.path.join(save_path, img_name)
if save:
new_img.save(img_save)

View File

@ -2,17 +2,29 @@ import pdb
import torch
import torch.nn as nn
from model import resnet18
from config import config as conf
# from config import config as conf
from collections import OrderedDict
from configs import trainer_tools
import cv2
import yaml
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
# 定义模型
if model_name == 'resnet18':
model = resnet18(scale=0.75)
def tranform_onnx_model():
# # 定义模型
# if model_name == 'resnet18':
# model = resnet18(scale=0.75)
print('model_name >>> {}'.format(model_name))
if conf.multiple_cards:
with open('../configs/transform.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
pretrained_weights = conf['models']['model_path']
print('model_name >>> {}'.format(conf['models']['backbone']))
if conf['base']['distributed']:
model = model.to(torch.device('cpu'))
checkpoint = torch.load(pretrained_weights)
new_state_dict = OrderedDict()
@ -22,22 +34,8 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# try:
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
# except Exception as e:
# print(e)
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
# model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
# 转换为ONNX
if model_name == 'gift_type2':
input_shape = [1, 64, 13, 13]
elif model_name == 'gift_type3':
input_shape = [1, 3, 224, 224]
else:
# 假设输入数据的大小是通道数*高度*宽度例如3*224*224
input_shape = [1, 3, 224, 224]
img = cv2.imread('./dog_224x224.jpg')
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
if __name__ == '__main__':
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断gift3_type3指resnet原图计算推理
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
tranform_onnx_model()

View File

@ -6,15 +6,14 @@ import time
import sys
import numpy as np
import cv2
from config import config as conf
from rknn.api import RKNN
import config
import yaml
with open('../configs/transform.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# ONNX_MODEL = 'resnet50v2.onnx'
# RKNN_MODEL = 'resnet50v2.rknn'
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
ONNX_MODEL = conf['models']['onnx_model']
RKNN_MODEL = conf['models']['rknn_model']
# ONNX_MODEL = 'v3_small_0424.onnx'
@ -100,7 +99,8 @@ if __name__ == '__main__':
target_platform='rk3588',
model_pruning=False,
compress_weight=False,
single_core_mode=True)
single_core_mode=True,
enable_flash_attention=True)
# rknn.config(
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
@ -122,7 +122,10 @@ if __name__ == '__main__':
# Build model
print('--> Building model')
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
ret = rknn.build(do_quantization=True, # True
# dataset='./dataset.txt',
dataset=conf['base']['dataset'],
rknn_batch_size=conf['models']['rknn_batch_size'])
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
if ret != 0:
print('Build model failed!')

View File

@ -0,0 +1,242 @@
from similar_analysis import SimilarAnalysis
import os
import pickle
from tools.image_joint import merge_imgs
import yaml
from PIL import Image
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt
'''
轨迹图与标准库之间的相似度分析
1.用于生成轨迹图与标准库中所有图片的相似度
2.用于分析轨迹图与标准库比对选取策略的判断
'''
class picDirSimilarAnalysis(SimilarAnalysis):
def __init__(self):
super(picDirSimilarAnalysis, self).__init__()
with open('../configs/pic_pic_similar.yml', 'r') as f:
self.conf = yaml.load(f, Loader=yaml.FullLoader)
if not os.path.exists(self.conf['data']['total_pkl']):
# self.create_total_feature()
self.create_total_pkl()
if os.path.exists(self.conf['data']['total_pkl']):
self.all_dicts = self.load_dict_from_pkl()
def is_image_file(self, filename):
"""
检查文件是否为图像文件
"""
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
return filename.lower().endswith(image_extensions)
def create_total_pkl(self): # 将目录下所有的图片特征存入pkl文件
all_images_feature_dict = {}
for roots, dirs, files in os.walk(self.conf['data']['data_dir']):
for file_name in files:
if self.is_image_file(file_name):
try:
print(f"处理图像 {os.sep.join([roots, file_name])}")
feature = self.extract_features(os.sep.join([roots, file_name]))
except Exception as e:
print(f"处理图像 {os.sep.join([roots, file_name])} 时出错: {e}")
feature = None
all_images_feature_dict[os.sep.join([roots, file_name])] = feature
if not os.path.exists(self.conf['data']['total_pkl']):
with open(self.conf['data']['total_pkl'], 'wb') as f:
pickle.dump(all_images_feature_dict, f)
def load_dict_from_pkl(self):
with open(self.conf['data']['total_pkl'], 'rb') as f:
data = pickle.load(f)
print(f"字典已从 {self.conf['data']['total_pkl']} 加载")
return data
def get_image_files(self, folder_path):
"""
获取文件夹中的所有图像文件
"""
image_files = []
for root, _, files in os.walk(folder_path):
for file in files:
if self.is_image_file(file):
image_files.append(os.path.join(root, file))
return image_files
def extract_features(self, image_path):
feature_dict = self.get_feature(image_path)
return feature_dict[image_path]
def create_one_similarity_matrix(self, folder1_path, folder2_path):
images1 = self.get_image_files(folder1_path)
images2 = self.get_image_files(folder2_path)
print(f"文件夹1 ({folder1_path}) 包含 {len(images1)} 张图像")
print(f"文件夹2 ({folder2_path}) 包含 {len(images2)} 张图像")
if len(images1) == 0 or len(images2) == 0:
raise ValueError("至少有一个文件夹中没有图像文件")
# 提取文件夹1中的所有图像特征
features1 = []
print("正在提取文件夹1中的图像特征...")
for i, img_path in enumerate(images1):
try:
# feature = self.extract_features(img_path)
feature = self.all_dicts[img_path]
features1.append(feature.cpu().numpy())
# if (i + 1) % 10 == 0:
# print(f"已处理 {i + 1}/{len(images1)} 张图像")
except Exception as e:
print(f"处理图像 {img_path} 时出错: {e}")
features1.append(None)
# 提取文件夹2中的所有图像特征
features2 = []
print("正在提取文件夹2中的图像特征...")
for i, img_path in enumerate(images2):
try:
# feature = self.extract_features(img_path)
feature = self.all_dicts[img_path]
features2.append(feature.cpu().numpy())
# if (i + 1) % 10 == 0:
# print(f"已处理 {i + 1}/{len(images2)} 张图像")
except Exception as e:
print(f"处理图像 {img_path} 时出错: {e}")
features2.append(None)
# 移除处理失败的图像
valid_features1 = []
valid_images1 = []
for i, feature in enumerate(features1):
if feature is not None:
valid_features1.append(feature)
valid_images1.append(images1[i])
valid_features2 = []
valid_images2 = []
for i, feature in enumerate(features2):
if feature is not None:
valid_features2.append(feature)
valid_images2.append(images2[i])
# print(f"文件夹1中成功处理 {len(valid_features1)} 张图像")
# print(f"文件夹2中成功处理 {len(valid_features2)} 张图像")
if len(valid_features1) == 0 or len(valid_features2) == 0:
raise ValueError("没有成功处理任何图像")
# 计算相似度矩阵
print("正在计算相似度矩阵...")
similarity_matrix = cosine_similarity(valid_features1, valid_features2)
return similarity_matrix, valid_images1, valid_images2
def get_group_similarity_matrix(self, folder_path):
tracking_folder = os.sep.join([folder_path, 'tracking'])
standard_folder = os.sep.join([folder_path, 'standard_slim'])
for dir_name in os.listdir(tracking_folder):
tracking_dir = os.sep.join([tracking_folder, dir_name])
standard_dir = os.sep.join([standard_folder, dir_name])
similarity_matrix, valid_images1, valid_images2 = self.create_one_similarity_matrix(tracking_dir,
standard_dir)
mean_similarity = np.mean(similarity_matrix)
std_similarity = np.std(similarity_matrix)
max_similarity = np.max(similarity_matrix)
min_similarity = np.min(similarity_matrix)
print(f"文件夹 {dir_name} 的相似度矩阵已计算完成 "
f"均值:{mean_similarity} 标准差:{std_similarity} 最大值:{max_similarity} 最小值:{min_similarity}")
result = f"{os.path.basename(standard_folder)} {dir_name} {mean_similarity:.3f} {std_similarity:.3f} {max_similarity:.3f} {min_similarity:.3f}"
with open(self.conf['data']['result_txt'], 'a') as f:
f.write(result + '\n')
def read_result_txt():
parts = []
value_num = 2
with open('../configs/pic_pic_similar.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
f.close()
with open(conf['data']['result_txt'], 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line:
parts.append(line.split(' '))
parts = np.array(parts)
print(parts)
labels = ['Mean', 'Std', 'Max', 'Min']
while value_num < 6:
dicts = {}
for barcode, value in zip(parts[:, 1], parts[:, value_num]):
if barcode in dicts:
dicts[barcode].append(float(value))
else:
dicts[barcode] = [float(value)]
get_histogram(dicts, labels[value_num - 2])
value_num += 1
f.close()
def get_histogram(data, label=None):
# 准备数据
categories = list(data.keys())
values1 = [data[cat][0] for cat in categories] # 第一个值
values2 = [data[cat][1] for cat in categories] # 第二个值
# 设置柱状图的位置
x = np.arange(len(categories)) # 标签位置
width = 0.35 # 柱状图的宽度
# 创建图形和轴
fig, ax = plt.subplots(figsize=(10, 6))
# 绘制柱状图
bars1 = ax.bar(x - width / 2, values1, width, label='standard', color='red', alpha=0.7)
bars2 = ax.bar(x + width / 2, values2, width, label='standard_slim', color='green', alpha=0.7)
# 在每个柱状图上显示数值
for bar in bars1:
height = bar.get_height()
ax.annotate(f'{height:.3f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), # 3点垂直偏移
textcoords="offset points",
ha='center', va='bottom',
fontsize=12)
for bar in bars2:
height = bar.get_height()
ax.annotate(f'{height:.3f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), # 3点垂直偏移
textcoords="offset points",
ha='center', va='bottom',
fontsize=12)
# 添加标签和标题
if label is None:
label = ''
ax.set_xlabel('barcode')
ax.set_ylabel('Values')
ax.set_title(label)
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局并显示
plt.tight_layout()
plt.show()
if __name__ == '__main__':
picTopic_matrix = picDirSimilarAnalysis()
picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new')
# read_result_txt()

107
tools/similar_analysis.py Normal file
View File

@ -0,0 +1,107 @@
from configs.utils import trainer_tools
from test_ori import group_image, featurize, cosin_metric
from tools.dataset import get_transform
from tools.getHeatMap import cal_cam
from tools.image_joint import merge_imgs
import torch.nn as nn
import torch
from collections import ChainMap
import yaml
import os
class SimilarAnalysis:
def __init__(self):
with open('../configs/similar_analysis.yml', 'r') as f:
self.conf = yaml.load(f, Loader=yaml.FullLoader)
self.model = self.initialize_model(self.conf)
_, self.test_transform = get_transform(self.conf)
self.cam = cal_cam(self.model, self.conf)
def initialize_model(self, conf):
"""初始化模型和度量方法"""
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
print('model_path {}'.format(conf['models']['model_path']))
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]()
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
try:
model.load_state_dict(torch.load(conf['models']['model_path'],
map_location=conf['base']['device']))
except:
state_dict = torch.load(conf['models']['model_path'],
map_location=conf['base']['device'])
new_state_dict = {}
for k, v in state_dict.items():
new_key = k.replace("module.", "")
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict, strict=False)
return model.eval()
def get_feature(self, img_pth):
group = group_image([img_pth], self.conf['data']['val_batch_size'])
feature = featurize(group[0], self.test_transform, self.model, self.conf['base']['device'])
return feature
def get_similarity(self, feature_dict1, feature_dict2):
similarity = cosin_metric(feature_dict1, feature_dict2)
print(f"Similarity: {similarity}")
return similarity
def get_feature_map(self, all_imgs):
feature_dicts = {}
for img_pth in all_imgs:
print(f"Processing {img_pth}")
feature_dict = self.get_feature(img_pth)
feature_dicts = dict(ChainMap(feature_dict, feature_dicts))
return feature_dicts
def get_image_map(self):
all_compare_img = []
for root, dirs, files in os.walk(self.conf['data']['data_dir']):
if len(dirs) == 2:
dir_pth_1 = os.sep.join([root, dirs[0]])
dir_pth_2 = os.sep.join([root, dirs[1]])
for img_name_1 in os.listdir(dir_pth_1):
for img_name_2 in os.listdir(dir_pth_2):
all_compare_img.append((os.sep.join([dir_pth_1, img_name_1]),
os.sep.join([dir_pth_2, img_name_2])))
return all_compare_img
def create_total_feature(self):
all_imgs = []
for root, dirs, files in os.walk(self.conf['data']['data_dir']):
if len(dirs) == 2:
for dir_name in dirs:
dir_pth = os.sep.join([root, dir_name])
for img_name in os.listdir(dir_pth):
all_imgs.append(os.sep.join([dir_pth, img_name]))
return all_imgs
def get_contrast_result(self, feature_dicts, all_compare_img):
for img_pth1, img_pth2 in all_compare_img:
feature_dict1 = feature_dicts[img_pth1]
feature_dict2 = feature_dicts[img_pth2]
similarity = self.get_similarity(feature_dict1.cpu().numpy(),
feature_dict2.cpu().numpy())
dir_name = img_pth1.split('/')[-3]
save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name])
if similarity > 0.7:
merge_imgs(img_pth1,
img_pth2,
self.conf,
similarity,
label=None,
cam=self.cam,
save_path=save_path)
print(similarity)
if __name__ == '__main__':
ana = SimilarAnalysis()
all_imgs = ana.create_total_feature()
feature_dicts = ana.get_feature_map(all_imgs)
all_compare_img = ana.get_image_map()
ana.get_contrast_result(feature_dicts, all_compare_img)

View File

@ -4,7 +4,7 @@ import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform
from model import resnet18
from model import resnet18, resnet34, resnet50, resnet101
import torch
from PIL import Image
import pandas as pd
@ -50,7 +50,16 @@ class FeatureExtractor:
raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model
model = resnet18().to(self.conf['base']['device'])
if conf['models']['backbone'] == 'resnet18':
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet34':
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet50':
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet101':
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
else:
print("不支持的模型: {}".format(conf['models']['backbone']))
# Handle multi-GPU case
if conf['base']['distributed']:
@ -168,7 +177,7 @@ class FeatureExtractor:
# Validate input directory
if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}")
i = 0
# Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images)
@ -180,14 +189,16 @@ class FeatureExtractor:
ori_barcode = basename
barcode = basename
# Apply filter if provided
i += 1
print(ori_barcode, i)
if filter and ori_barcode not in filter:
continue
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
logger.warning(f"Skipping invalid barcode {ori_barcode}")
with open(conf['save']['error_barcodes'], 'a') as f:
f.write(ori_barcode + '\n')
f.close()
continue
# elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
# logger.warning(f"Skipping invalid barcode {ori_barcode}")
# with open(conf['save']['error_barcodes'], 'a') as f:
# f.write(ori_barcode + '\n')
# f.close()
# continue
# Process image files
if files:
@ -299,7 +310,8 @@ class FeatureExtractor:
dicts['value'] = truncated_imgs_list
if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'],
str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file:
json.dump(dicts, json_file)
else:
@ -317,8 +329,10 @@ class FeatureExtractor:
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode)))
if len(barcode) > 13 or len(barcode) < 8:
continue
# if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
# continue
if filter is not None:
f.writelines(barcode + '\n')
if barcode in filter:
@ -407,5 +421,5 @@ if __name__ == "__main__":
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
filter=column_values,
create_single_json=False) # False
create_single_json=conf['save']['create_single_json']) # False
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)

View File

@ -3,140 +3,365 @@ import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from model.loss import FocalLoss
from tools.dataset import load_data
from tools.dataset import load_data, MultiEpochsDataLoader
import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
from datetime import datetime
with open('configs/scatter.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
def load_configuration(config_path='configs/compare.yml'):
"""加载配置文件"""
with open(config_path, 'r') as f:
return yaml.load(f, Loader=yaml.FullLoader)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
def initialize_model_and_metric(conf, class_num):
"""初始化模型和度量方法"""
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]()
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
if conf['training']['metric'] in metric_mapping:
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
metric = nn.DataParallel(metric)
return model, metric
# Training Setup
if conf['training']['loss'] == 'focal_loss':
criterion = FocalLoss(gamma=2)
else:
criterion = nn.CrossEntropyLoss()
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
def setup_optimizer_and_scheduler(conf, model, metric):
"""设置优化器和学习率调度器"""
tr_tools = trainer_tools(conf)
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=conf['training']['lr_step'],
gamma=conf['training']['lr_decay']
)
else:
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
return optimizer, scheduler
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
# Checkpoints Setup
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
if __name__ == '__main__':
print('backbone>{} '.format(conf['models']['backbone']),
'metric>{} '.format(conf['training']['metric']),
'checkpoints>{} '.format(conf['training']['checkpoints']),
)
def setup_loss_function(conf):
"""配置损失函数"""
if conf['training']['loss'] == 'focal_loss':
return FocalLoss(gamma=2)
else:
return nn.CrossEntropyLoss()
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
"""执行单个训练周期"""
model.train()
train_loss = 0
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
# with torch.cuda.amp.autocast():
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
return train_loss / len(dataloader)
def validate(model, metric, criterion, dataloader, device, conf):
"""执行验证"""
model.eval()
val_loss = 0
with torch.no_grad():
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
val_loss += loss.item()
return val_loss / len(dataloader)
def save_model(model, path, is_parallel):
"""保存模型权重"""
if is_parallel:
torch.save(model.module.state_dict(), path)
else:
torch.save(model.state_dict(), path)
def log_training_info(log_path, log_info):
"""记录训练信息到日志文件"""
with open(log_path, 'a') as f:
f.write(log_info + '\n')
def initialize_training_components(distributed=False):
"""初始化所有训练所需组件"""
# 加载配置
conf = load_configuration()
# 初始化分布式训练相关参数
components = {
'conf': conf,
'distributed': distributed,
'device': None,
'train_dataloader': None,
'val_dataloader': None,
'model': None,
'metric': None,
'criterion': None,
'optimizer': None,
'scheduler': None,
'checkpoints': None,
'scaler': None
}
# 如果是非分布式训练,直接创建所有组件
if not distributed:
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
train_dataloader = MultiEpochsDataLoader(train_dataloader,
batch_size=conf['data']['train_batch_size'],
shuffle=True,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True)
val_dataloader = MultiEpochsDataLoader(val_dataloader,
batch_size=conf['data']['val_batch_size'],
shuffle=False,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device']
model = model.to(device)
metric = metric.to(device)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 更新组件字典
components.update({
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device
})
return components
def run_training_loop(components):
"""运行完整的训练循环"""
# 解包组件
conf = components['conf']
train_dataloader = components['train_dataloader']
val_dataloader = components['val_dataloader']
model = components['model']
metric = components['metric']
criterion = components['criterion']
optimizer = components['optimizer']
scheduler = components['scheduler']
checkpoints = components['checkpoints']
scaler = components['scaler']
device = components['device']
# 训练状态
train_losses = []
val_losses = []
epochs = []
temp_loss = 100
if conf['training']['restore']:
print('load pretrain model: {}'.format(conf['training']['restore_model']))
model.load_state_dict(torch.load(conf['training']['restore_model'],
map_location=conf['base']['device']))
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
# 训练循环
for e in range(conf['training']['epochs']):
train_loss = 0
model.train()
for train_data, train_labels in tqdm(train_dataloader,
desc="Epoch {}/{}"
.format(e, conf['training']['epochs']),
ascii=True,
total=len(train_dataloader)):
train_data = train_data.to(conf['base']['device'])
train_labels = train_labels.to(conf['base']['device'])
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
# pdb.set_trace()
if not conf['training']['metric'] == 'softmax':
thetas = metric(train_embeddings, train_labels) # [256,357]
else:
thetas = metric(train_embeddings)
tloss = criterion(thetas, train_labels)
optimizer.zero_grad()
tloss.backward()
optimizer.step()
train_loss += tloss.item()
train_lossAvg = train_loss / len(train_dataloader)
train_losses.append(train_lossAvg)
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
train_losses.append(train_loss_avg)
epochs.append(e)
val_loss = 0
model.eval()
with torch.no_grad():
for val_data, val_labels in tqdm(val_dataloader, desc="val",
ascii=True, total=len(val_dataloader)):
val_data = val_data.to(conf['base']['device'])
val_labels = val_labels.to(conf['base']['device'])
val_embeddings = model(val_data).to(conf['base']['device'])
if not conf['training']['metric'] == 'softmax':
thetas = metric(val_embeddings, val_labels)
else:
thetas = metric(val_embeddings)
vloss = criterion(thetas, val_labels)
val_loss += vloss.item()
val_lossAvg = val_loss / len(val_dataloader)
val_losses.append(val_lossAvg)
if val_lossAvg < temp_loss:
if torch.cuda.device_count() > 1:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
temp_loss = val_lossAvg
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
val_losses.append(val_loss_avg)
if val_loss_avg < temp_loss:
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
temp_loss = val_loss_avg
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
.format(datetime.now(),
e,
conf['training']['epochs'],
train_loss_avg,
val_loss_avg,
current_lr))
print(log_info)
# 写入日志文件
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
f.write(log_info + '\n')
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
print("%d个epoch的学习率%f" % (e, current_lr))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
plt.plot(epochs, train_losses, color='blue')
plt.plot(epochs, val_losses, color='red')
# plt.savefig('lossMobilenetv3.png')
# 保存最终模型
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
# 绘制损失曲线
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
plt.legend()
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
def main():
"""主函数入口"""
# 加载配置
conf = load_configuration()
# 检查是否启用分布式训练
distributed = conf['base']['distributed']
if distributed:
# 分布式训练使用mp.spawn启动多个进程
local_size = torch.cuda.device_count()
world_size = int(conf['distributed']['node_num'])*local_size
mp.spawn(
run_training,
args=(conf['distributed']['node_rank'],
local_size,
world_size,
conf),
nprocs=local_size,
join=True
)
else:
# 单机训练:直接运行训练流程
components = initialize_training_components(distributed=False)
run_training_loop(components)
def run_training(local_rank, node_rank, local_size, world_size, conf):
"""实际执行训练的函数供mp.spawn调用"""
# 初始化分布式环境
rank = local_rank + node_rank * local_size
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
# 获取数据集而不是DataLoader
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
model = model.to(device)
metric = metric.to(device)
# 包装为DistributedDataParallel模型
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 创建分布式采样器
train_sampler = DistributedSampler(train_dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
train_dataloader = MultiEpochsDataLoader(
train_dataset,
batch_size=conf['data']['train_batch_size'],
sampler=train_sampler,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True
)
val_dataloader = MultiEpochsDataLoader(
val_dataset,
batch_size=conf['data']['val_batch_size'],
sampler=val_sampler,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False
)
# 构建组件字典
components = {
'conf': conf,
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device,
'distributed': True # 因为是在mp.spawn中运行
}
# 运行训练循环
run_training_loop(components)
if __name__ == '__main__':
main()

233
train_compare.py.bak Normal file
View File

@ -0,0 +1,233 @@
import os
import os.path as osp
import torch
import torch.nn as nn
from tqdm import tqdm
from model.loss import FocalLoss
from tools.dataset import load_data
import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
from datetime import datetime
def load_configuration(config_path='configs/scatter.yml'):
"""加载配置文件"""
with open(config_path, 'r') as f:
return yaml.load(f, Loader=yaml.FullLoader)
def initialize_model_and_metric(conf, class_num):
"""初始化模型和度量方法"""
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]()
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
return model, metric
def setup_optimizer_and_scheduler(conf, model, metric):
"""设置优化器和学习率调度器"""
tr_tools = trainer_tools(conf)
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
return optimizer, scheduler
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
def setup_loss_function(conf):
"""配置损失函数"""
if conf['training']['loss'] == 'focal_loss':
return FocalLoss(gamma=2)
else:
return nn.CrossEntropyLoss()
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
"""执行单个训练周期"""
model.train()
train_loss = 0
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
with torch.cuda.amp.autocast():
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
return train_loss / len(dataloader)
def validate(model, metric, criterion, dataloader, device, conf):
"""执行验证"""
model.eval()
val_loss = 0
with torch.no_grad():
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
val_loss += loss.item()
return val_loss / len(dataloader)
def save_model(model, path, is_parallel):
"""保存模型权重"""
if is_parallel:
torch.save(model.module.state_dict(), path)
else:
torch.save(model.state_dict(), path)
def log_training_info(log_path, log_info):
"""记录训练信息到日志文件"""
with open(log_path, 'a') as f:
f.write(log_info + '\n')
def initialize_training_components():
"""初始化所有训练所需组件"""
# 加载配置
conf = load_configuration()
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device']
model = model.to(device)
metric = metric.to(device)
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
metric = nn.DataParallel(metric)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
return {
'conf': conf,
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device
}
def run_training_loop(components):
"""运行完整的训练循环"""
# 解包组件
conf = components['conf']
train_dataloader = components['train_dataloader']
val_dataloader = components['val_dataloader']
model = components['model']
metric = components['metric']
criterion = components['criterion']
optimizer = components['optimizer']
scheduler = components['scheduler']
checkpoints = components['checkpoints']
scaler = components['scaler']
device = components['device']
# 训练状态
train_losses = []
val_losses = []
epochs = []
temp_loss = 100
if conf['training']['restore']:
print('load pretrain model: {}'.format(conf['training']['restore_model']))
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
# 训练循环
for e in range(conf['training']['epochs']):
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
train_losses.append(train_loss_avg)
epochs.append(e)
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
val_losses.append(val_loss_avg)
if val_loss_avg < temp_loss:
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
temp_loss = val_loss_avg
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
.format(datetime.now(),
e,
conf['training']['epochs'],
train_loss_avg,
val_loss_avg,
current_lr))
print(log_info)
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
print("%d个epoch的学习率%f" % (e, current_lr))
# 保存最终模型
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
# 绘制损失曲线
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
plt.legend()
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
if __name__ == '__main__':
# 初始化训练组件
components = initialize_training_components()
# 运行训练循环
run_training_loop(components)