Compare commits
19 Commits
Author | SHA1 | Date | |
---|---|---|---|
c978787ff8 | |||
99a204ee22 | |||
bc896fc688 | |||
27ffb62223 | |||
ebba07d1ca | |||
3392d76e38 | |||
54898e30ec | |||
09f41f6289 | |||
0701538a73 | |||
6640f2bc5e | |||
bcbabd9313 | |||
5deaf4727f | |||
2219c0a303 | |||
537ed838fc | |||
061820c34f | |||
bf9604ec29 | |||
180a41ae90 | |||
e27e6c3d5b | |||
1803f319a5 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,4 +8,5 @@ loss/
|
||||
checkpoints/
|
||||
search_library/
|
||||
quant_imgs/
|
||||
electronic_imgs/
|
||||
README.md
|
367
.idea/CopilotChatHistory.xml
generated
367
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,120 @@
|
||||
<component name="CopilotChatHistory">
|
||||
<option name="conversations">
|
||||
<list>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1755228773977" />
|
||||
<option name="id" value="0198abc99e597020bf8aa3ef78bc8bd3" />
|
||||
<option name="title" value="新对话 2025年8月15日 11:32:53" />
|
||||
<option name="updateTime" value="1755228773977" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1755227620606" />
|
||||
<option name="id" value="0198abb804fe7bf8ab3ac9ecfeae6d3f" />
|
||||
<option name="title" value="新对话 2025年8月15日 11:13:40" />
|
||||
<option name="updateTime" value="1755227620606" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1755219481041" />
|
||||
<option name="id" value="0198ab3bd1d17216b0dab33158ff294e" />
|
||||
<option name="title" value="新对话 2025年8月15日 08:58:01" />
|
||||
<option name="updateTime" value="1755219481041" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1754286137102" />
|
||||
<option name="id" value="0198739a1f0e75c38b0579ade7b34050" />
|
||||
<option name="title" value="新对话 2025年8月04日 13:42:17" />
|
||||
<option name="updateTime" value="1754286137102" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1753932970546" />
|
||||
<option name="id" value="01985e8d3a3170bf871ba640afdf246d" />
|
||||
<option name="title" value="新对话 2025年7月31日 11:36:10" />
|
||||
<option name="updateTime" value="1753932970546" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1753932554257" />
|
||||
<option name="id" value="01985e86e01170d6a09dca496e3dad46" />
|
||||
<option name="title" value="新对话 2025年7月31日 11:29:14" />
|
||||
<option name="updateTime" value="1753932554257" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1753680371881" />
|
||||
<option name="id" value="01984f7ee0a9779aabcd3f1671b815b3" />
|
||||
<option name="title" value="新对话 2025年7月28日 13:26:11" />
|
||||
<option name="updateTime" value="1753680371881" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1753405176017" />
|
||||
<option name="id" value="01983f17b8d173dda926d0ffa5422bbf" />
|
||||
<option name="title" value="新对话 2025年7月25日 08:59:36" />
|
||||
<option name="updateTime" value="1753405176017" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1753065086744" />
|
||||
<option name="id" value="01982ad25f18712f862c5c18b627f40d" />
|
||||
<option name="title" value="新对话 2025年7月21日 10:31:26" />
|
||||
<option name="updateTime" value="1753065086744" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1752195523240" />
|
||||
<option name="id" value="0197f6fde2a87e68b893b3a36dfc838f" />
|
||||
<option name="title" value="新对话 2025年7月11日 08:58:43" />
|
||||
<option name="updateTime" value="1752195523240" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1752114061266" />
|
||||
<option name="id" value="0197f222dfd27515a3dbfea638532ee5" />
|
||||
<option name="title" value="新对话 2025年7月10日 10:21:01" />
|
||||
<option name="updateTime" value="1752114061266" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1751970991660" />
|
||||
<option name="id" value="0197e99bce2c7a569dee594fb9b6e152" />
|
||||
<option name="title" value="新对话 2025年7月08日 18:36:31" />
|
||||
<option name="updateTime" value="1751970991660" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1751441743239" />
|
||||
<option name="id" value="0197ca101d8771bd80f2bc4aaf1a8f19" />
|
||||
<option name="title" value="新对话 2025年7月02日 15:35:43" />
|
||||
<option name="updateTime" value="1751441743239" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1751441398488" />
|
||||
<option name="id" value="0197ca0adad875168de40d792dcb7b4c" />
|
||||
<option name="title" value="新对话 2025年7月02日 15:29:58" />
|
||||
<option name="updateTime" value="1751441398488" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1750474299387" />
|
||||
<option name="id" value="0197906617fb7194a0407baae2b1e2eb" />
|
||||
<option name="title" value="新对话 2025年6月21日 10:51:39" />
|
||||
<option name="updateTime" value="1750474299387" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749793513436" />
|
||||
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />
|
||||
<option name="title" value="新对话 2025年6月13日 13:45:13" />
|
||||
<option name="updateTime" value="1749793513436" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749792408202" />
|
||||
<option name="id" value="019767c1428a7aa28682039d57d19778" />
|
||||
<option name="title" value="新对话 2025年6月13日 13:26:48" />
|
||||
<option name="updateTime" value="1749792408202" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749718122230" />
|
||||
<option name="id" value="01976353bef6703884544447c919013c" />
|
||||
<option name="title" value="新对话 2025年6月12日 16:48:42" />
|
||||
<option name="updateTime" value="1749718122230" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749648208122" />
|
||||
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
|
||||
<option name="title" value="新对话 2025年6月11日 21:23:28" />
|
||||
<option name="updateTime" value="1749648208122" />
|
||||
</Conversation>
|
||||
<Conversation>
|
||||
<option name="createTime" value="1749522765718" />
|
||||
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
|
||||
@ -57,16 +171,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -91,16 +196,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
</list>
|
||||
@ -135,16 +231,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -169,16 +256,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -203,16 +281,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -237,16 +306,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -271,16 +331,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -305,16 +356,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -339,16 +381,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -373,16 +406,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -407,16 +431,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -441,16 +456,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -475,16 +481,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -509,16 +506,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -543,16 +531,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -577,16 +556,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -611,16 +581,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -645,16 +606,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -679,16 +631,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -713,16 +656,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -773,16 +707,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -807,16 +732,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
<Turn>
|
||||
@ -841,16 +757,7 @@
|
||||
</option>
|
||||
<option name="status" value="SUCCESS" />
|
||||
<option name="variables">
|
||||
<list>
|
||||
<CodebaseVariable>
|
||||
<option name="selectedPlaceHolder">
|
||||
<Object />
|
||||
</option>
|
||||
<option name="selectedVariable">
|
||||
<Object />
|
||||
</option>
|
||||
</CodebaseVariable>
|
||||
</list>
|
||||
<list />
|
||||
</option>
|
||||
</Turn>
|
||||
</list>
|
||||
|
720
.idea/CopilotWebChatHistory.xml
generated
720
.idea/CopilotWebChatHistory.xml
generated
File diff suppressed because one or more lines are too long
@ -20,20 +20,23 @@ models:
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
epochs: 600 # 总训练轮次
|
||||
epochs: 400 # 总训练轮次
|
||||
batch_size: 128 # 批次大小
|
||||
lr: 0.001 # 初始学习率
|
||||
lr: 0.01 # 初始学习率
|
||||
optimizer: "sgd" # 优化器类型
|
||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
lr_step: 5 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.95 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
scheduler: "step" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
||||
checkpoints: "./checkpoints/resnet18_pdd_test/" # 模型保存目录
|
||||
restore: false
|
||||
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||
restore_model: "./checkpoints/resnet50_electornic_20250807/best.pth" # 模型恢复路径
|
||||
cosine_t_0: 10 # 初始周期长度
|
||||
cosine_t_mult: 1 # 周期长度倍率
|
||||
cosine_eta_min: 0.00001 # 最小学习率
|
||||
|
||||
# 验证参数
|
||||
validation:
|
||||
@ -46,8 +49,8 @@ data:
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 128 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||
data_train_dir: "../data_center/electornic/v1/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
@ -59,11 +62,13 @@ transform:
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs" # 日志保存目录
|
||||
logging_dir: "./logs/resnet50_electornic_log" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
# 分布式训练(可选)
|
||||
distributed:
|
||||
enabled: false # 是否启用分布式训练
|
||||
enabled: true # 是否启用分布式训练
|
||||
backend: "nccl" # 分布式后端(nccl/gloo)
|
||||
node_rank: 0 # 节点编号
|
||||
node_num: 2 # 共计几个节点 一般几台机器就有几个节点
|
||||
|
@ -51,7 +51,7 @@ data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 100 # 验证批次大小
|
||||
num_workers: 4 # 数据加载线程数
|
||||
num_workers: 16 # 数据加载线程数
|
||||
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||
|
||||
|
54
configs/pic_pic_similar.yml
Normal file
54
configs/pic_pic_similar.yml
Normal file
@ -0,0 +1,54 @@
|
||||
# configs/similar_analysis.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
|
||||
# model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
|
||||
heatmap:
|
||||
feature_layer: "layer4"
|
||||
show_heatmap: true
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 8 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_dir: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new"
|
||||
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
|
||||
total_pkl: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new/total.pkl"
|
||||
result_txt: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new/result.txt"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
#event:
|
||||
# oneToOne_max_th: 0.9
|
||||
# oneToSn_min_th: 0.6
|
||||
# event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
|
||||
# stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
|
||||
# pickle_path: "event.pickle"
|
@ -18,20 +18,20 @@ models:
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
epochs: 300 # 总训练轮次
|
||||
epochs: 800 # 总训练轮次
|
||||
batch_size: 64 # 批次大小
|
||||
lr: 0.005 # 初始学习率
|
||||
lr: 0.01 # 初始学习率
|
||||
optimizer: "sgd" # 优化器类型
|
||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
lr_decay: 0.95 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
|
||||
restore: True
|
||||
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
|
||||
checkpoints: "./checkpoints/resnet18_scatter_7.4/" # 模型保存目录
|
||||
restore: false
|
||||
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
|
||||
|
||||
|
||||
|
||||
@ -46,8 +46,8 @@ data:
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 100 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
|
||||
data_train_dir: "../data_center/scatter/v4/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/scatter/v4/val" # 验证数据集根目录
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
@ -59,7 +59,7 @@ transform:
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
|
||||
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
|
19
configs/scatter_data.yml
Normal file
19
configs/scatter_data.yml
Normal file
@ -0,0 +1,19 @@
|
||||
# configs/scatter_data.yml
|
||||
# 专为散称前处理的配置文件
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
source_dir: "../../data_center/electornic/source" # 原始数据
|
||||
train_dir: "../../data_center/electornic/v1/train" # 训练数据集根目录
|
||||
val_dir: "../../data_center/electornic/v1/val" # 验证数据集根目录
|
||||
extra_dir: "../../data_center/electornic/v1/extra" # 验证数据集根目录
|
||||
split_ratio: 0.9
|
||||
max_files: 10 # 数据集小于该阈值则归纳至extra
|
||||
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
|
54
configs/similar_analysis.yml
Normal file
54
configs/similar_analysis.yml
Normal file
@ -0,0 +1,54 @@
|
||||
# configs/similar_analysis.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
# model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
|
||||
|
||||
heatmap:
|
||||
feature_layer: "layer4"
|
||||
show_heatmap: true
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 8 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_dir: "/home/lc/data_center/image_analysis/error_compare_subimg"
|
||||
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
event:
|
||||
oneToOneTxt: "/home/lc/detecttracking/oneToOne.txt"
|
||||
oneToSnTxt: "/home/lc/detecttracking/oneToSn.txt"
|
||||
oneToOne_max_th: 0.9
|
||||
oneToSn_min_th: 0.6
|
||||
event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
|
||||
stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
|
||||
pickle_path: "event.pickle"
|
32
configs/sub_data.yml
Normal file
32
configs/sub_data.yml
Normal file
@ -0,0 +1,32 @@
|
||||
# configs/sub_data.yml
|
||||
# 专为对比模型训练的数据集设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
source_dir: "../../data_center/contrast_data/total" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_dir: "../../data_center/contrast_data/v1/train" # 训练数据集根目录
|
||||
val_dir: "../../data_center/contrast_data/v1/val" # 验证数据集根目录
|
||||
data_extra_dir: "../../data_center/contrast_data/v1/extra"
|
||||
max_files_ratio: 0.1
|
||||
min_files: 10
|
||||
split_ratio: 0.9
|
||||
combine_scr_dir: "../../data_center/contrast_data/v1/val" # 合并数据集源目录·
|
||||
combine_dst_dir: "../../data_center/contrast_data/v2/val" # 合并数据集目标目录
|
||||
|
||||
|
||||
extend:
|
||||
extend_same_dir: true
|
||||
extend_extra: true
|
||||
extend_extra_dir: "../../data_center/contrast_data/v1/extra" # 扩展测试集数据
|
||||
extend_train: true
|
||||
extend_train_dir: "../../data_center/contrast_data/v1/train" # 训练接数据扩展
|
||||
|
||||
limit:
|
||||
count_limit: true
|
||||
limit_count: 200
|
||||
limit_dir: "../../data_center/contrast_data/v1/train" # 限制单个样本数量
|
||||
|
||||
control:
|
||||
combine: true # 是否进行子类数据集合并
|
||||
split: false # 子类数据集拆解与扩增
|
@ -8,23 +8,28 @@ base:
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
distributed: false # 是否启用分布式训练 启用热力图时不能用分布式训练
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 1.0
|
||||
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
|
||||
model_path: "checkpoints/resnet18_electornic_20250806/best.pth"
|
||||
#resnet18_20250715_scale=0.75_sub
|
||||
#resnet18_20250718_scale=0.75_nosub
|
||||
half: false # 是否启用半精度测试(fp16)
|
||||
contrast_learning: false
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
||||
test_batch_size: 128 # 训练批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
test_dir: "../data_center/scatter/" # 验证数据集根目录
|
||||
test_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
|
||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||
test_list: "../data_center/scatter/val_pair.txt"
|
||||
test_list: "../data_center/electornic/v1/cross_same.txt"
|
||||
group_test: false
|
||||
save_image_joint: false
|
||||
image_joint_pth: "./joint_images"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
@ -34,6 +39,11 @@ transform:
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
heatmap:
|
||||
image_joint_pth: "./heatmap_joint_images"
|
||||
feature_layer: "layer4"
|
||||
show_heatmap: true
|
||||
|
||||
save:
|
||||
save_dir: ""
|
||||
save_name: ""
|
||||
|
29
configs/transform.yml
Normal file
29
configs/transform.yml
Normal file
@ -0,0 +1,29 @@
|
||||
# configs/transform.yml
|
||||
# pth转换onnx配置文件
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||
seed: 42 # 随机种子(保证可复现性)
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
dataset: "./dataset_electornic.txt" # 数据集名称
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet101'
|
||||
channel_ratio: 1.0
|
||||
model_path: "../checkpoints/resnet101_electornic_20250807/best.pth"
|
||||
onnx_model: "../checkpoints/resnet101_electornic_20250807/best.onnx"
|
||||
rknn_model: "../checkpoints/resnet101_electornic_20250807/resnet101_electornic_3588.rknn"
|
||||
rknn_batch_size: 1
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
@ -1,4 +1,4 @@
|
||||
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
from model import (resnet18, resnet34, resnet50, resnet101, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||
from model.metric import ArcFace, CosFace
|
||||
@ -13,7 +13,10 @@ class trainer_tools:
|
||||
|
||||
def get_backbone(self):
|
||||
backbone_mapping = {
|
||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'resnet101': lambda: resnet101(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'mobilevit_s': lambda: mobilevit_s(),
|
||||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||
@ -54,3 +57,24 @@ class trainer_tools:
|
||||
)
|
||||
}
|
||||
return optimizer_mapping
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
scheduler_mapping = {
|
||||
'step': lambda: optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=self.conf['training']['lr_step'],
|
||||
gamma=self.conf['training']['lr_decay']
|
||||
),
|
||||
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=self.conf['training']['epochs'],
|
||||
eta_min=self.conf['training']['cosine_eta_min']
|
||||
),
|
||||
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=self.conf['training'].get('cosine_t_0', 10),
|
||||
T_mult=self.conf['training'].get('cosine_t_mult', 1),
|
||||
eta_min=self.conf['training'].get('cosine_eta_min', 0)
|
||||
)
|
||||
}
|
||||
return scheduler_mapping
|
||||
|
@ -14,7 +14,7 @@ base:
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
checkpoints: "../checkpoints/resnet18_1009/best.pth"
|
||||
checkpoints: "../checkpoints/resnet18_20250718_scale=0.75_nosub/best.pth"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
@ -22,7 +22,7 @@ data:
|
||||
test_batch_size: 128 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
half: true # 是否启用半精度数据
|
||||
img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
|
||||
img_dirs_path: "/home/lc/data_center/baseStlib/pic/stlib_base" # base标准库图片存储路径
|
||||
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
|
||||
xlsx_pth: false # 过滤商品, 默认None不进行过滤
|
||||
|
||||
@ -41,7 +41,8 @@ logging:
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
save:
|
||||
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
|
||||
json_path: "../data/feature_json_compare/" # 保存单个json文件
|
||||
json_bin: "../search_library/resnet101_electronic.json" # 保存整个json文件
|
||||
json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub_1k_合并" # 保存单个json文件路径
|
||||
error_barcodes: "error_barcodes.txt"
|
||||
barcodes_statistics: "../search_library/barcodes_statistics.txt"
|
||||
create_single_json: true # 是否保存单个json文件
|
0
data_preprocessing/__init__.py
Normal file
0
data_preprocessing/__init__.py
Normal file
25
data_preprocessing/combine_sub_class.py
Normal file
25
data_preprocessing/combine_sub_class.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
def combine_dirs(conf):
|
||||
source_root = conf['data']['combine_scr_dir']
|
||||
target_root = conf['data']['combine_dst_dir']
|
||||
for roots, dir_names, files in os.walk(source_root):
|
||||
for dir_name in dir_names:
|
||||
source_dir = os.path.join(roots, dir_name)
|
||||
target_dir = os.path.join(target_root, dir_name.split('_')[0])
|
||||
if not os.path.exists(target_dir):
|
||||
os.mkdir(target_dir)
|
||||
for filename in os.listdir(source_dir):
|
||||
print(filename)
|
||||
source_file = os.sep.join([source_dir, filename])
|
||||
target_file = os.sep.join([target_dir, filename])
|
||||
shutil.copy(source_file, target_file)
|
||||
# print(f"已复制目录 {source_dir} 到 {target_dir}")
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# source_root = r'scatter_mini'
|
||||
# target_root = r'C:\Users\123\Desktop\scatter-1'
|
||||
# # combine_dirs(conf)
|
111
data_preprocessing/create_extra.py
Normal file
111
data_preprocessing/create_extra.py
Normal file
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
def count_files(directory):
|
||||
"""统计目录中的文件数量"""
|
||||
try:
|
||||
return len([f for f in os.listdir(directory)
|
||||
if os.path.isfile(os.path.join(directory, f))])
|
||||
except Exception as e:
|
||||
print(f"无法统计目录 {directory}: {e}")
|
||||
return 0
|
||||
|
||||
def clear_empty_dirs(path):
|
||||
"""
|
||||
删除空目录
|
||||
:param path: 目录路径
|
||||
"""
|
||||
for root, dirs, files in os.walk(path, topdown=False):
|
||||
for dir_name in dirs:
|
||||
dir_path = os.path.join(root, dir_name)
|
||||
try:
|
||||
if not os.listdir(dir_path):
|
||||
os.rmdir(dir_path)
|
||||
print(f"Deleted empty directory: {dir_path}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e.strerror}")
|
||||
|
||||
def get_max_files(conf):
|
||||
max_files_ratio = conf['data']['max_files_ratio']
|
||||
files_number = []
|
||||
for root, dirs, files in os.walk(conf['data']['source_dir']):
|
||||
if len(dirs) == 0:
|
||||
if len(files) == 0:
|
||||
print(root, dirs,files)
|
||||
files_number.append(len(files))
|
||||
files_number = sorted(files_number, reverse=False)
|
||||
max_files = files_number[int(max_files_ratio * len(files_number))]
|
||||
print(f"max_files: {max_files}")
|
||||
if max_files < conf['data']['min_files']:
|
||||
max_files = conf['data']['min_files']
|
||||
return max_files
|
||||
def megre_subdirs(pth):
|
||||
for roots, dir_names, files in os.walk(pth):
|
||||
print(f"image {dir_names}")
|
||||
for image in dir_names:
|
||||
inner_dir_path = os.path.join(pth, image)
|
||||
for inner_roots, inner_dirs, inner_files in os.walk(inner_dir_path):
|
||||
for inner_dir in inner_dirs:
|
||||
src_dir = os.path.join(inner_roots, inner_dir)
|
||||
dest_dir = os.path.join(pth, inner_dir)
|
||||
# shutil.copytree(src_dir, dest_dir)
|
||||
shutil.move(src_dir, dest_dir)
|
||||
print(f"Copied {inner_dir} to {pth}")
|
||||
clear_empty_dirs(pth)
|
||||
|
||||
|
||||
# def split_subdirs(source_dir, target_dir, max_files=10):
|
||||
def split_subdirs(conf):
|
||||
"""
|
||||
复制文件数≤max_files的子目录到目标目录
|
||||
:param source_dir: 源目录路径
|
||||
:param target_dir: 目标目录路径
|
||||
:param max_files: 最大文件数阈值
|
||||
"""
|
||||
source_dir = conf['data']['source_dir']
|
||||
target_extra_dir = conf['data']['data_extra_dir']
|
||||
train_dir = conf['data']['train_dir']
|
||||
max_files = get_max_files(conf)
|
||||
megre_subdirs(source_dir) # 合并子目录,删除上级目录
|
||||
# 创建目标目录
|
||||
Path(target_extra_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"开始处理目录: {source_dir}")
|
||||
print(f"目标目录: {target_extra_dir}")
|
||||
print(f"筛选条件: 文件数 ≤ {max_files}\n")
|
||||
|
||||
# 遍历源目录
|
||||
for subdir in os.listdir(source_dir):
|
||||
subdir_path = os.path.join(source_dir, subdir)
|
||||
|
||||
if not os.path.isdir(subdir_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
file_count = count_files(subdir_path)
|
||||
print(f"复制 {subdir} (包含 {file_count} 个文件)")
|
||||
if file_count <= max_files:
|
||||
dest_path = os.path.join(target_extra_dir, subdir)
|
||||
else:
|
||||
dest_path = os.path.join(train_dir, subdir)
|
||||
# 如果目标目录已存在则跳过
|
||||
if os.path.exists(dest_path):
|
||||
print(f"目录已存在,跳过: {dest_path}")
|
||||
continue
|
||||
print(f"复制 {subdir} (包含 {file_count} 个文件) 至 {dest_path}")
|
||||
shutil.copytree(subdir_path, dest_path)
|
||||
# shutil.move(subdir_path, dest_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理目录 {subdir} 时出错: {e}")
|
||||
|
||||
print("\n处理完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置路径
|
||||
with open('../configs/sub_data.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# 执行复制操作
|
||||
split_subdirs(conf)
|
72
data_preprocessing/data_split.py
Normal file
72
data_preprocessing/data_split.py
Normal file
@ -0,0 +1,72 @@
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
def is_image_file(filename):
|
||||
"""检查文件是否为图像文件"""
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||
return filename.lower().endswith(image_extensions)
|
||||
|
||||
|
||||
|
||||
def split_directory(conf):
|
||||
"""
|
||||
分割目录中的图像文件到train和val目录
|
||||
:param src_dir: 源目录路径
|
||||
:param train_dir: 训练集目录路径
|
||||
:param val_dir: 验证集目录路径
|
||||
:param split_ratio: 训练集比例(默认0.9)
|
||||
"""
|
||||
# 创建目标目录
|
||||
train_dir = conf['data']['train_dir']
|
||||
val_dir = conf['data']['val_dir']
|
||||
split_ratio = conf['data']['split_ratio']
|
||||
Path(val_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 遍历源目录
|
||||
for root, dirs, files in os.walk(train_dir):
|
||||
# 获取相对路径(train_dir)
|
||||
rel_path = os.path.relpath(root, train_dir)
|
||||
|
||||
# 跳过当前目录(.)
|
||||
if rel_path == '.':
|
||||
continue
|
||||
|
||||
# 创建对应的目标子目录
|
||||
val_subdir = os.path.join(val_dir, rel_path)
|
||||
os.makedirs(val_subdir, exist_ok=True)
|
||||
|
||||
# 筛选图像文件
|
||||
image_files = [f for f in files if is_image_file(f)]
|
||||
if not image_files:
|
||||
continue
|
||||
|
||||
# 随机打乱文件列表
|
||||
random.shuffle(image_files)
|
||||
|
||||
# 计算分割点
|
||||
split_point = int(len(image_files) * split_ratio)
|
||||
|
||||
# 复制文件到验证集
|
||||
for file in image_files[split_point:]:
|
||||
src = os.path.join(root, file)
|
||||
dst = os.path.join(val_subdir, file)
|
||||
# shutil.copy2(src, dst)
|
||||
shutil.move(src, dst)
|
||||
|
||||
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
|
||||
|
||||
def control_train_number():
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # 设置目录路径
|
||||
# TRAIN_DIR = "/home/lc/data_center/electornic/v1/train"
|
||||
# VAL_DIR = "/home/lc/data_center/electornic/v1/val"
|
||||
|
||||
with open('../configs/scatter_data.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
print("开始分割数据集...")
|
||||
split_directory(conf)
|
||||
print("数据集分割完成")
|
192
data_preprocessing/extend.py
Normal file
192
data_preprocessing/extend.py
Normal file
@ -0,0 +1,192 @@
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from PIL import Image, ImageEnhance
|
||||
|
||||
|
||||
class ImageExtendProcessor:
|
||||
def __init__(self, conf):
|
||||
self.conf = conf
|
||||
|
||||
def is_image_file(self, filename):
|
||||
"""检查文件是否为图像文件"""
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||
return filename.lower().endswith(image_extensions)
|
||||
|
||||
def random_cute_image(self, image_path, output_path, ratio=0.8):
|
||||
"""
|
||||
对图像进行随机裁剪
|
||||
:param image_path: 输入图像路径
|
||||
:param output_path: 输出图像路径
|
||||
:param ratio: 裁剪比例,决定裁剪区域的大小,默认为0.8
|
||||
"""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
# 获取图像尺寸
|
||||
width, height = img.size
|
||||
|
||||
# 计算裁剪后的尺寸
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
|
||||
# 随机生成裁剪起始点
|
||||
left = random.randint(0, width - new_width)
|
||||
top = random.randint(0, height - new_height)
|
||||
right = left + new_width
|
||||
bottom = top + new_height
|
||||
|
||||
# 执行裁剪
|
||||
cropped_img = img.crop((left, top, right, bottom))
|
||||
|
||||
# 保存裁剪后的图像
|
||||
cropped_img.save(output_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"处理图像 {image_path} 时出错: {e}")
|
||||
return False
|
||||
|
||||
def random_brightness(self, image_path, output_path, brightness_factor=None):
|
||||
"""
|
||||
对图像进行随机亮度调整
|
||||
:param image_path: 输入图像路径
|
||||
:param output_path: 输出图像路径
|
||||
:param brightness_factor: 亮度调整因子,默认为随机值
|
||||
"""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
# 创建一个ImageEnhance.Brightness对象
|
||||
enhancer = ImageEnhance.Brightness(img)
|
||||
|
||||
# 如果没有指定亮度因子,则随机生成
|
||||
if brightness_factor is None:
|
||||
brightness_factor = random.uniform(0.5, 1.5)
|
||||
|
||||
# 应用亮度调整
|
||||
brightened_img = enhancer.enhance(brightness_factor)
|
||||
|
||||
# 保存调整后的图像
|
||||
brightened_img.save(output_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"处理图像 {image_path} 时出错: {e}")
|
||||
return False
|
||||
|
||||
def rotate_image(self, image_path, output_path, degrees):
|
||||
"""旋转图像并保存到指定路径"""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
# 旋转图像并自动调整画布大小
|
||||
rotated = img.rotate(degrees, expand=True)
|
||||
rotated.save(output_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"处理图像 {image_path} 时出错: {e}")
|
||||
return False
|
||||
|
||||
def process_extra_directory(self, src_dir, dst_dir, same_directory, dir_name):
|
||||
"""
|
||||
处理单个目录中的图像文件
|
||||
:param src_dir: 源目录路径
|
||||
:param dst_dir: 目标目录路径
|
||||
"""
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
# 获取目录中所有图像文件
|
||||
image_files = [f for f in os.listdir(src_dir)
|
||||
if self.is_image_file(f) and os.path.isfile(os.path.join(src_dir, f))]
|
||||
|
||||
# 处理每个图像文件
|
||||
for img_file in image_files:
|
||||
src_path = os.path.join(src_dir, img_file)
|
||||
base_name, ext = os.path.splitext(img_file)
|
||||
if not same_directory:
|
||||
# 复制原始文件 (另存文件夹时启用)
|
||||
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
|
||||
if dir_name == 'extra':
|
||||
# 生成并保存旋转后的图像
|
||||
for angle in [90, 180, 270]:
|
||||
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||
self.rotate_image(src_path, dst_path, angle)
|
||||
for ratio in [0.8, 0.85, 0.9]:
|
||||
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
|
||||
self.random_cute_image(src_path, dst_path, ratio)
|
||||
for brightness_factor in [0.8, 0.9, 1.0]:
|
||||
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
|
||||
self.random_brightness(src_path, dst_path, brightness_factor)
|
||||
elif dir_name == 'train':
|
||||
# 生成并保存旋转后的图像
|
||||
for angle in [90, 180, 270]:
|
||||
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||
self.rotate_image(src_path, dst_path, angle)
|
||||
|
||||
def image_extend(self, src_dir, dst_dir, same_directory=False, dir_name=None):
|
||||
if same_directory:
|
||||
n_dst_dir = src_dir
|
||||
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
|
||||
else:
|
||||
n_dst_dir = dst_dir
|
||||
print(f"处理目录 {src_dir} 中的图像文件 保存至不同目录下")
|
||||
for src_subdir in os.listdir(src_dir):
|
||||
src_subdir_path = os.path.join(src_dir, src_subdir)
|
||||
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
|
||||
if dir_name == 'extra':
|
||||
self.process_extra_directory(src_subdir_path,
|
||||
dst_subdir_path,
|
||||
same_directory,
|
||||
dir_name)
|
||||
if dir_name == 'train':
|
||||
if len(os.listdir(src_subdir_path)) < 50:
|
||||
self.process_extra_directory(src_subdir_path,
|
||||
dst_subdir_path,
|
||||
same_directory,
|
||||
dir_name)
|
||||
|
||||
def random_remove_image(self, subdir_path, max_count=1000):
|
||||
"""
|
||||
随机删除子目录中的图像文件,直到数量不超过max_count
|
||||
:param subdir_path: 子目录路径
|
||||
:param max_count: 最大允许的图像数量
|
||||
"""
|
||||
# 统计图像文件数量
|
||||
image_files = [f for f in os.listdir(subdir_path)
|
||||
if self.is_image_file(f) and os.path.isfile(os.path.join(subdir_path, f))]
|
||||
current_count = len(image_files)
|
||||
|
||||
# 如果图像数量不超过max_count,则无需删除
|
||||
if current_count <= max_count:
|
||||
print(f"无需处理 {subdir_path} (包含 {current_count} 个图像)")
|
||||
return
|
||||
|
||||
# 计算需要删除的文件数
|
||||
remove_count = current_count - max_count
|
||||
|
||||
# 随机选择要删除的文件
|
||||
files_to_remove = random.sample(image_files, remove_count)
|
||||
|
||||
# 删除选中的文件
|
||||
for file in files_to_remove:
|
||||
file_path = os.path.join(subdir_path, file)
|
||||
os.remove(file_path)
|
||||
print(f"已删除 {file_path}")
|
||||
|
||||
def control_number(self):
|
||||
if self.conf['extend']['extend_extra']:
|
||||
self.image_extend(self.conf['extend']['extend_extra_dir'],
|
||||
'',
|
||||
same_directory=self.conf['extend']['extend_same_dir'],
|
||||
dir_name='extra')
|
||||
if self.conf['extend']['extend_train']:
|
||||
self.image_extend(self.conf['extend']['extend_train_dir'],
|
||||
'',
|
||||
same_directory=self.conf['extend']['extend_same_dir'],
|
||||
dir_name='train')
|
||||
if self.conf['limit']['count_limit']:
|
||||
self.random_remove_image(self.conf['limit']['limit_dir'],
|
||||
max_count=self.conf['limit']['limit_count'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
src_dir = "./scatter_mini"
|
||||
dst_dir = "./scatter_add"
|
||||
image_ex = ImageExtendProcessor()
|
||||
image_ex.image_extend(src_dir, dst_dir, same_directory=False)
|
21
data_preprocessing/sub_data_preprocessing.py
Normal file
21
data_preprocessing/sub_data_preprocessing.py
Normal file
@ -0,0 +1,21 @@
|
||||
from create_extra import split_subdirs
|
||||
from data_split import split_directory
|
||||
from extend import ImageExtendProcessor
|
||||
from combine_sub_class import combine_dirs
|
||||
import yaml
|
||||
|
||||
|
||||
def data_preprocessing(conf):
|
||||
if conf['control']['split']:
|
||||
split_subdirs(conf)
|
||||
image_ex = ImageExtendProcessor(conf)
|
||||
image_ex.control_number()
|
||||
split_directory(conf)
|
||||
if conf['control']['combine']:
|
||||
combine_dirs(conf)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open('../configs/sub_data.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data_preprocessing(conf)
|
@ -4,7 +4,7 @@ from .mobilevit import mobilevit_s
|
||||
from .metric import ArcFace, CosFace
|
||||
from .loss import FocalLoss
|
||||
from .resbam import resnet
|
||||
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
|
||||
from .resnet_pre import resnet18, resnet34, resnet50, resnet101, resnet152,resnet14, CustomResNet18
|
||||
from .mobilenet_v2 import mobilenet_v2
|
||||
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||
# from .mobilenet_v1 import mobilenet_v1
|
||||
|
@ -205,7 +205,7 @@ class ResNet(nn.Module):
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
print("ResNet scale: >>>>>>>>>> ", scale)
|
||||
print("通道剪枝 {}".format(scale))
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
@ -222,13 +222,13 @@ class ResNet(nn.Module):
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||
self.maxpool2 = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||
)
|
||||
# self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||
# self.maxpool2 = nn.Sequential(
|
||||
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
# nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||
# )
|
||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
@ -368,7 +368,7 @@ def resnet18(pretrained=True, progress=True, **kwargs):
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
def resnet34(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
@ -380,7 +380,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
def resnet50(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
@ -392,7 +392,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
def resnet101(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
|
87
test_ori.py
87
test_ori.py
@ -11,10 +11,13 @@ import matplotlib.pyplot as plt
|
||||
|
||||
# from config import config as conf
|
||||
from tools.dataset import get_transform
|
||||
from tools.image_joint import merge_imgs
|
||||
from tools.getHeatMap import cal_cam
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
with open('configs/test.yml', 'r') as f:
|
||||
with open('../configs/test.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Constants from config
|
||||
@ -22,6 +25,7 @@ embedding_size = conf["base"]["embedding_size"]
|
||||
img_size = conf["transform"]["img_size"]
|
||||
device = conf["base"]["device"]
|
||||
|
||||
|
||||
def unique_image(pair_list: str) -> Set[str]:
|
||||
unique_images = set()
|
||||
try:
|
||||
@ -115,8 +119,12 @@ def featurize(
|
||||
except Exception as e:
|
||||
print(f"Error in feature extraction: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def cosin_metric(x1, x2):
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
|
||||
|
||||
def threshold_search(y_score, y_true):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
@ -133,7 +141,7 @@ def threshold_search(y_score, y_true):
|
||||
|
||||
|
||||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||||
x = np.linspace(start=-1, stop=1.0, num=100, endpoint=True).tolist()
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||
@ -143,6 +151,7 @@ def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig('grid.png')
|
||||
plt.show()
|
||||
@ -154,19 +163,19 @@ def showHist(same, cross):
|
||||
Cross = np.array(cross)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].hist(Same, bins=100, edgecolor='black')
|
||||
axs[0].set_xlim([-1, 1])
|
||||
axs[0].set_title('Same Barcode')
|
||||
|
||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].hist(Cross, bins=100, edgecolor='black')
|
||||
axs[1].set_xlim([-1, 1])
|
||||
axs[1].set_title('Cross Barcode')
|
||||
plt.savefig('plot.png')
|
||||
|
||||
|
||||
def compute_accuracy_recall(score, labels):
|
||||
th = 0.1
|
||||
squence = np.linspace(-1, 1, num=50)
|
||||
squence = np.linspace(-1, 1, num=100)
|
||||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||
Same = score[:len(score) // 2]
|
||||
Cross = score[len(score) // 2:]
|
||||
@ -179,24 +188,26 @@ def compute_accuracy_recall(score, labels):
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||
|
||||
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
|
||||
PreciseNeg[-1], recall_TN[-1]))
|
||||
showHist(Same, Cross)
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||
|
||||
|
||||
def compute_accuracy(
|
||||
feature_dict: Dict[str, torch.Tensor],
|
||||
pair_list: str,
|
||||
test_root: str
|
||||
cam: cal_cam,
|
||||
) -> Tuple[float, float]:
|
||||
try:
|
||||
pair_list = conf['data']['test_list']
|
||||
test_root = conf['data']['test_dir']
|
||||
with open(pair_list, 'r') as f:
|
||||
pairs = f.readlines()
|
||||
except IOError as e:
|
||||
@ -211,7 +222,8 @@ def compute_accuracy(
|
||||
if not pair:
|
||||
continue
|
||||
|
||||
try:
|
||||
# try:
|
||||
print(f"Processing pair: {pair}")
|
||||
img1, img2, label = pair.split()
|
||||
img1_path = osp.join(test_root, img1)
|
||||
img2_path = osp.join(test_root, img2)
|
||||
@ -224,13 +236,20 @@ def compute_accuracy(
|
||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||
similarity = cosin_metric(feat1, feat2)
|
||||
|
||||
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
|
||||
if conf['data']['save_image_joint']:
|
||||
merge_imgs(img1_path,
|
||||
img2_path,
|
||||
conf,
|
||||
similarity,
|
||||
label,
|
||||
cam)
|
||||
similarities.append(similarity)
|
||||
labels.append(int(label))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
continue
|
||||
# except Exception as e:
|
||||
# print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
# continue
|
||||
|
||||
# Find optimal threshold and accuracy
|
||||
accuracy, threshold = threshold_search(similarities, labels)
|
||||
@ -267,10 +286,10 @@ def compute_group_accuracy(content_list_read):
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if data_loaded[-1] == '1':
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||
Same.append(similarity)
|
||||
else:
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||
Cross.append(similarity)
|
||||
allLabel.append(data_loaded[-1])
|
||||
allSimilarity.extend(similarity)
|
||||
@ -291,14 +310,36 @@ def init_model():
|
||||
print('load model {} '.format(conf['models']['backbone']))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
###############正常模型加载################
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device']))
|
||||
#######################################
|
||||
####### 对比学习模型临时运用###
|
||||
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||
# new_state_dict = {}
|
||||
# for k, v in state_dict.items():
|
||||
# new_key = k.replace("module.base_model.", "module.")
|
||||
# new_state_dict[new_key] = v
|
||||
# model.load_state_dict(new_state_dict, strict=False)
|
||||
###########################
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
else:
|
||||
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
|
||||
if conf.model_half:
|
||||
try:
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device']))
|
||||
except:
|
||||
state_dict = torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device'])
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k.replace("module.", "")
|
||||
new_state_dict[new_key] = v
|
||||
model.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
@ -308,7 +349,7 @@ def init_model():
|
||||
if __name__ == '__main__':
|
||||
model = init_model()
|
||||
model.eval()
|
||||
|
||||
cam = cal_cam(model, conf)
|
||||
if not conf['data']['group_test']:
|
||||
images = unique_image(conf['data']['test_list'])
|
||||
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||||
@ -318,7 +359,7 @@ if __name__ == '__main__':
|
||||
for group in groups:
|
||||
d = featurize(group, test_transform, model, conf['base']['device'])
|
||||
feature_dict.update(d)
|
||||
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
|
||||
accuracy, threshold = compute_accuracy(feature_dict, cam)
|
||||
print(
|
||||
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||||
)
|
||||
|
@ -5,12 +5,14 @@ import torchvision.transforms as T
|
||||
# from config import config as conf
|
||||
import torch
|
||||
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
|
||||
def get_transform(cfg):
|
||||
train_transform = T.Compose([
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
@ -24,7 +26,7 @@ def get_transform(cfg):
|
||||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||
])
|
||||
test_transform = T.Compose([
|
||||
# T.Lambda(pad_to_square), # 补边
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
@ -32,37 +34,73 @@ def get_transform(cfg):
|
||||
])
|
||||
return train_transform, test_transform
|
||||
|
||||
def load_data(training=True, cfg=None):
|
||||
|
||||
def load_data(training=True, cfg=None, return_dataset=False):
|
||||
train_transform, test_transform = get_transform(cfg)
|
||||
if training:
|
||||
dataroot = cfg['data']['data_train_dir']
|
||||
transform = train_transform
|
||||
# transform = conf.train_transform
|
||||
# transform.yml = conf.train_transform
|
||||
batch_size = cfg['data']['train_batch_size']
|
||||
else:
|
||||
dataroot = cfg['data']['data_val_dir']
|
||||
# transform = conf.test_transform
|
||||
# transform.yml = conf.test_transform
|
||||
transform = test_transform
|
||||
batch_size = cfg['data']['val_batch_size']
|
||||
|
||||
data = ImageFolder(dataroot, transform=transform)
|
||||
class_num = len(data.classes)
|
||||
if return_dataset:
|
||||
return data, class_num
|
||||
else:
|
||||
loader = DataLoader(data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle=True if training else False,
|
||||
pin_memory=cfg['base']['pin_memory'],
|
||||
num_workers=cfg['data']['num_workers'],
|
||||
drop_last=True)
|
||||
return loader, class_num
|
||||
|
||||
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||
"""
|
||||
MultiEpochsDataLoader 类
|
||||
通过重用工作进程来提高数据加载效率,避免每个epoch重新启动工作进程
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._DataLoader__initialized = False
|
||||
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||
self._DataLoader__initialized = True
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
|
||||
class _RepeatSampler(object):
|
||||
"""
|
||||
重复采样器,避免每个epoch重新创建迭代器
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
# def load_gift_data(action):
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# return train_dataset, val_dataset, test_dataset
|
||||
|
@ -1,10 +1,10 @@
|
||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||
|
23
tools/dataset_electornic.txt
Normal file
23
tools/dataset_electornic.txt
Normal file
@ -0,0 +1,23 @@
|
||||
../electronic_imgs/0.jpg
|
||||
../electronic_imgs/1.jpg
|
||||
../electronic_imgs/2.jpg
|
||||
../electronic_imgs/3.jpg
|
||||
../electronic_imgs/4.jpg
|
||||
../electronic_imgs/5.jpg
|
||||
../electronic_imgs/6.jpg
|
||||
../electronic_imgs/7.jpg
|
||||
../electronic_imgs/8.jpg
|
||||
../electronic_imgs/9.jpg
|
||||
../electronic_imgs/10.jpg
|
||||
../electronic_imgs/11.jpg
|
||||
../electronic_imgs/12.jpg
|
||||
../electronic_imgs/13.jpg
|
||||
../electronic_imgs/14.jpg
|
||||
../electronic_imgs/15.jpg
|
||||
../electronic_imgs/16.jpg
|
||||
../electronic_imgs/17.jpg
|
||||
../electronic_imgs/18.jpg
|
||||
../electronic_imgs/19.jpg
|
||||
../electronic_imgs/20.jpg
|
||||
../electronic_imgs/21.jpg
|
||||
../electronic_imgs/22.jpg
|
144
tools/event_similar_analysis.py
Normal file
144
tools/event_similar_analysis.py
Normal file
@ -0,0 +1,144 @@
|
||||
from similar_analysis import SimilarAnalysis
|
||||
import os
|
||||
import pickle
|
||||
from tools.image_joint import merge_imgs
|
||||
|
||||
|
||||
class EventSimilarAnalysis(SimilarAnalysis):
|
||||
def __init__(self):
|
||||
super(EventSimilarAnalysis, self).__init__()
|
||||
self.fn_one2one_event, self.fp_one2one_event = self.One2one_similar_analysis()
|
||||
self.fn_one2sn_event, self.fp_one2sn_event = self.One2Sn_similar_analysis()
|
||||
if os.path.exists(self.conf['event']['pickle_path']):
|
||||
print('pickle file exists')
|
||||
else:
|
||||
self.target_image = self.get_path()
|
||||
|
||||
def get_path(self):
|
||||
events = [self.fn_one2one_event, self.fp_one2one_event,
|
||||
self.fn_one2sn_event, self.fp_one2sn_event]
|
||||
event_image_path = []
|
||||
barcode_image_path = []
|
||||
for event in events:
|
||||
for event_name, bcd in event:
|
||||
event_sub_image = os.sep.join([self.conf['event']['event_save_dir'],
|
||||
event_name,
|
||||
'subimgs'])
|
||||
barcode_images = os.sep.join([self.conf['event']['stdlib_image_path'],
|
||||
bcd])
|
||||
for image_name in os.listdir(event_sub_image):
|
||||
event_image_path.append(os.sep.join([event_sub_image, image_name]))
|
||||
for barcode in os.listdir(barcode_images):
|
||||
barcode_image_path.append(os.sep.join([barcode_images, barcode]))
|
||||
return list(set(event_image_path + barcode_image_path))
|
||||
|
||||
|
||||
def write_dict_to_pickle(self, data):
|
||||
"""将字典写入pickle文件."""
|
||||
with open(self.conf['event']['pickle_path'], 'wb') as file:
|
||||
pickle.dump(data, file)
|
||||
|
||||
def get_dict_to_pickle(self):
|
||||
with open(self.conf['event']['pickle_path'], 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
|
||||
def create_total_feature(self):
|
||||
feature_dicts = self.get_feature_map(self.target_image)
|
||||
self.write_dict_to_pickle(feature_dicts)
|
||||
print(feature_dicts)
|
||||
|
||||
def One2one_similar_analysis(self):
|
||||
fn_event, fp_event = [], []
|
||||
with open(self.conf['event']['oneToOneTxt'], 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
print(line.strip().split(' '))
|
||||
event_infor = line.strip().split(' ')
|
||||
label = event_infor[0]
|
||||
event_name = event_infor[1]
|
||||
bcd = event_infor[2]
|
||||
simi1 = event_infor[3]
|
||||
simi2 = event_infor[4]
|
||||
if label == 'same' and float(simi2) < self.conf['event']['oneToOne_max_th']:
|
||||
print(event_name, bcd, simi1)
|
||||
fn_event.append((event_name, bcd))
|
||||
elif label == 'diff' and float(simi2) > self.conf['event']['oneToSn_min_th']:
|
||||
fp_event.append((event_name, bcd))
|
||||
return fn_event, fp_event
|
||||
|
||||
def One2Sn_similar_analysis(self):
|
||||
fn_event, fp_event = [], []
|
||||
with open(self.conf['event']['oneToOneTxt'], 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
print(line.strip().split(' '))
|
||||
event_infor = line.strip().split(' ')
|
||||
label = event_infor[0]
|
||||
event_name = event_infor[1]
|
||||
bcd = event_infor[2]
|
||||
simi = event_infor[3]
|
||||
if label == 'fn':
|
||||
print(event_name, bcd, simi)
|
||||
fn_event.append((event_name, bcd))
|
||||
elif label == 'fp':
|
||||
fp_event.append((event_name, bcd))
|
||||
return fn_event, fp_event
|
||||
|
||||
def save_joint_image(self, img_pth1, img_pth2, feature_dicts, record):
|
||||
feature_dict1 = feature_dicts[img_pth1]
|
||||
feature_dict2 = feature_dicts[img_pth2]
|
||||
similarity = self.get_similarity(feature_dict1.cpu().numpy(),
|
||||
feature_dict2.cpu().numpy())
|
||||
dir_name = img_pth1.split('/')[-3]
|
||||
save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name, record])
|
||||
if "fp" in record:
|
||||
if similarity > 0.8:
|
||||
merge_imgs(img_pth1,
|
||||
img_pth2,
|
||||
self.conf,
|
||||
similarity,
|
||||
label=None,
|
||||
cam=self.cam,
|
||||
save_path=save_path)
|
||||
else:
|
||||
if similarity < 0.8:
|
||||
merge_imgs(img_pth1,
|
||||
img_pth2,
|
||||
self.conf,
|
||||
similarity,
|
||||
label=None,
|
||||
cam=self.cam,
|
||||
save_path=save_path)
|
||||
print(similarity)
|
||||
|
||||
def get_contrast(self, feature_dicts):
|
||||
events_compare = [self.fp_one2one_event, self.fn_one2one_event, self.fp_one2sn_event, self.fn_one2sn_event]
|
||||
event_record = ['fp_one2one', 'fn_one2one', 'fp_one2sn', 'fn_one2sn']
|
||||
for event_compare, record in zip(events_compare, event_record):
|
||||
for img, img_std in event_compare:
|
||||
imgs_pth1 = os.sep.join([self.conf['event']['event_save_dir'],
|
||||
img,
|
||||
'subimgs'])
|
||||
imgs_pth2 = os.sep.join([self.conf['event']['stdlib_image_path'],
|
||||
img_std])
|
||||
for img1 in os.listdir(imgs_pth1):
|
||||
for img2 in os.listdir(imgs_pth2):
|
||||
img_pth1 = os.sep.join([imgs_pth1, img1])
|
||||
img_pth2 = os.sep.join([imgs_pth2, img2])
|
||||
try:
|
||||
self.save_joint_image(img_pth1, img_pth2, feature_dicts, record)
|
||||
except Exception as e:
|
||||
continue
|
||||
print(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
event_similar_analysis = EventSimilarAnalysis()
|
||||
if os.path.exists(event_similar_analysis.conf['event']['pickle_path']):
|
||||
print('pickle file exists')
|
||||
else:
|
||||
event_similar_analysis.create_total_feature() # 生成pickle文件, 生成时间较长,生成一个文件即可
|
||||
feature_dicts = event_similar_analysis.get_dict_to_pickle()
|
||||
# all_compare_img = event_similar_analysis.get_image_map()
|
||||
event_similar_analysis.get_contrast(feature_dicts) # 获取比对结果
|
164
tools/getHeatMap.py
Normal file
164
tools/getHeatMap.py
Normal file
@ -0,0 +1,164 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torchvision import models
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as tfs
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
# from tools.config import cfg
|
||||
# from comparative.tools.initmodel import initSimilarityModel
|
||||
import yaml
|
||||
from dataset import get_transform
|
||||
|
||||
|
||||
class cal_cam(nn.Module):
|
||||
def __init__(self, model, conf):
|
||||
super(cal_cam, self).__init__()
|
||||
self.conf = conf
|
||||
self.device = self.conf['base']['device']
|
||||
|
||||
self.model = model
|
||||
self.model.to(self.device)
|
||||
|
||||
# 要求梯度的层
|
||||
self.feature_layer = conf['heatmap']['feature_layer']
|
||||
# 记录梯度
|
||||
self.gradient = []
|
||||
# 记录输出的特征图
|
||||
self.output = []
|
||||
_, self.transform = get_transform(self.conf)
|
||||
|
||||
def get_conf(self, yaml_pth):
|
||||
with open(yaml_pth, 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return conf
|
||||
|
||||
def save_grad(self, grad):
|
||||
self.gradient.append(grad)
|
||||
|
||||
def get_grad(self):
|
||||
return self.gradient[-1].cpu().data
|
||||
|
||||
def get_feature(self):
|
||||
return self.output[-1][0]
|
||||
|
||||
def process_img(self, input):
|
||||
input = self.transform(input)
|
||||
input = input.unsqueeze(0)
|
||||
return input
|
||||
|
||||
# 计算最后一个卷积层的梯度,输出梯度和最后一个卷积层的特征图
|
||||
def getGrad(self, input_):
|
||||
self.gradient = [] # 清除之前的梯度
|
||||
self.output = [] # 清除之前的特征图
|
||||
# print(f"cuda.memory_allocated 1 {torch.cuda.memory_allocated()/ (1024 ** 3)}G")
|
||||
input_ = input_.to(self.device).requires_grad_(True)
|
||||
num = 1
|
||||
for name, module in self.model._modules.items():
|
||||
# print(f'module_name: {name}')
|
||||
# print(f'module: {module}')
|
||||
if (num == 1):
|
||||
input = module(input_)
|
||||
num = num + 1
|
||||
continue
|
||||
# 是待提取特征图的层
|
||||
if (name == self.feature_layer):
|
||||
input = module(input)
|
||||
input.register_hook(self.save_grad)
|
||||
self.output.append([input])
|
||||
# 马上要到全连接层了
|
||||
elif (name == "avgpool"):
|
||||
input = module(input)
|
||||
input = input.reshape(input.shape[0], -1)
|
||||
# 普通的层
|
||||
else:
|
||||
input = module(input)
|
||||
|
||||
# print(f"cuda.memory_allocated 2 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||
# 到这里input就是最后全连接层的输出了
|
||||
index = torch.max(input, dim=-1)[1]
|
||||
one_hot = torch.zeros((1, input.shape[-1]), dtype=torch.float32)
|
||||
one_hot[0][index] = 1
|
||||
confidenct = one_hot * input.cpu()
|
||||
confidenct = torch.sum(confidenct, dim=-1).requires_grad_(True)
|
||||
|
||||
# print(f"cuda.memory_allocated 3 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||
# 清除之前的所有梯度
|
||||
self.model.zero_grad()
|
||||
# 反向传播获取梯度
|
||||
grad_output = torch.ones_like(confidenct)
|
||||
confidenct.backward(grad_output)
|
||||
# 获取特征图的梯度
|
||||
grad_val = self.get_grad()
|
||||
feature = self.get_feature()
|
||||
|
||||
# print(f"cuda.memory_allocated 4 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||
return grad_val, feature, input_.grad
|
||||
|
||||
# 计算CAM
|
||||
def getCam(self, grad_val, feature):
|
||||
# 对特征图的每个通道进行全局池化
|
||||
alpha = torch.mean(grad_val, dim=(2, 3)).cpu()
|
||||
feature = feature.cpu()
|
||||
# 将池化后的结果和相应通道特征图相乘
|
||||
cam = torch.zeros((feature.shape[2], feature.shape[3]), dtype=torch.float32)
|
||||
for idx in range(alpha.shape[1]):
|
||||
cam = cam + alpha[0][idx] * feature[0][idx]
|
||||
# 进行ReLU操作
|
||||
cam = np.maximum(cam.detach().numpy(), 0)
|
||||
|
||||
# plt.imshow(cam)
|
||||
# plt.colorbar()
|
||||
# plt.savefig("cam.jpg")
|
||||
|
||||
# 将cam区域放大到输入图片大小
|
||||
cam_ = cv2.resize(cam, (224, 224))
|
||||
cam_ = cam_ - np.min(cam_)
|
||||
cam_ = cam_ / np.max(cam_)
|
||||
# plt.imshow(cam_)
|
||||
# plt.savefig("cam_.jpg")
|
||||
cam = torch.from_numpy(cam)
|
||||
|
||||
return cam, cam_
|
||||
|
||||
def show_img(self, cam_, img, heatmap_save_pth, imgname):
|
||||
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||
# cv2.imwrite("img.jpg", cam_img)
|
||||
cv2.imwrite(os.sep.join([heatmap_save_pth, imgname]), cam_img)
|
||||
|
||||
def get_hot_map(self, img_pth):
|
||||
img = Image.open(img_pth)
|
||||
img = img.resize((224, 224))
|
||||
input = self.process_img(img)
|
||||
grad_val, feature, input_grad = self.getGrad(input)
|
||||
cam, cam_ = self.getCam(grad_val, feature)
|
||||
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||
cam_img = Image.fromarray(np.uint8(cam_img))
|
||||
return cam_img
|
||||
|
||||
# def __call__(self, img_root, heatmap_save_pth):
|
||||
# for imgname in os.listdir(img_root):
|
||||
# img = Image.open(os.sep.join([img_root, imgname]))
|
||||
# img = img.resize((224, 224))
|
||||
# # plt.imshow(img)
|
||||
# # plt.savefig("airplane.jpg")
|
||||
# input = self.process_img(img)
|
||||
# grad_val, feature, input_grad = self.getGrad(input)
|
||||
# cam, cam_ = self.getCam(grad_val, feature)
|
||||
# self.show_img(cam_, img, heatmap_save_pth, imgname)
|
||||
# return cam
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cam = cal_cam()
|
||||
img_root = "test_img/"
|
||||
heatmap_save_pth = "heatmap_result"
|
||||
cam(img_root, heatmap_save_pth)
|
213
tools/getpairs.py
Normal file
213
tools/getpairs.py
Normal file
@ -0,0 +1,213 @@
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
import logging
|
||||
|
||||
|
||||
class PairGenerator:
|
||||
"""Generate positive and negative image pairs for contrastive learning."""
|
||||
|
||||
def __init__(self, original_path):
|
||||
self._setup_logging()
|
||||
self.original_path = original_path
|
||||
self._delete_space()
|
||||
|
||||
def _delete_space(self): # 删除图片文件名中的空格
|
||||
print(self.original_path)
|
||||
for root, dirs, files in os.walk(self.original_path):
|
||||
for file_name in files:
|
||||
if file_name.endswith('.jpg' or '.png'):
|
||||
n_file_name = file_name.replace(' ', '')
|
||||
os.rename(os.path.join(root, file_name), os.path.join(root, n_file_name))
|
||||
if 'rotate' in file_name:
|
||||
os.remove(os.path.join(root, file_name))
|
||||
for dir_name in dirs:
|
||||
n_dir_name = dir_name.replace(' ', '')
|
||||
os.rename(os.path.join(root, dir_name), os.path.join(root, n_dir_name))
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Configure logging settings."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
||||
"""Scan directory and return dict of {folder: [image_paths]}."""
|
||||
root = Path(root_dir)
|
||||
if not root.is_dir():
|
||||
raise ValueError(f"Invalid directory: {root_dir}")
|
||||
|
||||
return {
|
||||
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
||||
for folder in root.iterdir() if folder.is_dir()
|
||||
}
|
||||
|
||||
def _generate_same_pairs(
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int,
|
||||
min_size: int, # min_size is the minimum number of images per folder
|
||||
group_size: Optional[int] = None
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate positive pairs from same folder."""
|
||||
pairs = []
|
||||
|
||||
for folder, files in files_dict.items():
|
||||
if len(files) < 2:
|
||||
continue
|
||||
|
||||
if group_size:
|
||||
# Group mode: generate all possible pairs within group
|
||||
for i in range(0, len(files), group_size):
|
||||
group = files[i:i + group_size]
|
||||
pairs.extend([
|
||||
(group[i], group[j], 1)
|
||||
for i in range(len(group))
|
||||
for j in range(i + 1, len(group))
|
||||
])
|
||||
else:
|
||||
# Individual mode: random pairs
|
||||
try:
|
||||
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||
|
||||
random.shuffle(pairs)
|
||||
return pairs[:num_pairs]
|
||||
|
||||
def _generate_cross_pairs(
|
||||
self,
|
||||
files_dict: Dict[str, List[str]],
|
||||
num_pairs: int
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""Generate negative pairs from different folders."""
|
||||
folders = list(files_dict.keys())
|
||||
pairs = []
|
||||
existing_pairs = set()
|
||||
|
||||
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||
folder1, folder2 = random.sample(folders, 2)
|
||||
file1 = random.choice(files_dict[folder1])
|
||||
file2 = random.choice(files_dict[folder2])
|
||||
|
||||
pair_key = (file1, file2)
|
||||
reverse_key = (file2, file1)
|
||||
|
||||
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
|
||||
pairs.append((file1, file2, 0))
|
||||
existing_pairs.add(pair_key)
|
||||
existing_pairs.add(reverse_key)
|
||||
|
||||
return pairs
|
||||
|
||||
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||
"""Generate random pairs from file list."""
|
||||
max_possible = len(files) // 2
|
||||
if max_possible == 0:
|
||||
return []
|
||||
|
||||
actual_pairs = min(num_pairs, max_possible)
|
||||
indices = random.sample(range(len(files)), 2 * actual_pairs)
|
||||
indices.sort()
|
||||
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
|
||||
|
||||
def get_pairs(self,
|
||||
root_dir: str,
|
||||
num_pairs: int = 6000,
|
||||
output_txt: Optional[str] = None
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate individual image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
num_pairs: Number of pairs to generate
|
||||
output_txt: Optional path to save pairs as txt file
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
files_dict = self._get_image_files(root_dir)
|
||||
|
||||
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
|
||||
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||
|
||||
if output_txt:
|
||||
try:
|
||||
with open(output_txt, 'w') as f:
|
||||
for file1, file2, label in pairs:
|
||||
f.write(f"{file1} {file2} {label}\n")
|
||||
self.logger.info(f"Saved pairs to {output_txt}")
|
||||
except IOError as e:
|
||||
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
|
||||
|
||||
return pairs
|
||||
|
||||
def get_group_pairs(
|
||||
self,
|
||||
root_dir: str,
|
||||
img_num: int = 20,
|
||||
group_num: int = 10,
|
||||
num_pairs: int = 5000
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
"""
|
||||
Generate grouped image pairs with labels (1=same, 0=different).
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing subfolders of images
|
||||
img_num: Minimum images required per folder
|
||||
group_num: Number of images per group
|
||||
num_pairs: Number of pairs to generate
|
||||
|
||||
Returns:
|
||||
List of (path1, path2, label) tuples
|
||||
"""
|
||||
# Filter folders with enough images
|
||||
files_dict = {
|
||||
k: v for k, v in self._get_image_files(root_dir).items()
|
||||
if len(v) >= img_num
|
||||
}
|
||||
|
||||
# Split into groups
|
||||
grouped_files = {}
|
||||
for folder, files in files_dict.items():
|
||||
random.shuffle(files)
|
||||
grouped_files[folder] = [
|
||||
files[i:i + group_num]
|
||||
for i in range(0, len(files), group_num)
|
||||
]
|
||||
|
||||
# Generate pairs
|
||||
same_pairs = self._generate_same_pairs(
|
||||
grouped_files, num_pairs, min_size=30, group_size=group_num
|
||||
)
|
||||
cross_pairs = self._generate_cross_pairs(
|
||||
grouped_files, len(same_pairs)
|
||||
)
|
||||
|
||||
pairs = same_pairs + cross_pairs
|
||||
self.logger.info(f"Generated {len(pairs)} group pairs")
|
||||
|
||||
# Save to JSON
|
||||
with open("cross_same.json", 'w') as f:
|
||||
json.dump(pairs, f)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
original_path = '/home/lc/data_center/electornic/v1/val'
|
||||
parent_dir = str(Path(original_path).parent)
|
||||
generator = PairGenerator(original_path)
|
||||
|
||||
# Example usage:
|
||||
pairs = generator.get_pairs(original_path,
|
||||
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
|
||||
# groups = generator.get_group_pairs('val') # Group pairs
|
71
tools/image_joint.py
Normal file
71
tools/image_joint.py
Normal file
@ -0,0 +1,71 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tools.getHeatMap import cal_cam
|
||||
import os
|
||||
|
||||
|
||||
def merge_imgs(img1_path, img2_path, conf, similar=None, label=None, cam=None, save_path=None):
|
||||
save = True
|
||||
position = (50, 50) # 文字的左上角坐标
|
||||
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||
# if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||
# os.makedirs(os.sep.join([save_path, str(label)]))
|
||||
# save_path = os.sep.join([save_path, str(label)])
|
||||
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||
if save_path is None:
|
||||
save_path = conf['data']['image_joint_pth']
|
||||
if not conf['heatmap']['show_heatmap']:
|
||||
img1 = Image.open(img1_path)
|
||||
img2 = Image.open(img2_path)
|
||||
img1 = img1.resize((224, 224))
|
||||
img2 = img2.resize((224, 224))
|
||||
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||
# save_path = conf['data']['image_joint_pth']
|
||||
else:
|
||||
assert cam is not None, 'cam is None'
|
||||
img1 = cam.get_hot_map(img1_path)
|
||||
img2 = cam.get_hot_map(img2_path)
|
||||
img1_ori = Image.open(img1_path)
|
||||
img2_ori = Image.open(img2_path)
|
||||
img1_ori = img1_ori.resize((224, 224))
|
||||
img2_ori = img2_ori.resize((224, 224))
|
||||
new_img = Image.new('RGB',
|
||||
(img1.width + img2.width + 10,
|
||||
img1.height + img2.width + 10))
|
||||
# save_path = conf['heatmap']['image_joint_pth']
|
||||
# print('img1_path', img1)
|
||||
# print('img2_path', img2)
|
||||
if not os.path.exists(os.sep.join([save_path, str(label)])) and (label is not None):
|
||||
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||
save_path = os.sep.join([save_path, str(label)])
|
||||
if save_path is None:
|
||||
# save_path = os.sep.join([save_path, str(label)])
|
||||
pass
|
||||
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||
img_name = os.path.basename(img1_path).split('.')[0][:30] + '_' + os.path.basename(img2_path).split('.')[0][
|
||||
:30] + '.png'
|
||||
assert img1.height == img2.height
|
||||
|
||||
|
||||
|
||||
# print('new_img', new_img)
|
||||
if not conf['heatmap']['show_heatmap']:
|
||||
new_img.paste(img1, (0, 0))
|
||||
new_img.paste(img2, (img1.width + 10, 0))
|
||||
else:
|
||||
new_img.paste(img1_ori, (10, 10))
|
||||
new_img.paste(img2_ori, (img2_ori.width + 20, 10))
|
||||
new_img.paste(img1, (10, img1.height+20))
|
||||
new_img.paste(img2, (img2.width+20, img2.height+20))
|
||||
|
||||
if similar is not None:
|
||||
if label == '1' and (similar > 0.5 or similar < 0.25):
|
||||
save = False
|
||||
elif label == '0' and similar > 0.25:
|
||||
save = False
|
||||
similar = str(similar) + '_' + str(label)
|
||||
draw = ImageDraw.Draw(new_img)
|
||||
draw.text(position, str(similar), color, font_size=36)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
img_save = os.path.join(save_path, img_name)
|
||||
if save:
|
||||
new_img.save(img_save)
|
@ -2,17 +2,29 @@ import pdb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model import resnet18
|
||||
from config import config as conf
|
||||
# from config import config as conf
|
||||
from collections import OrderedDict
|
||||
from configs import trainer_tools
|
||||
import cv2
|
||||
import yaml
|
||||
|
||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||||
# 定义模型
|
||||
if model_name == 'resnet18':
|
||||
model = resnet18(scale=0.75)
|
||||
def tranform_onnx_model():
|
||||
# # 定义模型
|
||||
# if model_name == 'resnet18':
|
||||
# model = resnet18(scale=0.75)
|
||||
|
||||
print('model_name >>> {}'.format(model_name))
|
||||
if conf.multiple_cards:
|
||||
with open('../configs/transform.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
pretrained_weights = conf['models']['model_path']
|
||||
print('model_name >>> {}'.format(conf['models']['backbone']))
|
||||
if conf['base']['distributed']:
|
||||
model = model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(pretrained_weights)
|
||||
new_state_dict = OrderedDict()
|
||||
@ -22,22 +34,8 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# try:
|
||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
# 转换为ONNX
|
||||
if model_name == 'gift_type2':
|
||||
input_shape = [1, 64, 13, 13]
|
||||
elif model_name == 'gift_type3':
|
||||
input_shape = [1, 3, 224, 224]
|
||||
else:
|
||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||||
input_shape = [1, 3, 224, 224]
|
||||
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
||||
tranform_onnx_model()
|
||||
|
@ -6,15 +6,14 @@ import time
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
from config import config as conf
|
||||
from rknn.api import RKNN
|
||||
|
||||
import config
|
||||
|
||||
import yaml
|
||||
with open('../configs/transform.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
||||
ONNX_MODEL = conf['models']['onnx_model']
|
||||
RKNN_MODEL = conf['models']['rknn_model']
|
||||
|
||||
|
||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||
@ -100,7 +99,8 @@ if __name__ == '__main__':
|
||||
target_platform='rk3588',
|
||||
model_pruning=False,
|
||||
compress_weight=False,
|
||||
single_core_mode=True)
|
||||
single_core_mode=True,
|
||||
enable_flash_attention=True)
|
||||
# rknn.config(
|
||||
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||
@ -122,7 +122,10 @@ if __name__ == '__main__':
|
||||
|
||||
# Build model
|
||||
print('--> Building model')
|
||||
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
|
||||
ret = rknn.build(do_quantization=True, # True
|
||||
# dataset='./dataset.txt',
|
||||
dataset=conf['base']['dataset'],
|
||||
rknn_batch_size=conf['models']['rknn_batch_size'])
|
||||
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
||||
if ret != 0:
|
||||
print('Build model failed!')
|
||||
|
242
tools/picdir_to_picdir_similar.py
Normal file
242
tools/picdir_to_picdir_similar.py
Normal file
@ -0,0 +1,242 @@
|
||||
from similar_analysis import SimilarAnalysis
|
||||
import os
|
||||
import pickle
|
||||
from tools.image_joint import merge_imgs
|
||||
import yaml
|
||||
from PIL import Image
|
||||
import torch
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
'''
|
||||
轨迹图与标准库之间的相似度分析
|
||||
1.用于生成轨迹图与标准库中所有图片的相似度
|
||||
2.用于分析轨迹图与标准库比对选取策略的判断
|
||||
'''
|
||||
|
||||
|
||||
class picDirSimilarAnalysis(SimilarAnalysis):
|
||||
def __init__(self):
|
||||
super(picDirSimilarAnalysis, self).__init__()
|
||||
with open('../configs/pic_pic_similar.yml', 'r') as f:
|
||||
self.conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if not os.path.exists(self.conf['data']['total_pkl']):
|
||||
# self.create_total_feature()
|
||||
self.create_total_pkl()
|
||||
if os.path.exists(self.conf['data']['total_pkl']):
|
||||
self.all_dicts = self.load_dict_from_pkl()
|
||||
|
||||
def is_image_file(self, filename):
|
||||
"""
|
||||
检查文件是否为图像文件
|
||||
"""
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||
return filename.lower().endswith(image_extensions)
|
||||
|
||||
def create_total_pkl(self): # 将目录下所有的图片特征存入pkl文件
|
||||
all_images_feature_dict = {}
|
||||
for roots, dirs, files in os.walk(self.conf['data']['data_dir']):
|
||||
for file_name in files:
|
||||
if self.is_image_file(file_name):
|
||||
try:
|
||||
print(f"处理图像 {os.sep.join([roots, file_name])}")
|
||||
feature = self.extract_features(os.sep.join([roots, file_name]))
|
||||
except Exception as e:
|
||||
print(f"处理图像 {os.sep.join([roots, file_name])} 时出错: {e}")
|
||||
feature = None
|
||||
all_images_feature_dict[os.sep.join([roots, file_name])] = feature
|
||||
if not os.path.exists(self.conf['data']['total_pkl']):
|
||||
with open(self.conf['data']['total_pkl'], 'wb') as f:
|
||||
pickle.dump(all_images_feature_dict, f)
|
||||
|
||||
def load_dict_from_pkl(self):
|
||||
with open(self.conf['data']['total_pkl'], 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
print(f"字典已从 {self.conf['data']['total_pkl']} 加载")
|
||||
return data
|
||||
|
||||
def get_image_files(self, folder_path):
|
||||
"""
|
||||
获取文件夹中的所有图像文件
|
||||
"""
|
||||
image_files = []
|
||||
for root, _, files in os.walk(folder_path):
|
||||
for file in files:
|
||||
if self.is_image_file(file):
|
||||
image_files.append(os.path.join(root, file))
|
||||
return image_files
|
||||
|
||||
def extract_features(self, image_path):
|
||||
feature_dict = self.get_feature(image_path)
|
||||
return feature_dict[image_path]
|
||||
|
||||
def create_one_similarity_matrix(self, folder1_path, folder2_path):
|
||||
images1 = self.get_image_files(folder1_path)
|
||||
images2 = self.get_image_files(folder2_path)
|
||||
|
||||
print(f"文件夹1 ({folder1_path}) 包含 {len(images1)} 张图像")
|
||||
print(f"文件夹2 ({folder2_path}) 包含 {len(images2)} 张图像")
|
||||
|
||||
if len(images1) == 0 or len(images2) == 0:
|
||||
raise ValueError("至少有一个文件夹中没有图像文件")
|
||||
|
||||
# 提取文件夹1中的所有图像特征
|
||||
features1 = []
|
||||
print("正在提取文件夹1中的图像特征...")
|
||||
for i, img_path in enumerate(images1):
|
||||
try:
|
||||
# feature = self.extract_features(img_path)
|
||||
feature = self.all_dicts[img_path]
|
||||
features1.append(feature.cpu().numpy())
|
||||
# if (i + 1) % 10 == 0:
|
||||
# print(f"已处理 {i + 1}/{len(images1)} 张图像")
|
||||
except Exception as e:
|
||||
print(f"处理图像 {img_path} 时出错: {e}")
|
||||
features1.append(None)
|
||||
|
||||
# 提取文件夹2中的所有图像特征
|
||||
features2 = []
|
||||
print("正在提取文件夹2中的图像特征...")
|
||||
for i, img_path in enumerate(images2):
|
||||
try:
|
||||
# feature = self.extract_features(img_path)
|
||||
feature = self.all_dicts[img_path]
|
||||
features2.append(feature.cpu().numpy())
|
||||
# if (i + 1) % 10 == 0:
|
||||
# print(f"已处理 {i + 1}/{len(images2)} 张图像")
|
||||
except Exception as e:
|
||||
print(f"处理图像 {img_path} 时出错: {e}")
|
||||
features2.append(None)
|
||||
|
||||
# 移除处理失败的图像
|
||||
valid_features1 = []
|
||||
valid_images1 = []
|
||||
for i, feature in enumerate(features1):
|
||||
if feature is not None:
|
||||
valid_features1.append(feature)
|
||||
valid_images1.append(images1[i])
|
||||
|
||||
valid_features2 = []
|
||||
valid_images2 = []
|
||||
for i, feature in enumerate(features2):
|
||||
if feature is not None:
|
||||
valid_features2.append(feature)
|
||||
valid_images2.append(images2[i])
|
||||
|
||||
# print(f"文件夹1中成功处理 {len(valid_features1)} 张图像")
|
||||
# print(f"文件夹2中成功处理 {len(valid_features2)} 张图像")
|
||||
|
||||
if len(valid_features1) == 0 or len(valid_features2) == 0:
|
||||
raise ValueError("没有成功处理任何图像")
|
||||
|
||||
# 计算相似度矩阵
|
||||
print("正在计算相似度矩阵...")
|
||||
similarity_matrix = cosine_similarity(valid_features1, valid_features2)
|
||||
|
||||
return similarity_matrix, valid_images1, valid_images2
|
||||
|
||||
def get_group_similarity_matrix(self, folder_path):
|
||||
tracking_folder = os.sep.join([folder_path, 'tracking'])
|
||||
standard_folder = os.sep.join([folder_path, 'standard_slim'])
|
||||
for dir_name in os.listdir(tracking_folder):
|
||||
tracking_dir = os.sep.join([tracking_folder, dir_name])
|
||||
standard_dir = os.sep.join([standard_folder, dir_name])
|
||||
similarity_matrix, valid_images1, valid_images2 = self.create_one_similarity_matrix(tracking_dir,
|
||||
standard_dir)
|
||||
mean_similarity = np.mean(similarity_matrix)
|
||||
std_similarity = np.std(similarity_matrix)
|
||||
max_similarity = np.max(similarity_matrix)
|
||||
min_similarity = np.min(similarity_matrix)
|
||||
print(f"文件夹 {dir_name} 的相似度矩阵已计算完成 "
|
||||
f"均值:{mean_similarity} 标准差:{std_similarity} 最大值:{max_similarity} 最小值:{min_similarity}")
|
||||
result = f"{os.path.basename(standard_folder)} {dir_name} {mean_similarity:.3f} {std_similarity:.3f} {max_similarity:.3f} {min_similarity:.3f}"
|
||||
with open(self.conf['data']['result_txt'], 'a') as f:
|
||||
f.write(result + '\n')
|
||||
|
||||
|
||||
def read_result_txt():
|
||||
parts = []
|
||||
value_num = 2
|
||||
with open('../configs/pic_pic_similar.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
f.close()
|
||||
with open(conf['data']['result_txt'], 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
parts.append(line.split(' '))
|
||||
parts = np.array(parts)
|
||||
print(parts)
|
||||
labels = ['Mean', 'Std', 'Max', 'Min']
|
||||
while value_num < 6:
|
||||
dicts = {}
|
||||
for barcode, value in zip(parts[:, 1], parts[:, value_num]):
|
||||
if barcode in dicts:
|
||||
dicts[barcode].append(float(value))
|
||||
else:
|
||||
dicts[barcode] = [float(value)]
|
||||
get_histogram(dicts, labels[value_num - 2])
|
||||
value_num += 1
|
||||
f.close()
|
||||
|
||||
|
||||
def get_histogram(data, label=None):
|
||||
# 准备数据
|
||||
categories = list(data.keys())
|
||||
values1 = [data[cat][0] for cat in categories] # 第一个值
|
||||
values2 = [data[cat][1] for cat in categories] # 第二个值
|
||||
|
||||
# 设置柱状图的位置
|
||||
x = np.arange(len(categories)) # 标签位置
|
||||
width = 0.35 # 柱状图的宽度
|
||||
|
||||
# 创建图形和轴
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# 绘制柱状图
|
||||
bars1 = ax.bar(x - width / 2, values1, width, label='standard', color='red', alpha=0.7)
|
||||
bars2 = ax.bar(x + width / 2, values2, width, label='standard_slim', color='green', alpha=0.7)
|
||||
|
||||
# 在每个柱状图上显示数值
|
||||
for bar in bars1:
|
||||
height = bar.get_height()
|
||||
ax.annotate(f'{height:.3f}',
|
||||
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||
xytext=(0, 3), # 3点垂直偏移
|
||||
textcoords="offset points",
|
||||
ha='center', va='bottom',
|
||||
fontsize=12)
|
||||
|
||||
for bar in bars2:
|
||||
height = bar.get_height()
|
||||
ax.annotate(f'{height:.3f}',
|
||||
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||
xytext=(0, 3), # 3点垂直偏移
|
||||
textcoords="offset points",
|
||||
ha='center', va='bottom',
|
||||
fontsize=12)
|
||||
|
||||
# 添加标签和标题
|
||||
if label is None:
|
||||
label = ''
|
||||
ax.set_xlabel('barcode')
|
||||
ax.set_ylabel('Values')
|
||||
ax.set_title(label)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(categories)
|
||||
ax.legend()
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局并显示
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
picTopic_matrix = picDirSimilarAnalysis()
|
||||
picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new')
|
||||
# read_result_txt()
|
107
tools/similar_analysis.py
Normal file
107
tools/similar_analysis.py
Normal file
@ -0,0 +1,107 @@
|
||||
from configs.utils import trainer_tools
|
||||
from test_ori import group_image, featurize, cosin_metric
|
||||
from tools.dataset import get_transform
|
||||
from tools.getHeatMap import cal_cam
|
||||
from tools.image_joint import merge_imgs
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from collections import ChainMap
|
||||
import yaml
|
||||
import os
|
||||
|
||||
|
||||
class SimilarAnalysis:
|
||||
def __init__(self):
|
||||
with open('../configs/similar_analysis.yml', 'r') as f:
|
||||
self.conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
self.model = self.initialize_model(self.conf)
|
||||
_, self.test_transform = get_transform(self.conf)
|
||||
self.cam = cal_cam(self.model, self.conf)
|
||||
|
||||
def initialize_model(self, conf):
|
||||
"""初始化模型和度量方法"""
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
print('model_path {}'.format(conf['models']['model_path']))
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]()
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
try:
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device']))
|
||||
except:
|
||||
state_dict = torch.load(conf['models']['model_path'],
|
||||
map_location=conf['base']['device'])
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k.replace("module.", "")
|
||||
new_state_dict[new_key] = v
|
||||
model.load_state_dict(new_state_dict, strict=False)
|
||||
return model.eval()
|
||||
|
||||
def get_feature(self, img_pth):
|
||||
group = group_image([img_pth], self.conf['data']['val_batch_size'])
|
||||
feature = featurize(group[0], self.test_transform, self.model, self.conf['base']['device'])
|
||||
return feature
|
||||
|
||||
def get_similarity(self, feature_dict1, feature_dict2):
|
||||
similarity = cosin_metric(feature_dict1, feature_dict2)
|
||||
print(f"Similarity: {similarity}")
|
||||
return similarity
|
||||
|
||||
def get_feature_map(self, all_imgs):
|
||||
feature_dicts = {}
|
||||
for img_pth in all_imgs:
|
||||
print(f"Processing {img_pth}")
|
||||
feature_dict = self.get_feature(img_pth)
|
||||
feature_dicts = dict(ChainMap(feature_dict, feature_dicts))
|
||||
return feature_dicts
|
||||
|
||||
def get_image_map(self):
|
||||
all_compare_img = []
|
||||
for root, dirs, files in os.walk(self.conf['data']['data_dir']):
|
||||
if len(dirs) == 2:
|
||||
dir_pth_1 = os.sep.join([root, dirs[0]])
|
||||
dir_pth_2 = os.sep.join([root, dirs[1]])
|
||||
for img_name_1 in os.listdir(dir_pth_1):
|
||||
for img_name_2 in os.listdir(dir_pth_2):
|
||||
all_compare_img.append((os.sep.join([dir_pth_1, img_name_1]),
|
||||
os.sep.join([dir_pth_2, img_name_2])))
|
||||
return all_compare_img
|
||||
|
||||
def create_total_feature(self):
|
||||
all_imgs = []
|
||||
for root, dirs, files in os.walk(self.conf['data']['data_dir']):
|
||||
if len(dirs) == 2:
|
||||
for dir_name in dirs:
|
||||
dir_pth = os.sep.join([root, dir_name])
|
||||
for img_name in os.listdir(dir_pth):
|
||||
all_imgs.append(os.sep.join([dir_pth, img_name]))
|
||||
return all_imgs
|
||||
|
||||
def get_contrast_result(self, feature_dicts, all_compare_img):
|
||||
for img_pth1, img_pth2 in all_compare_img:
|
||||
feature_dict1 = feature_dicts[img_pth1]
|
||||
feature_dict2 = feature_dicts[img_pth2]
|
||||
similarity = self.get_similarity(feature_dict1.cpu().numpy(),
|
||||
feature_dict2.cpu().numpy())
|
||||
dir_name = img_pth1.split('/')[-3]
|
||||
save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name])
|
||||
if similarity > 0.7:
|
||||
merge_imgs(img_pth1,
|
||||
img_pth2,
|
||||
self.conf,
|
||||
similarity,
|
||||
label=None,
|
||||
cam=self.cam,
|
||||
save_path=save_path)
|
||||
print(similarity)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ana = SimilarAnalysis()
|
||||
all_imgs = ana.create_total_feature()
|
||||
feature_dicts = ana.get_feature_map(all_imgs)
|
||||
all_compare_img = ana.get_image_map()
|
||||
ana.get_contrast_result(feature_dicts, all_compare_img)
|
@ -4,7 +4,7 @@ import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tools.dataset import get_transform
|
||||
from model import resnet18
|
||||
from model import resnet18, resnet34, resnet50, resnet101
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
@ -50,7 +50,16 @@ class FeatureExtractor:
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18().to(self.conf['base']['device'])
|
||||
if conf['models']['backbone'] == 'resnet18':
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet34':
|
||||
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet50':
|
||||
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet101':
|
||||
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
else:
|
||||
print("不支持的模型: {}".format(conf['models']['backbone']))
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
@ -168,7 +177,7 @@ class FeatureExtractor:
|
||||
# Validate input directory
|
||||
if not os.path.isdir(folder):
|
||||
raise ValueError(f"Invalid directory: {folder}")
|
||||
|
||||
i = 0
|
||||
# Process each barcode directory
|
||||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||
if not dirs: # Leaf directory (contains images)
|
||||
@ -180,14 +189,16 @@ class FeatureExtractor:
|
||||
ori_barcode = basename
|
||||
barcode = basename
|
||||
# Apply filter if provided
|
||||
i += 1
|
||||
print(ori_barcode, i)
|
||||
if filter and ori_barcode not in filter:
|
||||
continue
|
||||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
f.write(ori_barcode + '\n')
|
||||
f.close()
|
||||
continue
|
||||
# elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
|
||||
# logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
# with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
# f.write(ori_barcode + '\n')
|
||||
# f.close()
|
||||
# continue
|
||||
|
||||
# Process image files
|
||||
if files:
|
||||
@ -299,7 +310,8 @@ class FeatureExtractor:
|
||||
dicts['value'] = truncated_imgs_list
|
||||
if create_single_json:
|
||||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'],
|
||||
str(barcode_list[i]) + '.json')
|
||||
with open(json_path, 'w') as json_file:
|
||||
json.dump(dicts, json_file)
|
||||
else:
|
||||
@ -317,8 +329,10 @@ class FeatureExtractor:
|
||||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||
for barcode in os.listdir(pth):
|
||||
print("barcode length >> {}".format(len(barcode)))
|
||||
if len(barcode) > 13 or len(barcode) < 8:
|
||||
continue
|
||||
|
||||
# if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
|
||||
# continue
|
||||
|
||||
if filter is not None:
|
||||
f.writelines(barcode + '\n')
|
||||
if barcode in filter:
|
||||
@ -407,5 +421,5 @@ if __name__ == "__main__":
|
||||
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||||
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||||
filter=column_values,
|
||||
create_single_json=False) # False
|
||||
create_single_json=conf['save']['create_single_json']) # False
|
||||
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
||||
|
403
train_compare.py
403
train_compare.py
@ -3,28 +3,34 @@ import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from model.loss import FocalLoss
|
||||
from tools.dataset import load_data
|
||||
from tools.dataset import load_data, MultiEpochsDataLoader
|
||||
import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
with open('configs/scatter.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Data Setup
|
||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||
def load_configuration(config_path='configs/compare.yml'):
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r') as f:
|
||||
return yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
|
||||
def initialize_model_and_metric(conf, class_num):
|
||||
"""初始化模型和度量方法"""
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
metric_mapping = tr_tools.get_metric(class_num)
|
||||
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
model = backbone_mapping[conf['models']['backbone']]()
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
|
||||
@ -33,110 +39,329 @@ if conf['training']['metric'] in metric_mapping:
|
||||
else:
|
||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
metric = nn.DataParallel(metric)
|
||||
return model, metric
|
||||
|
||||
# Training Setup
|
||||
if conf['training']['loss'] == 'focal_loss':
|
||||
criterion = FocalLoss(gamma=2)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
def setup_optimizer_and_scheduler(conf, model, metric):
|
||||
"""设置优化器和学习率调度器"""
|
||||
tr_tools = trainer_tools(conf)
|
||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||
|
||||
if conf['training']['optimizer'] in optimizer_mapping:
|
||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=conf['training']['lr_step'],
|
||||
gamma=conf['training']['lr_decay']
|
||||
)
|
||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||
conf['training']['scheduler']))
|
||||
return optimizer, scheduler
|
||||
else:
|
||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||
|
||||
# Checkpoints Setup
|
||||
|
||||
def setup_loss_function(conf):
|
||||
"""配置损失函数"""
|
||||
if conf['training']['loss'] == 'focal_loss':
|
||||
return FocalLoss(gamma=2)
|
||||
else:
|
||||
return nn.CrossEntropyLoss()
|
||||
|
||||
|
||||
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||||
"""执行单个训练周期"""
|
||||
model.train()
|
||||
train_loss = 0
|
||||
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
# with torch.cuda.amp.autocast():
|
||||
embeddings = model(data)
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(embeddings, labels)
|
||||
else:
|
||||
thetas = metric(embeddings)
|
||||
loss = criterion(thetas, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
train_loss += loss.item()
|
||||
return train_loss / len(dataloader)
|
||||
|
||||
|
||||
def validate(model, metric, criterion, dataloader, device, conf):
|
||||
"""执行验证"""
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
embeddings = model(data)
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(embeddings, labels)
|
||||
else:
|
||||
thetas = metric(embeddings)
|
||||
loss = criterion(thetas, labels)
|
||||
val_loss += loss.item()
|
||||
return val_loss / len(dataloader)
|
||||
|
||||
|
||||
def save_model(model, path, is_parallel):
|
||||
"""保存模型权重"""
|
||||
if is_parallel:
|
||||
torch.save(model.module.state_dict(), path)
|
||||
else:
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
|
||||
def log_training_info(log_path, log_info):
|
||||
"""记录训练信息到日志文件"""
|
||||
with open(log_path, 'a') as f:
|
||||
f.write(log_info + '\n')
|
||||
|
||||
|
||||
def initialize_training_components(distributed=False):
|
||||
"""初始化所有训练所需组件"""
|
||||
# 加载配置
|
||||
conf = load_configuration()
|
||||
|
||||
# 初始化分布式训练相关参数
|
||||
components = {
|
||||
'conf': conf,
|
||||
'distributed': distributed,
|
||||
'device': None,
|
||||
'train_dataloader': None,
|
||||
'val_dataloader': None,
|
||||
'model': None,
|
||||
'metric': None,
|
||||
'criterion': None,
|
||||
'optimizer': None,
|
||||
'scheduler': None,
|
||||
'checkpoints': None,
|
||||
'scaler': None
|
||||
}
|
||||
|
||||
# 如果是非分布式训练,直接创建所有组件
|
||||
if not distributed:
|
||||
# 数据加载
|
||||
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||||
|
||||
train_dataloader = MultiEpochsDataLoader(train_dataloader,
|
||||
batch_size=conf['data']['train_batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=conf['data']['num_workers'],
|
||||
pin_memory=conf['base']['pin_memory'],
|
||||
drop_last=True)
|
||||
val_dataloader = MultiEpochsDataLoader(val_dataloader,
|
||||
batch_size=conf['data']['val_batch_size'],
|
||||
shuffle=False,
|
||||
num_workers=conf['data']['num_workers'],
|
||||
pin_memory=conf['base']['pin_memory'],
|
||||
drop_last=False)
|
||||
# 初始化模型和度量
|
||||
model, metric = initialize_model_and_metric(conf, class_num)
|
||||
device = conf['base']['device']
|
||||
model = model.to(device)
|
||||
metric = metric.to(device)
|
||||
|
||||
# 设置损失函数、优化器和调度器
|
||||
criterion = setup_loss_function(conf)
|
||||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||
|
||||
# 检查点目录
|
||||
checkpoints = conf['training']['checkpoints']
|
||||
os.makedirs(checkpoints, exist_ok=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('backbone>{} '.format(conf['models']['backbone']),
|
||||
'metric>{} '.format(conf['training']['metric']),
|
||||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
||||
)
|
||||
# GradScaler for mixed precision
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# 更新组件字典
|
||||
components.update({
|
||||
'train_dataloader': train_dataloader,
|
||||
'val_dataloader': val_dataloader,
|
||||
'model': model,
|
||||
'metric': metric,
|
||||
'criterion': criterion,
|
||||
'optimizer': optimizer,
|
||||
'scheduler': scheduler,
|
||||
'checkpoints': checkpoints,
|
||||
'scaler': scaler,
|
||||
'device': device
|
||||
})
|
||||
|
||||
return components
|
||||
|
||||
|
||||
def run_training_loop(components):
|
||||
"""运行完整的训练循环"""
|
||||
# 解包组件
|
||||
conf = components['conf']
|
||||
train_dataloader = components['train_dataloader']
|
||||
val_dataloader = components['val_dataloader']
|
||||
model = components['model']
|
||||
metric = components['metric']
|
||||
criterion = components['criterion']
|
||||
optimizer = components['optimizer']
|
||||
scheduler = components['scheduler']
|
||||
checkpoints = components['checkpoints']
|
||||
scaler = components['scaler']
|
||||
device = components['device']
|
||||
|
||||
# 训练状态
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
epochs = []
|
||||
temp_loss = 100
|
||||
|
||||
if conf['training']['restore']:
|
||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
||||
map_location=conf['base']['device']))
|
||||
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||||
|
||||
# 训练循环
|
||||
for e in range(conf['training']['epochs']):
|
||||
train_loss = 0
|
||||
model.train()
|
||||
|
||||
for train_data, train_labels in tqdm(train_dataloader,
|
||||
desc="Epoch {}/{}"
|
||||
.format(e, conf['training']['epochs']),
|
||||
ascii=True,
|
||||
total=len(train_dataloader)):
|
||||
train_data = train_data.to(conf['base']['device'])
|
||||
train_labels = train_labels.to(conf['base']['device'])
|
||||
|
||||
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
||||
# pdb.set_trace()
|
||||
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(train_embeddings, train_labels) # [256,357]
|
||||
else:
|
||||
thetas = metric(train_embeddings)
|
||||
tloss = criterion(thetas, train_labels)
|
||||
optimizer.zero_grad()
|
||||
tloss.backward()
|
||||
optimizer.step()
|
||||
train_loss += tloss.item()
|
||||
train_lossAvg = train_loss / len(train_dataloader)
|
||||
train_losses.append(train_lossAvg)
|
||||
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||||
train_losses.append(train_loss_avg)
|
||||
epochs.append(e)
|
||||
val_loss = 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
||||
ascii=True, total=len(val_dataloader)):
|
||||
val_data = val_data.to(conf['base']['device'])
|
||||
val_labels = val_labels.to(conf['base']['device'])
|
||||
val_embeddings = model(val_data).to(conf['base']['device'])
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(val_embeddings, val_labels)
|
||||
else:
|
||||
thetas = metric(val_embeddings)
|
||||
vloss = criterion(thetas, val_labels)
|
||||
val_loss += vloss.item()
|
||||
val_lossAvg = val_loss / len(val_dataloader)
|
||||
val_losses.append(val_lossAvg)
|
||||
if val_lossAvg < temp_loss:
|
||||
if torch.cuda.device_count() > 1:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||
else:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||
temp_loss = val_lossAvg
|
||||
|
||||
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||||
val_losses.append(val_loss_avg)
|
||||
|
||||
if val_loss_avg < temp_loss:
|
||||
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||||
temp_loss = val_loss_avg
|
||||
|
||||
scheduler.step()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
|
||||
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||
.format(datetime.now(),
|
||||
e,
|
||||
conf['training']['epochs'],
|
||||
train_loss_avg,
|
||||
val_loss_avg,
|
||||
current_lr))
|
||||
print(log_info)
|
||||
# 写入日志文件
|
||||
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
||||
f.write(log_info + '\n')
|
||||
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||
else:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||
plt.plot(epochs, train_losses, color='blue')
|
||||
plt.plot(epochs, val_losses, color='red')
|
||||
# plt.savefig('lossMobilenetv3.png')
|
||||
|
||||
# 保存最终模型
|
||||
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||||
|
||||
# 绘制损失曲线
|
||||
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||||
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||||
plt.legend()
|
||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数入口"""
|
||||
# 加载配置
|
||||
conf = load_configuration()
|
||||
|
||||
# 检查是否启用分布式训练
|
||||
distributed = conf['base']['distributed']
|
||||
|
||||
if distributed:
|
||||
# 分布式训练:使用mp.spawn启动多个进程
|
||||
local_size = torch.cuda.device_count()
|
||||
world_size = int(conf['distributed']['node_num'])*local_size
|
||||
mp.spawn(
|
||||
run_training,
|
||||
args=(conf['distributed']['node_rank'],
|
||||
local_size,
|
||||
world_size,
|
||||
conf),
|
||||
nprocs=local_size,
|
||||
join=True
|
||||
)
|
||||
else:
|
||||
# 单机训练:直接运行训练流程
|
||||
components = initialize_training_components(distributed=False)
|
||||
run_training_loop(components)
|
||||
|
||||
|
||||
def run_training(local_rank, node_rank, local_size, world_size, conf):
|
||||
"""实际执行训练的函数,供mp.spawn调用"""
|
||||
# 初始化分布式环境
|
||||
rank = local_rank + node_rank * local_size
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device('cuda', local_rank)
|
||||
|
||||
# 获取数据集而不是DataLoader
|
||||
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||||
|
||||
# 初始化模型和度量
|
||||
model, metric = initialize_model_and_metric(conf, class_num)
|
||||
model = model.to(device)
|
||||
metric = metric.to(device)
|
||||
|
||||
# 包装为DistributedDataParallel模型
|
||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
||||
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
# 设置损失函数、优化器和调度器
|
||||
criterion = setup_loss_function(conf)
|
||||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||
|
||||
# 检查点目录
|
||||
checkpoints = conf['training']['checkpoints']
|
||||
os.makedirs(checkpoints, exist_ok=True)
|
||||
|
||||
# GradScaler for mixed precision
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# 创建分布式采样器
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
|
||||
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
|
||||
train_dataloader = MultiEpochsDataLoader(
|
||||
train_dataset,
|
||||
batch_size=conf['data']['train_batch_size'],
|
||||
sampler=train_sampler,
|
||||
num_workers=conf['data']['num_workers'],
|
||||
pin_memory=conf['base']['pin_memory'],
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
val_dataloader = MultiEpochsDataLoader(
|
||||
val_dataset,
|
||||
batch_size=conf['data']['val_batch_size'],
|
||||
sampler=val_sampler,
|
||||
num_workers=conf['data']['num_workers'],
|
||||
pin_memory=conf['base']['pin_memory'],
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
# 构建组件字典
|
||||
components = {
|
||||
'conf': conf,
|
||||
'train_dataloader': train_dataloader,
|
||||
'val_dataloader': val_dataloader,
|
||||
'model': model,
|
||||
'metric': metric,
|
||||
'criterion': criterion,
|
||||
'optimizer': optimizer,
|
||||
'scheduler': scheduler,
|
||||
'checkpoints': checkpoints,
|
||||
'scaler': scaler,
|
||||
'device': device,
|
||||
'distributed': True # 因为是在mp.spawn中运行
|
||||
}
|
||||
|
||||
# 运行训练循环
|
||||
run_training_loop(components)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
233
train_compare.py.bak
Normal file
233
train_compare.py.bak
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.loss import FocalLoss
|
||||
from tools.dataset import load_data
|
||||
import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def load_configuration(config_path='configs/scatter.yml'):
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r') as f:
|
||||
return yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
|
||||
def initialize_model_and_metric(conf, class_num):
|
||||
"""初始化模型和度量方法"""
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
metric_mapping = tr_tools.get_metric(class_num)
|
||||
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]()
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
|
||||
if conf['training']['metric'] in metric_mapping:
|
||||
metric = metric_mapping[conf['training']['metric']]()
|
||||
else:
|
||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||
|
||||
return model, metric
|
||||
|
||||
|
||||
def setup_optimizer_and_scheduler(conf, model, metric):
|
||||
"""设置优化器和学习率调度器"""
|
||||
tr_tools = trainer_tools(conf)
|
||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||
|
||||
if conf['training']['optimizer'] in optimizer_mapping:
|
||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||
conf['training']['scheduler']))
|
||||
return optimizer, scheduler
|
||||
else:
|
||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||
|
||||
|
||||
def setup_loss_function(conf):
|
||||
"""配置损失函数"""
|
||||
if conf['training']['loss'] == 'focal_loss':
|
||||
return FocalLoss(gamma=2)
|
||||
else:
|
||||
return nn.CrossEntropyLoss()
|
||||
|
||||
|
||||
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||||
"""执行单个训练周期"""
|
||||
model.train()
|
||||
train_loss = 0
|
||||
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
embeddings = model(data)
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(embeddings, labels)
|
||||
else:
|
||||
thetas = metric(embeddings)
|
||||
loss = criterion(thetas, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
train_loss += loss.item()
|
||||
return train_loss / len(dataloader)
|
||||
|
||||
|
||||
def validate(model, metric, criterion, dataloader, device, conf):
|
||||
"""执行验证"""
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
embeddings = model(data)
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(embeddings, labels)
|
||||
else:
|
||||
thetas = metric(embeddings)
|
||||
loss = criterion(thetas, labels)
|
||||
val_loss += loss.item()
|
||||
return val_loss / len(dataloader)
|
||||
|
||||
|
||||
def save_model(model, path, is_parallel):
|
||||
"""保存模型权重"""
|
||||
if is_parallel:
|
||||
torch.save(model.module.state_dict(), path)
|
||||
else:
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
|
||||
def log_training_info(log_path, log_info):
|
||||
"""记录训练信息到日志文件"""
|
||||
with open(log_path, 'a') as f:
|
||||
f.write(log_info + '\n')
|
||||
|
||||
|
||||
def initialize_training_components():
|
||||
"""初始化所有训练所需组件"""
|
||||
# 加载配置
|
||||
conf = load_configuration()
|
||||
|
||||
# 数据加载
|
||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||
|
||||
# 初始化模型和度量
|
||||
model, metric = initialize_model_and_metric(conf, class_num)
|
||||
device = conf['base']['device']
|
||||
model = model.to(device)
|
||||
metric = metric.to(device)
|
||||
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
metric = nn.DataParallel(metric)
|
||||
|
||||
# 设置损失函数、优化器和调度器
|
||||
criterion = setup_loss_function(conf)
|
||||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||
|
||||
# 检查点目录
|
||||
checkpoints = conf['training']['checkpoints']
|
||||
os.makedirs(checkpoints, exist_ok=True)
|
||||
|
||||
# GradScaler for mixed precision
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
return {
|
||||
'conf': conf,
|
||||
'train_dataloader': train_dataloader,
|
||||
'val_dataloader': val_dataloader,
|
||||
'model': model,
|
||||
'metric': metric,
|
||||
'criterion': criterion,
|
||||
'optimizer': optimizer,
|
||||
'scheduler': scheduler,
|
||||
'checkpoints': checkpoints,
|
||||
'scaler': scaler,
|
||||
'device': device
|
||||
}
|
||||
|
||||
|
||||
def run_training_loop(components):
|
||||
"""运行完整的训练循环"""
|
||||
# 解包组件
|
||||
conf = components['conf']
|
||||
train_dataloader = components['train_dataloader']
|
||||
val_dataloader = components['val_dataloader']
|
||||
model = components['model']
|
||||
metric = components['metric']
|
||||
criterion = components['criterion']
|
||||
optimizer = components['optimizer']
|
||||
scheduler = components['scheduler']
|
||||
checkpoints = components['checkpoints']
|
||||
scaler = components['scaler']
|
||||
device = components['device']
|
||||
|
||||
# 训练状态
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
epochs = []
|
||||
temp_loss = 100
|
||||
|
||||
if conf['training']['restore']:
|
||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||||
|
||||
# 训练循环
|
||||
for e in range(conf['training']['epochs']):
|
||||
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||||
train_losses.append(train_loss_avg)
|
||||
epochs.append(e)
|
||||
|
||||
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||||
val_losses.append(val_loss_avg)
|
||||
|
||||
if val_loss_avg < temp_loss:
|
||||
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||||
temp_loss = val_loss_avg
|
||||
|
||||
scheduler.step()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||
.format(datetime.now(),
|
||||
e,
|
||||
conf['training']['epochs'],
|
||||
train_loss_avg,
|
||||
val_loss_avg,
|
||||
current_lr))
|
||||
print(log_info)
|
||||
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||
|
||||
# 保存最终模型
|
||||
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||||
|
||||
# 绘制损失曲线
|
||||
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||||
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||||
plt.legend()
|
||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 初始化训练组件
|
||||
components = initialize_training_components()
|
||||
|
||||
# 运行训练循环
|
||||
run_training_loop(components)
|
Reference in New Issue
Block a user