Compare commits
17 Commits
dev
...
bc896fc688
Author | SHA1 | Date | |
---|---|---|---|
bc896fc688 | |||
27ffb62223 | |||
ebba07d1ca | |||
3392d76e38 | |||
54898e30ec | |||
09f41f6289 | |||
0701538a73 | |||
6640f2bc5e | |||
bcbabd9313 | |||
5deaf4727f | |||
2219c0a303 | |||
537ed838fc | |||
061820c34f | |||
bf9604ec29 | |||
180a41ae90 | |||
e27e6c3d5b | |||
1803f319a5 |
349
.idea/CopilotChatHistory.xml
generated
349
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,102 @@
|
|||||||
<component name="CopilotChatHistory">
|
<component name="CopilotChatHistory">
|
||||||
<option name="conversations">
|
<option name="conversations">
|
||||||
<list>
|
<list>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1754286137102" />
|
||||||
|
<option name="id" value="0198739a1f0e75c38b0579ade7b34050" />
|
||||||
|
<option name="title" value="新对话 2025年8月04日 13:42:17" />
|
||||||
|
<option name="updateTime" value="1754286137102" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1753932970546" />
|
||||||
|
<option name="id" value="01985e8d3a3170bf871ba640afdf246d" />
|
||||||
|
<option name="title" value="新对话 2025年7月31日 11:36:10" />
|
||||||
|
<option name="updateTime" value="1753932970546" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1753932554257" />
|
||||||
|
<option name="id" value="01985e86e01170d6a09dca496e3dad46" />
|
||||||
|
<option name="title" value="新对话 2025年7月31日 11:29:14" />
|
||||||
|
<option name="updateTime" value="1753932554257" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1753680371881" />
|
||||||
|
<option name="id" value="01984f7ee0a9779aabcd3f1671b815b3" />
|
||||||
|
<option name="title" value="新对话 2025年7月28日 13:26:11" />
|
||||||
|
<option name="updateTime" value="1753680371881" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1753405176017" />
|
||||||
|
<option name="id" value="01983f17b8d173dda926d0ffa5422bbf" />
|
||||||
|
<option name="title" value="新对话 2025年7月25日 08:59:36" />
|
||||||
|
<option name="updateTime" value="1753405176017" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1753065086744" />
|
||||||
|
<option name="id" value="01982ad25f18712f862c5c18b627f40d" />
|
||||||
|
<option name="title" value="新对话 2025年7月21日 10:31:26" />
|
||||||
|
<option name="updateTime" value="1753065086744" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1752195523240" />
|
||||||
|
<option name="id" value="0197f6fde2a87e68b893b3a36dfc838f" />
|
||||||
|
<option name="title" value="新对话 2025年7月11日 08:58:43" />
|
||||||
|
<option name="updateTime" value="1752195523240" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1752114061266" />
|
||||||
|
<option name="id" value="0197f222dfd27515a3dbfea638532ee5" />
|
||||||
|
<option name="title" value="新对话 2025年7月10日 10:21:01" />
|
||||||
|
<option name="updateTime" value="1752114061266" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1751970991660" />
|
||||||
|
<option name="id" value="0197e99bce2c7a569dee594fb9b6e152" />
|
||||||
|
<option name="title" value="新对话 2025年7月08日 18:36:31" />
|
||||||
|
<option name="updateTime" value="1751970991660" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1751441743239" />
|
||||||
|
<option name="id" value="0197ca101d8771bd80f2bc4aaf1a8f19" />
|
||||||
|
<option name="title" value="新对话 2025年7月02日 15:35:43" />
|
||||||
|
<option name="updateTime" value="1751441743239" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1751441398488" />
|
||||||
|
<option name="id" value="0197ca0adad875168de40d792dcb7b4c" />
|
||||||
|
<option name="title" value="新对话 2025年7月02日 15:29:58" />
|
||||||
|
<option name="updateTime" value="1751441398488" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1750474299387" />
|
||||||
|
<option name="id" value="0197906617fb7194a0407baae2b1e2eb" />
|
||||||
|
<option name="title" value="新对话 2025年6月21日 10:51:39" />
|
||||||
|
<option name="updateTime" value="1750474299387" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1749793513436" />
|
||||||
|
<option name="id" value="019767d21fdc756ba782b33c8b14cdf1" />
|
||||||
|
<option name="title" value="新对话 2025年6月13日 13:45:13" />
|
||||||
|
<option name="updateTime" value="1749793513436" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1749792408202" />
|
||||||
|
<option name="id" value="019767c1428a7aa28682039d57d19778" />
|
||||||
|
<option name="title" value="新对话 2025年6月13日 13:26:48" />
|
||||||
|
<option name="updateTime" value="1749792408202" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1749718122230" />
|
||||||
|
<option name="id" value="01976353bef6703884544447c919013c" />
|
||||||
|
<option name="title" value="新对话 2025年6月12日 16:48:42" />
|
||||||
|
<option name="updateTime" value="1749718122230" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1749648208122" />
|
||||||
|
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
|
||||||
|
<option name="title" value="新对话 2025年6月11日 21:23:28" />
|
||||||
|
<option name="updateTime" value="1749648208122" />
|
||||||
|
</Conversation>
|
||||||
<Conversation>
|
<Conversation>
|
||||||
<option name="createTime" value="1749522765718" />
|
<option name="createTime" value="1749522765718" />
|
||||||
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
|
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
|
||||||
@ -57,16 +153,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -91,16 +178,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
</list>
|
</list>
|
||||||
@ -135,16 +213,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -169,16 +238,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -203,16 +263,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -237,16 +288,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -271,16 +313,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -305,16 +338,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -339,16 +363,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -373,16 +388,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -407,16 +413,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -441,16 +438,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -475,16 +463,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -509,16 +488,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -543,16 +513,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -577,16 +538,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -611,16 +563,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -645,16 +588,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -679,16 +613,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -713,16 +638,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -773,16 +689,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -807,16 +714,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -841,16 +739,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
</list>
|
</list>
|
||||||
|
720
.idea/CopilotWebChatHistory.xml
generated
720
.idea/CopilotWebChatHistory.xml
generated
File diff suppressed because one or more lines are too long
@ -15,25 +15,28 @@ base:
|
|||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet50'
|
||||||
channel_ratio: 0.75
|
channel_ratio: 1.0
|
||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
training:
|
training:
|
||||||
epochs: 600 # 总训练轮次
|
epochs: 400 # 总训练轮次
|
||||||
batch_size: 128 # 批次大小
|
batch_size: 128 # 批次大小
|
||||||
lr: 0.001 # 初始学习率
|
lr: 0.01 # 初始学习率
|
||||||
optimizer: "sgd" # 优化器类型
|
optimizer: "sgd" # 优化器类型
|
||||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
lr_step: 10 # 学习率调整间隔(epoch)
|
lr_step: 5 # 学习率调整间隔(epoch)
|
||||||
lr_decay: 0.98 # 学习率衰减率
|
lr_decay: 0.95 # 学习率衰减率
|
||||||
weight_decay: 0.0005 # 权重衰减
|
weight_decay: 0.0005 # 权重衰减
|
||||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
scheduler: "step" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
checkpoints: "./checkpoints/resnet50_electornic_20250807/" # 模型保存目录
|
||||||
restore: false
|
restore: false
|
||||||
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径
|
||||||
|
cosine_t_0: 10 # 初始周期长度
|
||||||
|
cosine_t_mult: 1 # 周期长度倍率
|
||||||
|
cosine_eta_min: 0.00001 # 最小学习率
|
||||||
|
|
||||||
# 验证参数
|
# 验证参数
|
||||||
validation:
|
validation:
|
||||||
@ -46,8 +49,8 @@ data:
|
|||||||
train_batch_size: 128 # 训练批次大小
|
train_batch_size: 128 # 训练批次大小
|
||||||
val_batch_size: 128 # 验证批次大小
|
val_batch_size: 128 # 验证批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
data_train_dir: "../data_center/electornic/v1/train" # 训练数据集根目录
|
||||||
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
data_val_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
@ -59,7 +62,7 @@ transform:
|
|||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
logging_dir: "./logs" # 日志保存目录
|
logging_dir: "./logs/resnet50_electornic_log" # 日志保存目录
|
||||||
tensorboard: true # 是否启用TensorBoard
|
tensorboard: true # 是否启用TensorBoard
|
||||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ data:
|
|||||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
train_batch_size: 128 # 训练批次大小
|
train_batch_size: 128 # 训练批次大小
|
||||||
val_batch_size: 100 # 验证批次大小
|
val_batch_size: 100 # 验证批次大小
|
||||||
num_workers: 4 # 数据加载线程数
|
num_workers: 16 # 数据加载线程数
|
||||||
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||||
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||||
|
|
||||||
|
53
configs/pic_pic_similar.yml
Normal file
53
configs/pic_pic_similar.yml
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# configs/similar_analysis.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 0.75
|
||||||
|
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||||
|
|
||||||
|
heatmap:
|
||||||
|
feature_layer: "layer4"
|
||||||
|
show_heatmap: true
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_batch_size: 128 # 训练批次大小
|
||||||
|
val_batch_size: 8 # 验证批次大小
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
data_dir: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix"
|
||||||
|
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
|
||||||
|
total_pkl: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix/total.pkl"
|
||||||
|
result_txt: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix/result.txt"
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
event:
|
||||||
|
oneToOne_max_th: 0.9
|
||||||
|
oneToSn_min_th: 0.6
|
||||||
|
event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
|
||||||
|
stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
|
||||||
|
pickle_path: "event.pickle"
|
@ -18,20 +18,20 @@ models:
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
training:
|
training:
|
||||||
epochs: 300 # 总训练轮次
|
epochs: 800 # 总训练轮次
|
||||||
batch_size: 64 # 批次大小
|
batch_size: 64 # 批次大小
|
||||||
lr: 0.005 # 初始学习率
|
lr: 0.01 # 初始学习率
|
||||||
optimizer: "sgd" # 优化器类型
|
optimizer: "sgd" # 优化器类型
|
||||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
lr_step: 10 # 学习率调整间隔(epoch)
|
lr_step: 10 # 学习率调整间隔(epoch)
|
||||||
lr_decay: 0.98 # 学习率衰减率
|
lr_decay: 0.95 # 学习率衰减率
|
||||||
weight_decay: 0.0005 # 权重衰减
|
weight_decay: 0.0005 # 权重衰减
|
||||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
|
checkpoints: "./checkpoints/resnet18_scatter_7.4/" # 模型保存目录
|
||||||
restore: True
|
restore: false
|
||||||
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
|
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ data:
|
|||||||
train_batch_size: 128 # 训练批次大小
|
train_batch_size: 128 # 训练批次大小
|
||||||
val_batch_size: 100 # 验证批次大小
|
val_batch_size: 100 # 验证批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
|
data_train_dir: "../data_center/scatter/v4/train" # 训练数据集根目录
|
||||||
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
|
data_val_dir: "../data_center/scatter/v4/val" # 验证数据集根目录
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
@ -59,7 +59,7 @@ transform:
|
|||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
|
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
|
||||||
tensorboard: true # 是否启用TensorBoard
|
tensorboard: true # 是否启用TensorBoard
|
||||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
19
configs/scatter_data.yml
Normal file
19
configs/scatter_data.yml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# configs/scatter_data.yml
|
||||||
|
# 专为散称前处理的配置文件
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
source_dir: "../../data_center/electornic/source" # 原始数据
|
||||||
|
train_dir: "../../data_center/electornic/v1/train" # 训练数据集根目录
|
||||||
|
val_dir: "../../data_center/electornic/v1/val" # 验证数据集根目录
|
||||||
|
extra_dir: "../../data_center/electornic/v1/extra" # 验证数据集根目录
|
||||||
|
split_ratio: 0.9
|
||||||
|
max_files: 10 # 数据集小于该阈值则归纳至extra
|
||||||
|
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
|
53
configs/similar_analysis.yml
Normal file
53
configs/similar_analysis.yml
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# configs/similar_analysis.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 0.75
|
||||||
|
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||||
|
|
||||||
|
heatmap:
|
||||||
|
feature_layer: "layer4"
|
||||||
|
show_heatmap: true
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_batch_size: 128 # 训练批次大小
|
||||||
|
val_batch_size: 8 # 验证批次大小
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
data_dir: "/home/lc/data_center/image_analysis/error_compare_subimg"
|
||||||
|
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
event:
|
||||||
|
oneToOneTxt: "/home/lc/detecttracking/oneToOne.txt"
|
||||||
|
oneToSnTxt: "/home/lc/detecttracking/oneToSn.txt"
|
||||||
|
oneToOne_max_th: 0.9
|
||||||
|
oneToSn_min_th: 0.6
|
||||||
|
event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
|
||||||
|
stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
|
||||||
|
pickle_path: "event.pickle"
|
32
configs/sub_data.yml
Normal file
32
configs/sub_data.yml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# configs/sub_data.yml
|
||||||
|
# 专为对比模型训练的数据集设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
source_dir: "../../data_center/contrast_data/total" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_dir: "../../data_center/contrast_data/v1/train" # 训练数据集根目录
|
||||||
|
val_dir: "../../data_center/contrast_data/v1/val" # 验证数据集根目录
|
||||||
|
data_extra_dir: "../../data_center/contrast_data/v1/extra"
|
||||||
|
max_files_ratio: 0.1
|
||||||
|
min_files: 10
|
||||||
|
split_ratio: 0.9
|
||||||
|
combine_scr_dir: "../../data_center/contrast_data/v1/val" # 合并数据集源目录·
|
||||||
|
combine_dst_dir: "../../data_center/contrast_data/v2/val" # 合并数据集目标目录
|
||||||
|
|
||||||
|
|
||||||
|
extend:
|
||||||
|
extend_same_dir: true
|
||||||
|
extend_extra: true
|
||||||
|
extend_extra_dir: "../../data_center/contrast_data/v1/extra" # 扩展测试集数据
|
||||||
|
extend_train: true
|
||||||
|
extend_train_dir: "../../data_center/contrast_data/v1/train" # 训练接数据扩展
|
||||||
|
|
||||||
|
limit:
|
||||||
|
count_limit: true
|
||||||
|
limit_count: 200
|
||||||
|
limit_dir: "../../data_center/contrast_data/v1/train" # 限制单个样本数量
|
||||||
|
|
||||||
|
control:
|
||||||
|
combine: true # 是否进行子类数据集合并
|
||||||
|
split: false # 子类数据集拆解与扩增
|
@ -8,23 +8,28 @@ base:
|
|||||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
embedding_size: 256 # 特征维度
|
embedding_size: 256 # 特征维度
|
||||||
pin_memory: true # 是否启用pin_memory
|
pin_memory: true # 是否启用pin_memory
|
||||||
distributed: true # 是否启用分布式训练
|
distributed: false # 是否启用分布式训练 启用热力图时不能用分布式训练
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 1.0
|
channel_ratio: 0.75
|
||||||
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
|
model_path: "checkpoints/resnet18_1009/best.pth"
|
||||||
|
#resnet18_20250715_scale=0.75_sub
|
||||||
|
#resnet18_20250718_scale=0.75_nosub
|
||||||
half: false # 是否启用半精度测试(fp16)
|
half: false # 是否启用半精度测试(fp16)
|
||||||
|
contrast_learning: false
|
||||||
|
|
||||||
# 数据配置
|
# 数据配置
|
||||||
data:
|
data:
|
||||||
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
|
||||||
test_batch_size: 128 # 训练批次大小
|
test_batch_size: 128 # 训练批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
test_dir: "../data_center/scatter/" # 验证数据集根目录
|
test_dir: "../data_center/contrast_data/v1/extra" # 验证数据集根目录
|
||||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||||
test_list: "../data_center/scatter/val_pair.txt"
|
test_list: "../data_center/contrast_data/v1/extra_cross_same.txt"
|
||||||
|
group_test: false
|
||||||
|
save_image_joint: true
|
||||||
|
image_joint_pth: "./joint_images"
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
@ -34,6 +39,11 @@ transform:
|
|||||||
RandomRotation: 180 # 随机旋转角度
|
RandomRotation: 180 # 随机旋转角度
|
||||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
heatmap:
|
||||||
|
image_joint_pth: "./heatmap_joint_images"
|
||||||
|
feature_layer: "layer4"
|
||||||
|
show_heatmap: true
|
||||||
|
|
||||||
save:
|
save:
|
||||||
save_dir: ""
|
save_dir: ""
|
||||||
save_name: ""
|
save_name: ""
|
||||||
|
28
configs/transform.yml
Normal file
28
configs/transform.yml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# configs/transform.yml
|
||||||
|
# pth转换onnx配置文件
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||||
|
seed: 42 # 随机种子(保证可复现性)
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 1.0
|
||||||
|
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||||
|
onnx_model: "../checkpoints/resnet18_3399_sancheng/best.onnx"
|
||||||
|
rknn_model: "../checkpoints/resnet18_3399_sancheng/best_rknn2.3.2_RK3566.rknn"
|
||||||
|
rknn_batch_size: 1
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
@ -1,4 +1,4 @@
|
|||||||
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||||
from timm.models import vit_base_patch16_224 as vit_base_16
|
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||||
from model.metric import ArcFace, CosFace
|
from model.metric import ArcFace, CosFace
|
||||||
@ -14,6 +14,8 @@ class trainer_tools:
|
|||||||
def get_backbone(self):
|
def get_backbone(self):
|
||||||
backbone_mapping = {
|
backbone_mapping = {
|
||||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||||
|
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
|
||||||
|
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
|
||||||
'mobilevit_s': lambda: mobilevit_s(),
|
'mobilevit_s': lambda: mobilevit_s(),
|
||||||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||||
@ -54,3 +56,24 @@ class trainer_tools:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
return optimizer_mapping
|
return optimizer_mapping
|
||||||
|
|
||||||
|
def get_scheduler(self, optimizer):
|
||||||
|
scheduler_mapping = {
|
||||||
|
'step': lambda: optim.lr_scheduler.StepLR(
|
||||||
|
optimizer,
|
||||||
|
step_size=self.conf['training']['lr_step'],
|
||||||
|
gamma=self.conf['training']['lr_decay']
|
||||||
|
),
|
||||||
|
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
T_max=self.conf['training']['epochs'],
|
||||||
|
eta_min=self.conf['training']['cosine_eta_min']
|
||||||
|
),
|
||||||
|
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer,
|
||||||
|
T_0=self.conf['training'].get('cosine_t_0', 10),
|
||||||
|
T_mult=self.conf['training'].get('cosine_t_mult', 1),
|
||||||
|
eta_min=self.conf['training'].get('cosine_eta_min', 0)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return scheduler_mapping
|
||||||
|
@ -14,7 +14,7 @@ base:
|
|||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 0.75
|
channel_ratio: 0.75
|
||||||
checkpoints: "../checkpoints/resnet18_1009/best.pth"
|
checkpoints: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
|
||||||
|
|
||||||
# 数据配置
|
# 数据配置
|
||||||
data:
|
data:
|
||||||
@ -22,7 +22,7 @@ data:
|
|||||||
test_batch_size: 128 # 验证批次大小
|
test_batch_size: 128 # 验证批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
half: true # 是否启用半精度数据
|
half: true # 是否启用半精度数据
|
||||||
img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
|
img_dirs_path: "/home/lc/data_center/baseStlib/pic/stlib_base" # base标准库图片存储路径
|
||||||
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
|
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
|
||||||
xlsx_pth: false # 过滤商品, 默认None不进行过滤
|
xlsx_pth: false # 过滤商品, 默认None不进行过滤
|
||||||
|
|
||||||
@ -42,6 +42,7 @@ logging:
|
|||||||
|
|
||||||
save:
|
save:
|
||||||
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
|
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
|
||||||
json_path: "../data/feature_json_compare/" # 保存单个json文件
|
json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub" # 保存单个json文件路径
|
||||||
error_barcodes: "error_barcodes.txt"
|
error_barcodes: "error_barcodes.txt"
|
||||||
barcodes_statistics: "../search_library/barcodes_statistics.txt"
|
barcodes_statistics: "../search_library/barcodes_statistics.txt"
|
||||||
|
create_single_json: true # 是否保存单个json文件
|
0
data_preprocessing/__init__.py
Normal file
0
data_preprocessing/__init__.py
Normal file
25
data_preprocessing/combine_sub_class.py
Normal file
25
data_preprocessing/combine_sub_class.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
def combine_dirs(conf):
|
||||||
|
source_root = conf['data']['combine_scr_dir']
|
||||||
|
target_root = conf['data']['combine_dst_dir']
|
||||||
|
for roots, dir_names, files in os.walk(source_root):
|
||||||
|
for dir_name in dir_names:
|
||||||
|
source_dir = os.path.join(roots, dir_name)
|
||||||
|
target_dir = os.path.join(target_root, dir_name.split('_')[0])
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.mkdir(target_dir)
|
||||||
|
for filename in os.listdir(source_dir):
|
||||||
|
print(filename)
|
||||||
|
source_file = os.sep.join([source_dir, filename])
|
||||||
|
target_file = os.sep.join([target_dir, filename])
|
||||||
|
shutil.copy(source_file, target_file)
|
||||||
|
# print(f"已复制目录 {source_dir} 到 {target_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# source_root = r'scatter_mini'
|
||||||
|
# target_root = r'C:\Users\123\Desktop\scatter-1'
|
||||||
|
# # combine_dirs(conf)
|
111
data_preprocessing/create_extra.py
Normal file
111
data_preprocessing/create_extra.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
def count_files(directory):
|
||||||
|
"""统计目录中的文件数量"""
|
||||||
|
try:
|
||||||
|
return len([f for f in os.listdir(directory)
|
||||||
|
if os.path.isfile(os.path.join(directory, f))])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"无法统计目录 {directory}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def clear_empty_dirs(path):
|
||||||
|
"""
|
||||||
|
删除空目录
|
||||||
|
:param path: 目录路径
|
||||||
|
"""
|
||||||
|
for root, dirs, files in os.walk(path, topdown=False):
|
||||||
|
for dir_name in dirs:
|
||||||
|
dir_path = os.path.join(root, dir_name)
|
||||||
|
try:
|
||||||
|
if not os.listdir(dir_path):
|
||||||
|
os.rmdir(dir_path)
|
||||||
|
print(f"Deleted empty directory: {dir_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e.strerror}")
|
||||||
|
|
||||||
|
def get_max_files(conf):
|
||||||
|
max_files_ratio = conf['data']['max_files_ratio']
|
||||||
|
files_number = []
|
||||||
|
for root, dirs, files in os.walk(conf['data']['source_dir']):
|
||||||
|
if len(dirs) == 0:
|
||||||
|
if len(files) == 0:
|
||||||
|
print(root, dirs,files)
|
||||||
|
files_number.append(len(files))
|
||||||
|
files_number = sorted(files_number, reverse=False)
|
||||||
|
max_files = files_number[int(max_files_ratio * len(files_number))]
|
||||||
|
print(f"max_files: {max_files}")
|
||||||
|
if max_files < conf['data']['min_files']:
|
||||||
|
max_files = conf['data']['min_files']
|
||||||
|
return max_files
|
||||||
|
def megre_subdirs(pth):
|
||||||
|
for roots, dir_names, files in os.walk(pth):
|
||||||
|
print(f"image {dir_names}")
|
||||||
|
for image in dir_names:
|
||||||
|
inner_dir_path = os.path.join(pth, image)
|
||||||
|
for inner_roots, inner_dirs, inner_files in os.walk(inner_dir_path):
|
||||||
|
for inner_dir in inner_dirs:
|
||||||
|
src_dir = os.path.join(inner_roots, inner_dir)
|
||||||
|
dest_dir = os.path.join(pth, inner_dir)
|
||||||
|
# shutil.copytree(src_dir, dest_dir)
|
||||||
|
shutil.move(src_dir, dest_dir)
|
||||||
|
print(f"Copied {inner_dir} to {pth}")
|
||||||
|
clear_empty_dirs(pth)
|
||||||
|
|
||||||
|
|
||||||
|
# def split_subdirs(source_dir, target_dir, max_files=10):
|
||||||
|
def split_subdirs(conf):
|
||||||
|
"""
|
||||||
|
复制文件数≤max_files的子目录到目标目录
|
||||||
|
:param source_dir: 源目录路径
|
||||||
|
:param target_dir: 目标目录路径
|
||||||
|
:param max_files: 最大文件数阈值
|
||||||
|
"""
|
||||||
|
source_dir = conf['data']['source_dir']
|
||||||
|
target_extra_dir = conf['data']['data_extra_dir']
|
||||||
|
train_dir = conf['data']['train_dir']
|
||||||
|
max_files = get_max_files(conf)
|
||||||
|
megre_subdirs(source_dir) # 合并子目录,删除上级目录
|
||||||
|
# 创建目标目录
|
||||||
|
Path(target_extra_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"开始处理目录: {source_dir}")
|
||||||
|
print(f"目标目录: {target_extra_dir}")
|
||||||
|
print(f"筛选条件: 文件数 ≤ {max_files}\n")
|
||||||
|
|
||||||
|
# 遍历源目录
|
||||||
|
for subdir in os.listdir(source_dir):
|
||||||
|
subdir_path = os.path.join(source_dir, subdir)
|
||||||
|
|
||||||
|
if not os.path.isdir(subdir_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_count = count_files(subdir_path)
|
||||||
|
print(f"复制 {subdir} (包含 {file_count} 个文件)")
|
||||||
|
if file_count <= max_files:
|
||||||
|
dest_path = os.path.join(target_extra_dir, subdir)
|
||||||
|
else:
|
||||||
|
dest_path = os.path.join(train_dir, subdir)
|
||||||
|
# 如果目标目录已存在则跳过
|
||||||
|
if os.path.exists(dest_path):
|
||||||
|
print(f"目录已存在,跳过: {dest_path}")
|
||||||
|
continue
|
||||||
|
print(f"复制 {subdir} (包含 {file_count} 个文件) 至 {dest_path}")
|
||||||
|
shutil.copytree(subdir_path, dest_path)
|
||||||
|
# shutil.move(subdir_path, dest_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理目录 {subdir} 时出错: {e}")
|
||||||
|
|
||||||
|
print("\n处理完成")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 配置路径
|
||||||
|
with open('../configs/sub_data.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
# 执行复制操作
|
||||||
|
split_subdirs(conf)
|
72
data_preprocessing/data_split.py
Normal file
72
data_preprocessing/data_split.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
def is_image_file(filename):
|
||||||
|
"""检查文件是否为图像文件"""
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
|
return filename.lower().endswith(image_extensions)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def split_directory(conf):
|
||||||
|
"""
|
||||||
|
分割目录中的图像文件到train和val目录
|
||||||
|
:param src_dir: 源目录路径
|
||||||
|
:param train_dir: 训练集目录路径
|
||||||
|
:param val_dir: 验证集目录路径
|
||||||
|
:param split_ratio: 训练集比例(默认0.9)
|
||||||
|
"""
|
||||||
|
# 创建目标目录
|
||||||
|
train_dir = conf['data']['train_dir']
|
||||||
|
val_dir = conf['data']['val_dir']
|
||||||
|
split_ratio = conf['data']['split_ratio']
|
||||||
|
Path(val_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 遍历源目录
|
||||||
|
for root, dirs, files in os.walk(train_dir):
|
||||||
|
# 获取相对路径(train_dir)
|
||||||
|
rel_path = os.path.relpath(root, train_dir)
|
||||||
|
|
||||||
|
# 跳过当前目录(.)
|
||||||
|
if rel_path == '.':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建对应的目标子目录
|
||||||
|
val_subdir = os.path.join(val_dir, rel_path)
|
||||||
|
os.makedirs(val_subdir, exist_ok=True)
|
||||||
|
|
||||||
|
# 筛选图像文件
|
||||||
|
image_files = [f for f in files if is_image_file(f)]
|
||||||
|
if not image_files:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 随机打乱文件列表
|
||||||
|
random.shuffle(image_files)
|
||||||
|
|
||||||
|
# 计算分割点
|
||||||
|
split_point = int(len(image_files) * split_ratio)
|
||||||
|
|
||||||
|
# 复制文件到验证集
|
||||||
|
for file in image_files[split_point:]:
|
||||||
|
src = os.path.join(root, file)
|
||||||
|
dst = os.path.join(val_subdir, file)
|
||||||
|
# shutil.copy2(src, dst)
|
||||||
|
shutil.move(src, dst)
|
||||||
|
|
||||||
|
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
|
||||||
|
|
||||||
|
def control_train_number():
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# # 设置目录路径
|
||||||
|
# TRAIN_DIR = "/home/lc/data_center/electornic/v1/train"
|
||||||
|
# VAL_DIR = "/home/lc/data_center/electornic/v1/val"
|
||||||
|
|
||||||
|
with open('../configs/scatter_data.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
print("开始分割数据集...")
|
||||||
|
split_directory(conf)
|
||||||
|
print("数据集分割完成")
|
192
data_preprocessing/extend.py
Normal file
192
data_preprocessing/extend.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
from PIL import Image, ImageEnhance
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExtendProcessor:
|
||||||
|
def __init__(self, conf):
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
def is_image_file(self, filename):
|
||||||
|
"""检查文件是否为图像文件"""
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
|
return filename.lower().endswith(image_extensions)
|
||||||
|
|
||||||
|
def random_cute_image(self, image_path, output_path, ratio=0.8):
|
||||||
|
"""
|
||||||
|
对图像进行随机裁剪
|
||||||
|
:param image_path: 输入图像路径
|
||||||
|
:param output_path: 输出图像路径
|
||||||
|
:param ratio: 裁剪比例,决定裁剪区域的大小,默认为0.8
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 获取图像尺寸
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
# 计算裁剪后的尺寸
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
|
||||||
|
# 随机生成裁剪起始点
|
||||||
|
left = random.randint(0, width - new_width)
|
||||||
|
top = random.randint(0, height - new_height)
|
||||||
|
right = left + new_width
|
||||||
|
bottom = top + new_height
|
||||||
|
|
||||||
|
# 执行裁剪
|
||||||
|
cropped_img = img.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
# 保存裁剪后的图像
|
||||||
|
cropped_img.save(output_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def random_brightness(self, image_path, output_path, brightness_factor=None):
|
||||||
|
"""
|
||||||
|
对图像进行随机亮度调整
|
||||||
|
:param image_path: 输入图像路径
|
||||||
|
:param output_path: 输出图像路径
|
||||||
|
:param brightness_factor: 亮度调整因子,默认为随机值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 创建一个ImageEnhance.Brightness对象
|
||||||
|
enhancer = ImageEnhance.Brightness(img)
|
||||||
|
|
||||||
|
# 如果没有指定亮度因子,则随机生成
|
||||||
|
if brightness_factor is None:
|
||||||
|
brightness_factor = random.uniform(0.5, 1.5)
|
||||||
|
|
||||||
|
# 应用亮度调整
|
||||||
|
brightened_img = enhancer.enhance(brightness_factor)
|
||||||
|
|
||||||
|
# 保存调整后的图像
|
||||||
|
brightened_img.save(output_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def rotate_image(self, image_path, output_path, degrees):
|
||||||
|
"""旋转图像并保存到指定路径"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 旋转图像并自动调整画布大小
|
||||||
|
rotated = img.rotate(degrees, expand=True)
|
||||||
|
rotated.save(output_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_extra_directory(self, src_dir, dst_dir, same_directory, dir_name):
|
||||||
|
"""
|
||||||
|
处理单个目录中的图像文件
|
||||||
|
:param src_dir: 源目录路径
|
||||||
|
:param dst_dir: 目标目录路径
|
||||||
|
"""
|
||||||
|
if not os.path.exists(dst_dir):
|
||||||
|
os.makedirs(dst_dir)
|
||||||
|
# 获取目录中所有图像文件
|
||||||
|
image_files = [f for f in os.listdir(src_dir)
|
||||||
|
if self.is_image_file(f) and os.path.isfile(os.path.join(src_dir, f))]
|
||||||
|
|
||||||
|
# 处理每个图像文件
|
||||||
|
for img_file in image_files:
|
||||||
|
src_path = os.path.join(src_dir, img_file)
|
||||||
|
base_name, ext = os.path.splitext(img_file)
|
||||||
|
if not same_directory:
|
||||||
|
# 复制原始文件 (另存文件夹时启用)
|
||||||
|
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
|
||||||
|
if dir_name == 'extra':
|
||||||
|
# 生成并保存旋转后的图像
|
||||||
|
for angle in [90, 180, 270]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||||
|
self.rotate_image(src_path, dst_path, angle)
|
||||||
|
for ratio in [0.8, 0.85, 0.9]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
|
||||||
|
self.random_cute_image(src_path, dst_path, ratio)
|
||||||
|
for brightness_factor in [0.8, 0.9, 1.0]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
|
||||||
|
self.random_brightness(src_path, dst_path, brightness_factor)
|
||||||
|
elif dir_name == 'train':
|
||||||
|
# 生成并保存旋转后的图像
|
||||||
|
for angle in [90, 180, 270]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||||
|
self.rotate_image(src_path, dst_path, angle)
|
||||||
|
|
||||||
|
def image_extend(self, src_dir, dst_dir, same_directory=False, dir_name=None):
|
||||||
|
if same_directory:
|
||||||
|
n_dst_dir = src_dir
|
||||||
|
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
|
||||||
|
else:
|
||||||
|
n_dst_dir = dst_dir
|
||||||
|
print(f"处理目录 {src_dir} 中的图像文件 保存至不同目录下")
|
||||||
|
for src_subdir in os.listdir(src_dir):
|
||||||
|
src_subdir_path = os.path.join(src_dir, src_subdir)
|
||||||
|
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
|
||||||
|
if dir_name == 'extra':
|
||||||
|
self.process_extra_directory(src_subdir_path,
|
||||||
|
dst_subdir_path,
|
||||||
|
same_directory,
|
||||||
|
dir_name)
|
||||||
|
if dir_name == 'train':
|
||||||
|
if len(os.listdir(src_subdir_path)) < 50:
|
||||||
|
self.process_extra_directory(src_subdir_path,
|
||||||
|
dst_subdir_path,
|
||||||
|
same_directory,
|
||||||
|
dir_name)
|
||||||
|
|
||||||
|
def random_remove_image(self, subdir_path, max_count=1000):
|
||||||
|
"""
|
||||||
|
随机删除子目录中的图像文件,直到数量不超过max_count
|
||||||
|
:param subdir_path: 子目录路径
|
||||||
|
:param max_count: 最大允许的图像数量
|
||||||
|
"""
|
||||||
|
# 统计图像文件数量
|
||||||
|
image_files = [f for f in os.listdir(subdir_path)
|
||||||
|
if self.is_image_file(f) and os.path.isfile(os.path.join(subdir_path, f))]
|
||||||
|
current_count = len(image_files)
|
||||||
|
|
||||||
|
# 如果图像数量不超过max_count,则无需删除
|
||||||
|
if current_count <= max_count:
|
||||||
|
print(f"无需处理 {subdir_path} (包含 {current_count} 个图像)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 计算需要删除的文件数
|
||||||
|
remove_count = current_count - max_count
|
||||||
|
|
||||||
|
# 随机选择要删除的文件
|
||||||
|
files_to_remove = random.sample(image_files, remove_count)
|
||||||
|
|
||||||
|
# 删除选中的文件
|
||||||
|
for file in files_to_remove:
|
||||||
|
file_path = os.path.join(subdir_path, file)
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"已删除 {file_path}")
|
||||||
|
|
||||||
|
def control_number(self):
|
||||||
|
if self.conf['extend']['extend_extra']:
|
||||||
|
self.image_extend(self.conf['extend']['extend_extra_dir'],
|
||||||
|
'',
|
||||||
|
same_directory=self.conf['extend']['extend_same_dir'],
|
||||||
|
dir_name='extra')
|
||||||
|
if self.conf['extend']['extend_train']:
|
||||||
|
self.image_extend(self.conf['extend']['extend_train_dir'],
|
||||||
|
'',
|
||||||
|
same_directory=self.conf['extend']['extend_same_dir'],
|
||||||
|
dir_name='train')
|
||||||
|
if self.conf['limit']['count_limit']:
|
||||||
|
self.random_remove_image(self.conf['limit']['limit_dir'],
|
||||||
|
max_count=self.conf['limit']['limit_count'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
src_dir = "./scatter_mini"
|
||||||
|
dst_dir = "./scatter_add"
|
||||||
|
image_ex = ImageExtendProcessor()
|
||||||
|
image_ex.image_extend(src_dir, dst_dir, same_directory=False)
|
21
data_preprocessing/sub_data_preprocessing.py
Normal file
21
data_preprocessing/sub_data_preprocessing.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from create_extra import split_subdirs
|
||||||
|
from data_split import split_directory
|
||||||
|
from extend import ImageExtendProcessor
|
||||||
|
from combine_sub_class import combine_dirs
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def data_preprocessing(conf):
|
||||||
|
if conf['control']['split']:
|
||||||
|
split_subdirs(conf)
|
||||||
|
image_ex = ImageExtendProcessor(conf)
|
||||||
|
image_ex.control_number()
|
||||||
|
split_directory(conf)
|
||||||
|
if conf['control']['combine']:
|
||||||
|
combine_dirs(conf)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open('../configs/sub_data.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
data_preprocessing(conf)
|
@ -205,7 +205,7 @@ class ResNet(nn.Module):
|
|||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
norm_layer = nn.BatchNorm2d
|
norm_layer = nn.BatchNorm2d
|
||||||
self._norm_layer = norm_layer
|
self._norm_layer = norm_layer
|
||||||
print("ResNet scale: >>>>>>>>>> ", scale)
|
print("通道剪枝 {}".format(scale))
|
||||||
self.inplanes = 64
|
self.inplanes = 64
|
||||||
self.dilation = 1
|
self.dilation = 1
|
||||||
if replace_stride_with_dilation is None:
|
if replace_stride_with_dilation is None:
|
||||||
@ -222,13 +222,13 @@ class ResNet(nn.Module):
|
|||||||
self.bn1 = norm_layer(self.inplanes)
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
# self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||||
self.maxpool2 = nn.Sequential(
|
# self.maxpool2 = nn.Sequential(
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
# nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||||
)
|
# )
|
||||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||||
dilate=replace_stride_with_dilation[0])
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
87
test_ori.py
87
test_ori.py
@ -11,10 +11,13 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
# from config import config as conf
|
# from config import config as conf
|
||||||
from tools.dataset import get_transform
|
from tools.dataset import get_transform
|
||||||
|
from tools.image_joint import merge_imgs
|
||||||
|
from tools.getHeatMap import cal_cam
|
||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
with open('configs/test.yml', 'r') as f:
|
with open('./configs/test.yml', 'r') as f:
|
||||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
# Constants from config
|
# Constants from config
|
||||||
@ -22,6 +25,7 @@ embedding_size = conf["base"]["embedding_size"]
|
|||||||
img_size = conf["transform"]["img_size"]
|
img_size = conf["transform"]["img_size"]
|
||||||
device = conf["base"]["device"]
|
device = conf["base"]["device"]
|
||||||
|
|
||||||
|
|
||||||
def unique_image(pair_list: str) -> Set[str]:
|
def unique_image(pair_list: str) -> Set[str]:
|
||||||
unique_images = set()
|
unique_images = set()
|
||||||
try:
|
try:
|
||||||
@ -115,8 +119,12 @@ def featurize(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in feature extraction: {e}")
|
print(f"Error in feature extraction: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def cosin_metric(x1, x2):
|
def cosin_metric(x1, x2):
|
||||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||||
|
|
||||||
|
|
||||||
def threshold_search(y_score, y_true):
|
def threshold_search(y_score, y_true):
|
||||||
y_score = np.asarray(y_score)
|
y_score = np.asarray(y_score)
|
||||||
y_true = np.asarray(y_true)
|
y_true = np.asarray(y_true)
|
||||||
@ -133,7 +141,7 @@ def threshold_search(y_score, y_true):
|
|||||||
|
|
||||||
|
|
||||||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||||
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
x = np.linspace(start=-1, stop=1.0, num=100, endpoint=True).tolist()
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||||
@ -143,6 +151,7 @@ def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
|||||||
plt.legend()
|
plt.legend()
|
||||||
plt.xlabel('threshold')
|
plt.xlabel('threshold')
|
||||||
# plt.ylabel('Similarity')
|
# plt.ylabel('Similarity')
|
||||||
|
|
||||||
plt.grid(True, linestyle='--', alpha=0.5)
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
plt.savefig('grid.png')
|
plt.savefig('grid.png')
|
||||||
plt.show()
|
plt.show()
|
||||||
@ -154,19 +163,19 @@ def showHist(same, cross):
|
|||||||
Cross = np.array(cross)
|
Cross = np.array(cross)
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 1)
|
fig, axs = plt.subplots(2, 1)
|
||||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
axs[0].hist(Same, bins=100, edgecolor='black')
|
||||||
axs[0].set_xlim([-0.1, 1])
|
axs[0].set_xlim([-1, 1])
|
||||||
axs[0].set_title('Same Barcode')
|
axs[0].set_title('Same Barcode')
|
||||||
|
|
||||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
axs[1].hist(Cross, bins=100, edgecolor='black')
|
||||||
axs[1].set_xlim([-0.1, 1])
|
axs[1].set_xlim([-1, 1])
|
||||||
axs[1].set_title('Cross Barcode')
|
axs[1].set_title('Cross Barcode')
|
||||||
plt.savefig('plot.png')
|
plt.savefig('plot.png')
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy_recall(score, labels):
|
def compute_accuracy_recall(score, labels):
|
||||||
th = 0.1
|
th = 0.1
|
||||||
squence = np.linspace(-1, 1, num=50)
|
squence = np.linspace(-1, 1, num=100)
|
||||||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||||
Same = score[:len(score) // 2]
|
Same = score[:len(score) // 2]
|
||||||
Cross = score[len(score) // 2:]
|
Cross = score[len(score) // 2:]
|
||||||
@ -179,24 +188,26 @@ def compute_accuracy_recall(score, labels):
|
|||||||
f_labels = (labels == 0)
|
f_labels = (labels == 0)
|
||||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||||
|
|
||||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||||
|
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
|
||||||
|
PreciseNeg[-1], recall_TN[-1]))
|
||||||
showHist(Same, Cross)
|
showHist(Same, Cross)
|
||||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy(
|
def compute_accuracy(
|
||||||
feature_dict: Dict[str, torch.Tensor],
|
feature_dict: Dict[str, torch.Tensor],
|
||||||
pair_list: str,
|
cam: cal_cam,
|
||||||
test_root: str
|
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
try:
|
try:
|
||||||
|
pair_list = conf['data']['test_list']
|
||||||
|
test_root = conf['data']['test_dir']
|
||||||
with open(pair_list, 'r') as f:
|
with open(pair_list, 'r') as f:
|
||||||
pairs = f.readlines()
|
pairs = f.readlines()
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
@ -211,7 +222,8 @@ def compute_accuracy(
|
|||||||
if not pair:
|
if not pair:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
|
print(f"Processing pair: {pair}")
|
||||||
img1, img2, label = pair.split()
|
img1, img2, label = pair.split()
|
||||||
img1_path = osp.join(test_root, img1)
|
img1_path = osp.join(test_root, img1)
|
||||||
img2_path = osp.join(test_root, img2)
|
img2_path = osp.join(test_root, img2)
|
||||||
@ -224,13 +236,20 @@ def compute_accuracy(
|
|||||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||||
similarity = cosin_metric(feat1, feat2)
|
similarity = cosin_metric(feat1, feat2)
|
||||||
|
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
|
||||||
|
if conf['data']['save_image_joint']:
|
||||||
|
merge_imgs(img1_path,
|
||||||
|
img2_path,
|
||||||
|
conf,
|
||||||
|
similarity,
|
||||||
|
label,
|
||||||
|
cam)
|
||||||
similarities.append(similarity)
|
similarities.append(similarity)
|
||||||
labels.append(int(label))
|
labels.append(int(label))
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
# print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
# Find optimal threshold and accuracy
|
# Find optimal threshold and accuracy
|
||||||
accuracy, threshold = threshold_search(similarities, labels)
|
accuracy, threshold = threshold_search(similarities, labels)
|
||||||
@ -267,10 +286,10 @@ def compute_group_accuracy(content_list_read):
|
|||||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||||
one_group_list.append(d.values())
|
one_group_list.append(d.values())
|
||||||
if data_loaded[-1] == '1':
|
if data_loaded[-1] == '1':
|
||||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||||
Same.append(similarity)
|
Same.append(similarity)
|
||||||
else:
|
else:
|
||||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||||||
Cross.append(similarity)
|
Cross.append(similarity)
|
||||||
allLabel.append(data_loaded[-1])
|
allLabel.append(data_loaded[-1])
|
||||||
allSimilarity.extend(similarity)
|
allSimilarity.extend(similarity)
|
||||||
@ -291,14 +310,36 @@ def init_model():
|
|||||||
print('load model {} '.format(conf['models']['backbone']))
|
print('load model {} '.format(conf['models']['backbone']))
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
###############正常模型加载################
|
||||||
|
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device']))
|
||||||
|
#######################################
|
||||||
|
####### 对比学习模型临时运用###
|
||||||
|
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||||
|
# new_state_dict = {}
|
||||||
|
# for k, v in state_dict.items():
|
||||||
|
# new_key = k.replace("module.base_model.", "module.")
|
||||||
|
# new_state_dict[new_key] = v
|
||||||
|
# model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
###########################
|
||||||
if conf['models']['half']:
|
if conf['models']['half']:
|
||||||
model.half()
|
model.half()
|
||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
|
try:
|
||||||
if conf.model_half:
|
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device']))
|
||||||
|
except:
|
||||||
|
state_dict = torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device'])
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
new_key = k.replace("module.", "")
|
||||||
|
new_state_dict[new_key] = v
|
||||||
|
model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
|
||||||
|
if conf['models']['half']:
|
||||||
model.half()
|
model.half()
|
||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
@ -308,7 +349,7 @@ def init_model():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = init_model()
|
model = init_model()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
cam = cal_cam(model, conf)
|
||||||
if not conf['data']['group_test']:
|
if not conf['data']['group_test']:
|
||||||
images = unique_image(conf['data']['test_list'])
|
images = unique_image(conf['data']['test_list'])
|
||||||
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||||||
@ -318,7 +359,7 @@ if __name__ == '__main__':
|
|||||||
for group in groups:
|
for group in groups:
|
||||||
d = featurize(group, test_transform, model, conf['base']['device'])
|
d = featurize(group, test_transform, model, conf['base']['device'])
|
||||||
feature_dict.update(d)
|
feature_dict.update(d)
|
||||||
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
|
accuracy, threshold = compute_accuracy(feature_dict, cam)
|
||||||
print(
|
print(
|
||||||
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||||||
)
|
)
|
||||||
|
@ -5,12 +5,14 @@ import torchvision.transforms as T
|
|||||||
# from config import config as conf
|
# from config import config as conf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def pad_to_square(img):
|
def pad_to_square(img):
|
||||||
w, h = img.size
|
w, h = img.size
|
||||||
max_wh = max(w, h)
|
max_wh = max(w, h)
|
||||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||||
|
|
||||||
|
|
||||||
def get_transform(cfg):
|
def get_transform(cfg):
|
||||||
train_transform = T.Compose([
|
train_transform = T.Compose([
|
||||||
T.Lambda(pad_to_square), # 补边
|
T.Lambda(pad_to_square), # 补边
|
||||||
@ -24,7 +26,7 @@ def get_transform(cfg):
|
|||||||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||||
])
|
])
|
||||||
test_transform = T.Compose([
|
test_transform = T.Compose([
|
||||||
# T.Lambda(pad_to_square), # 补边
|
T.Lambda(pad_to_square), # 补边
|
||||||
T.ToTensor(),
|
T.ToTensor(),
|
||||||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||||
T.ConvertImageDtype(torch.float32),
|
T.ConvertImageDtype(torch.float32),
|
||||||
@ -32,37 +34,73 @@ def get_transform(cfg):
|
|||||||
])
|
])
|
||||||
return train_transform, test_transform
|
return train_transform, test_transform
|
||||||
|
|
||||||
def load_data(training=True, cfg=None):
|
|
||||||
|
def load_data(training=True, cfg=None, return_dataset=False):
|
||||||
train_transform, test_transform = get_transform(cfg)
|
train_transform, test_transform = get_transform(cfg)
|
||||||
if training:
|
if training:
|
||||||
dataroot = cfg['data']['data_train_dir']
|
dataroot = cfg['data']['data_train_dir']
|
||||||
transform = train_transform
|
transform = train_transform
|
||||||
# transform = conf.train_transform
|
# transform.yml = conf.train_transform
|
||||||
batch_size = cfg['data']['train_batch_size']
|
batch_size = cfg['data']['train_batch_size']
|
||||||
else:
|
else:
|
||||||
dataroot = cfg['data']['data_val_dir']
|
dataroot = cfg['data']['data_val_dir']
|
||||||
# transform = conf.test_transform
|
# transform.yml = conf.test_transform
|
||||||
transform = test_transform
|
transform = test_transform
|
||||||
batch_size = cfg['data']['val_batch_size']
|
batch_size = cfg['data']['val_batch_size']
|
||||||
|
|
||||||
data = ImageFolder(dataroot, transform=transform)
|
data = ImageFolder(dataroot, transform=transform)
|
||||||
class_num = len(data.classes)
|
class_num = len(data.classes)
|
||||||
|
if return_dataset:
|
||||||
|
return data, class_num
|
||||||
|
else:
|
||||||
loader = DataLoader(data,
|
loader = DataLoader(data,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True if training else False,
|
||||||
pin_memory=cfg['base']['pin_memory'],
|
pin_memory=cfg['base']['pin_memory'],
|
||||||
num_workers=cfg['data']['num_workers'],
|
num_workers=cfg['data']['num_workers'],
|
||||||
drop_last=True)
|
drop_last=True)
|
||||||
return loader, class_num
|
return loader, class_num
|
||||||
|
|
||||||
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||||
|
"""
|
||||||
|
MultiEpochsDataLoader 类
|
||||||
|
通过重用工作进程来提高数据加载效率,避免每个epoch重新启动工作进程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._DataLoader__initialized = False
|
||||||
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||||
|
self._DataLoader__initialized = True
|
||||||
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch_sampler.sampler)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield next(self.iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class _RepeatSampler(object):
|
||||||
|
"""
|
||||||
|
重复采样器,避免每个epoch重新创建迭代器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while True:
|
||||||
|
yield from iter(self.sampler)
|
||||||
# def load_gift_data(action):
|
# def load_gift_data(action):
|
||||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
# return train_dataset, val_dataset, test_dataset
|
# return train_dataset, val_dataset, test_dataset
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||||
|
144
tools/event_similar_analysis.py
Normal file
144
tools/event_similar_analysis.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from similar_analysis import SimilarAnalysis
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from tools.image_joint import merge_imgs
|
||||||
|
|
||||||
|
|
||||||
|
class EventSimilarAnalysis(SimilarAnalysis):
|
||||||
|
def __init__(self):
|
||||||
|
super(EventSimilarAnalysis, self).__init__()
|
||||||
|
self.fn_one2one_event, self.fp_one2one_event = self.One2one_similar_analysis()
|
||||||
|
self.fn_one2sn_event, self.fp_one2sn_event = self.One2Sn_similar_analysis()
|
||||||
|
if os.path.exists(self.conf['event']['pickle_path']):
|
||||||
|
print('pickle file exists')
|
||||||
|
else:
|
||||||
|
self.target_image = self.get_path()
|
||||||
|
|
||||||
|
def get_path(self):
|
||||||
|
events = [self.fn_one2one_event, self.fp_one2one_event,
|
||||||
|
self.fn_one2sn_event, self.fp_one2sn_event]
|
||||||
|
event_image_path = []
|
||||||
|
barcode_image_path = []
|
||||||
|
for event in events:
|
||||||
|
for event_name, bcd in event:
|
||||||
|
event_sub_image = os.sep.join([self.conf['event']['event_save_dir'],
|
||||||
|
event_name,
|
||||||
|
'subimgs'])
|
||||||
|
barcode_images = os.sep.join([self.conf['event']['stdlib_image_path'],
|
||||||
|
bcd])
|
||||||
|
for image_name in os.listdir(event_sub_image):
|
||||||
|
event_image_path.append(os.sep.join([event_sub_image, image_name]))
|
||||||
|
for barcode in os.listdir(barcode_images):
|
||||||
|
barcode_image_path.append(os.sep.join([barcode_images, barcode]))
|
||||||
|
return list(set(event_image_path + barcode_image_path))
|
||||||
|
|
||||||
|
|
||||||
|
def write_dict_to_pickle(self, data):
|
||||||
|
"""将字典写入pickle文件."""
|
||||||
|
with open(self.conf['event']['pickle_path'], 'wb') as file:
|
||||||
|
pickle.dump(data, file)
|
||||||
|
|
||||||
|
def get_dict_to_pickle(self):
|
||||||
|
with open(self.conf['event']['pickle_path'], 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def create_total_feature(self):
|
||||||
|
feature_dicts = self.get_feature_map(self.target_image)
|
||||||
|
self.write_dict_to_pickle(feature_dicts)
|
||||||
|
print(feature_dicts)
|
||||||
|
|
||||||
|
def One2one_similar_analysis(self):
|
||||||
|
fn_event, fp_event = [], []
|
||||||
|
with open(self.conf['event']['oneToOneTxt'], 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
print(line.strip().split(' '))
|
||||||
|
event_infor = line.strip().split(' ')
|
||||||
|
label = event_infor[0]
|
||||||
|
event_name = event_infor[1]
|
||||||
|
bcd = event_infor[2]
|
||||||
|
simi1 = event_infor[3]
|
||||||
|
simi2 = event_infor[4]
|
||||||
|
if label == 'same' and float(simi2) < self.conf['event']['oneToOne_max_th']:
|
||||||
|
print(event_name, bcd, simi1)
|
||||||
|
fn_event.append((event_name, bcd))
|
||||||
|
elif label == 'diff' and float(simi2) > self.conf['event']['oneToSn_min_th']:
|
||||||
|
fp_event.append((event_name, bcd))
|
||||||
|
return fn_event, fp_event
|
||||||
|
|
||||||
|
def One2Sn_similar_analysis(self):
|
||||||
|
fn_event, fp_event = [], []
|
||||||
|
with open(self.conf['event']['oneToOneTxt'], 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
print(line.strip().split(' '))
|
||||||
|
event_infor = line.strip().split(' ')
|
||||||
|
label = event_infor[0]
|
||||||
|
event_name = event_infor[1]
|
||||||
|
bcd = event_infor[2]
|
||||||
|
simi = event_infor[3]
|
||||||
|
if label == 'fn':
|
||||||
|
print(event_name, bcd, simi)
|
||||||
|
fn_event.append((event_name, bcd))
|
||||||
|
elif label == 'fp':
|
||||||
|
fp_event.append((event_name, bcd))
|
||||||
|
return fn_event, fp_event
|
||||||
|
|
||||||
|
def save_joint_image(self, img_pth1, img_pth2, feature_dicts, record):
|
||||||
|
feature_dict1 = feature_dicts[img_pth1]
|
||||||
|
feature_dict2 = feature_dicts[img_pth2]
|
||||||
|
similarity = self.get_similarity(feature_dict1.cpu().numpy(),
|
||||||
|
feature_dict2.cpu().numpy())
|
||||||
|
dir_name = img_pth1.split('/')[-3]
|
||||||
|
save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name, record])
|
||||||
|
if "fp" in record:
|
||||||
|
if similarity > 0.8:
|
||||||
|
merge_imgs(img_pth1,
|
||||||
|
img_pth2,
|
||||||
|
self.conf,
|
||||||
|
similarity,
|
||||||
|
label=None,
|
||||||
|
cam=self.cam,
|
||||||
|
save_path=save_path)
|
||||||
|
else:
|
||||||
|
if similarity < 0.8:
|
||||||
|
merge_imgs(img_pth1,
|
||||||
|
img_pth2,
|
||||||
|
self.conf,
|
||||||
|
similarity,
|
||||||
|
label=None,
|
||||||
|
cam=self.cam,
|
||||||
|
save_path=save_path)
|
||||||
|
print(similarity)
|
||||||
|
|
||||||
|
def get_contrast(self, feature_dicts):
|
||||||
|
events_compare = [self.fp_one2one_event, self.fn_one2one_event, self.fp_one2sn_event, self.fn_one2sn_event]
|
||||||
|
event_record = ['fp_one2one', 'fn_one2one', 'fp_one2sn', 'fn_one2sn']
|
||||||
|
for event_compare, record in zip(events_compare, event_record):
|
||||||
|
for img, img_std in event_compare:
|
||||||
|
imgs_pth1 = os.sep.join([self.conf['event']['event_save_dir'],
|
||||||
|
img,
|
||||||
|
'subimgs'])
|
||||||
|
imgs_pth2 = os.sep.join([self.conf['event']['stdlib_image_path'],
|
||||||
|
img_std])
|
||||||
|
for img1 in os.listdir(imgs_pth1):
|
||||||
|
for img2 in os.listdir(imgs_pth2):
|
||||||
|
img_pth1 = os.sep.join([imgs_pth1, img1])
|
||||||
|
img_pth2 = os.sep.join([imgs_pth2, img2])
|
||||||
|
try:
|
||||||
|
self.save_joint_image(img_pth1, img_pth2, feature_dicts, record)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
event_similar_analysis = EventSimilarAnalysis()
|
||||||
|
if os.path.exists(event_similar_analysis.conf['event']['pickle_path']):
|
||||||
|
print('pickle file exists')
|
||||||
|
else:
|
||||||
|
event_similar_analysis.create_total_feature() # 生成pickle文件, 生成时间较长,生成一个文件即可
|
||||||
|
feature_dicts = event_similar_analysis.get_dict_to_pickle()
|
||||||
|
# all_compare_img = event_similar_analysis.get_image_map()
|
||||||
|
event_similar_analysis.get_contrast(feature_dicts) # 获取比对结果
|
164
tools/getHeatMap.py
Normal file
164
tools/getHeatMap.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision import models
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms as tfs
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
# from tools.config import cfg
|
||||||
|
# from comparative.tools.initmodel import initSimilarityModel
|
||||||
|
import yaml
|
||||||
|
from dataset import get_transform
|
||||||
|
|
||||||
|
|
||||||
|
class cal_cam(nn.Module):
|
||||||
|
def __init__(self, model, conf):
|
||||||
|
super(cal_cam, self).__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.device = self.conf['base']['device']
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
# 要求梯度的层
|
||||||
|
self.feature_layer = conf['heatmap']['feature_layer']
|
||||||
|
# 记录梯度
|
||||||
|
self.gradient = []
|
||||||
|
# 记录输出的特征图
|
||||||
|
self.output = []
|
||||||
|
_, self.transform = get_transform(self.conf)
|
||||||
|
|
||||||
|
def get_conf(self, yaml_pth):
|
||||||
|
with open(yaml_pth, 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def save_grad(self, grad):
|
||||||
|
self.gradient.append(grad)
|
||||||
|
|
||||||
|
def get_grad(self):
|
||||||
|
return self.gradient[-1].cpu().data
|
||||||
|
|
||||||
|
def get_feature(self):
|
||||||
|
return self.output[-1][0]
|
||||||
|
|
||||||
|
def process_img(self, input):
|
||||||
|
input = self.transform(input)
|
||||||
|
input = input.unsqueeze(0)
|
||||||
|
return input
|
||||||
|
|
||||||
|
# 计算最后一个卷积层的梯度,输出梯度和最后一个卷积层的特征图
|
||||||
|
def getGrad(self, input_):
|
||||||
|
self.gradient = [] # 清除之前的梯度
|
||||||
|
self.output = [] # 清除之前的特征图
|
||||||
|
# print(f"cuda.memory_allocated 1 {torch.cuda.memory_allocated()/ (1024 ** 3)}G")
|
||||||
|
input_ = input_.to(self.device).requires_grad_(True)
|
||||||
|
num = 1
|
||||||
|
for name, module in self.model._modules.items():
|
||||||
|
# print(f'module_name: {name}')
|
||||||
|
# print(f'module: {module}')
|
||||||
|
if (num == 1):
|
||||||
|
input = module(input_)
|
||||||
|
num = num + 1
|
||||||
|
continue
|
||||||
|
# 是待提取特征图的层
|
||||||
|
if (name == self.feature_layer):
|
||||||
|
input = module(input)
|
||||||
|
input.register_hook(self.save_grad)
|
||||||
|
self.output.append([input])
|
||||||
|
# 马上要到全连接层了
|
||||||
|
elif (name == "avgpool"):
|
||||||
|
input = module(input)
|
||||||
|
input = input.reshape(input.shape[0], -1)
|
||||||
|
# 普通的层
|
||||||
|
else:
|
||||||
|
input = module(input)
|
||||||
|
|
||||||
|
# print(f"cuda.memory_allocated 2 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||||
|
# 到这里input就是最后全连接层的输出了
|
||||||
|
index = torch.max(input, dim=-1)[1]
|
||||||
|
one_hot = torch.zeros((1, input.shape[-1]), dtype=torch.float32)
|
||||||
|
one_hot[0][index] = 1
|
||||||
|
confidenct = one_hot * input.cpu()
|
||||||
|
confidenct = torch.sum(confidenct, dim=-1).requires_grad_(True)
|
||||||
|
|
||||||
|
# print(f"cuda.memory_allocated 3 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||||
|
# 清除之前的所有梯度
|
||||||
|
self.model.zero_grad()
|
||||||
|
# 反向传播获取梯度
|
||||||
|
grad_output = torch.ones_like(confidenct)
|
||||||
|
confidenct.backward(grad_output)
|
||||||
|
# 获取特征图的梯度
|
||||||
|
grad_val = self.get_grad()
|
||||||
|
feature = self.get_feature()
|
||||||
|
|
||||||
|
# print(f"cuda.memory_allocated 4 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||||
|
return grad_val, feature, input_.grad
|
||||||
|
|
||||||
|
# 计算CAM
|
||||||
|
def getCam(self, grad_val, feature):
|
||||||
|
# 对特征图的每个通道进行全局池化
|
||||||
|
alpha = torch.mean(grad_val, dim=(2, 3)).cpu()
|
||||||
|
feature = feature.cpu()
|
||||||
|
# 将池化后的结果和相应通道特征图相乘
|
||||||
|
cam = torch.zeros((feature.shape[2], feature.shape[3]), dtype=torch.float32)
|
||||||
|
for idx in range(alpha.shape[1]):
|
||||||
|
cam = cam + alpha[0][idx] * feature[0][idx]
|
||||||
|
# 进行ReLU操作
|
||||||
|
cam = np.maximum(cam.detach().numpy(), 0)
|
||||||
|
|
||||||
|
# plt.imshow(cam)
|
||||||
|
# plt.colorbar()
|
||||||
|
# plt.savefig("cam.jpg")
|
||||||
|
|
||||||
|
# 将cam区域放大到输入图片大小
|
||||||
|
cam_ = cv2.resize(cam, (224, 224))
|
||||||
|
cam_ = cam_ - np.min(cam_)
|
||||||
|
cam_ = cam_ / np.max(cam_)
|
||||||
|
# plt.imshow(cam_)
|
||||||
|
# plt.savefig("cam_.jpg")
|
||||||
|
cam = torch.from_numpy(cam)
|
||||||
|
|
||||||
|
return cam, cam_
|
||||||
|
|
||||||
|
def show_img(self, cam_, img, heatmap_save_pth, imgname):
|
||||||
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||||
|
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||||
|
# cv2.imwrite("img.jpg", cam_img)
|
||||||
|
cv2.imwrite(os.sep.join([heatmap_save_pth, imgname]), cam_img)
|
||||||
|
|
||||||
|
def get_hot_map(self, img_pth):
|
||||||
|
img = Image.open(img_pth)
|
||||||
|
img = img.resize((224, 224))
|
||||||
|
input = self.process_img(img)
|
||||||
|
grad_val, feature, input_grad = self.getGrad(input)
|
||||||
|
cam, cam_ = self.getCam(grad_val, feature)
|
||||||
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||||
|
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||||
|
cam_img = Image.fromarray(np.uint8(cam_img))
|
||||||
|
return cam_img
|
||||||
|
|
||||||
|
# def __call__(self, img_root, heatmap_save_pth):
|
||||||
|
# for imgname in os.listdir(img_root):
|
||||||
|
# img = Image.open(os.sep.join([img_root, imgname]))
|
||||||
|
# img = img.resize((224, 224))
|
||||||
|
# # plt.imshow(img)
|
||||||
|
# # plt.savefig("airplane.jpg")
|
||||||
|
# input = self.process_img(img)
|
||||||
|
# grad_val, feature, input_grad = self.getGrad(input)
|
||||||
|
# cam, cam_ = self.getCam(grad_val, feature)
|
||||||
|
# self.show_img(cam_, img, heatmap_save_pth, imgname)
|
||||||
|
# return cam
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cam = cal_cam()
|
||||||
|
img_root = "test_img/"
|
||||||
|
heatmap_save_pth = "heatmap_result"
|
||||||
|
cam(img_root, heatmap_save_pth)
|
213
tools/getpairs.py
Normal file
213
tools/getpairs.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class PairGenerator:
|
||||||
|
"""Generate positive and negative image pairs for contrastive learning."""
|
||||||
|
|
||||||
|
def __init__(self, original_path):
|
||||||
|
self._setup_logging()
|
||||||
|
self.original_path = original_path
|
||||||
|
self._delete_space()
|
||||||
|
|
||||||
|
def _delete_space(self): # 删除图片文件名中的空格
|
||||||
|
print(self.original_path)
|
||||||
|
for root, dirs, files in os.walk(self.original_path):
|
||||||
|
for file_name in files:
|
||||||
|
if file_name.endswith('.jpg' or '.png'):
|
||||||
|
n_file_name = file_name.replace(' ', '')
|
||||||
|
os.rename(os.path.join(root, file_name), os.path.join(root, n_file_name))
|
||||||
|
if 'rotate' in file_name:
|
||||||
|
os.remove(os.path.join(root, file_name))
|
||||||
|
for dir_name in dirs:
|
||||||
|
n_dir_name = dir_name.replace(' ', '')
|
||||||
|
os.rename(os.path.join(root, dir_name), os.path.join(root, n_dir_name))
|
||||||
|
|
||||||
|
def _setup_logging(self):
|
||||||
|
"""Configure logging settings."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
||||||
|
"""Scan directory and return dict of {folder: [image_paths]}."""
|
||||||
|
root = Path(root_dir)
|
||||||
|
if not root.is_dir():
|
||||||
|
raise ValueError(f"Invalid directory: {root_dir}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
||||||
|
for folder in root.iterdir() if folder.is_dir()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_same_pairs(
|
||||||
|
self,
|
||||||
|
files_dict: Dict[str, List[str]],
|
||||||
|
num_pairs: int,
|
||||||
|
min_size: int, # min_size is the minimum number of images per folder
|
||||||
|
group_size: Optional[int] = None
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""Generate positive pairs from same folder."""
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
for folder, files in files_dict.items():
|
||||||
|
if len(files) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if group_size:
|
||||||
|
# Group mode: generate all possible pairs within group
|
||||||
|
for i in range(0, len(files), group_size):
|
||||||
|
group = files[i:i + group_size]
|
||||||
|
pairs.extend([
|
||||||
|
(group[i], group[j], 1)
|
||||||
|
for i in range(len(group))
|
||||||
|
for j in range(i + 1, len(group))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Individual mode: random pairs
|
||||||
|
try:
|
||||||
|
pairs.extend(self._random_pairs(files, min(min_size, len(files) // 2)))
|
||||||
|
except ValueError as e:
|
||||||
|
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||||
|
|
||||||
|
random.shuffle(pairs)
|
||||||
|
return pairs[:num_pairs]
|
||||||
|
|
||||||
|
def _generate_cross_pairs(
|
||||||
|
self,
|
||||||
|
files_dict: Dict[str, List[str]],
|
||||||
|
num_pairs: int
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""Generate negative pairs from different folders."""
|
||||||
|
folders = list(files_dict.keys())
|
||||||
|
pairs = []
|
||||||
|
existing_pairs = set()
|
||||||
|
|
||||||
|
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||||
|
folder1, folder2 = random.sample(folders, 2)
|
||||||
|
file1 = random.choice(files_dict[folder1])
|
||||||
|
file2 = random.choice(files_dict[folder2])
|
||||||
|
|
||||||
|
pair_key = (file1, file2)
|
||||||
|
reverse_key = (file2, file1)
|
||||||
|
|
||||||
|
if pair_key not in existing_pairs and reverse_key not in existing_pairs:
|
||||||
|
pairs.append((file1, file2, 0))
|
||||||
|
existing_pairs.add(pair_key)
|
||||||
|
existing_pairs.add(reverse_key)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||||
|
"""Generate random pairs from file list."""
|
||||||
|
max_possible = len(files) // 2
|
||||||
|
if max_possible == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
actual_pairs = min(num_pairs, max_possible)
|
||||||
|
indices = random.sample(range(len(files)), 2 * actual_pairs)
|
||||||
|
indices.sort()
|
||||||
|
return [(files[i], files[i + 1], 1) for i in range(0, len(indices), 2)]
|
||||||
|
|
||||||
|
def get_pairs(self,
|
||||||
|
root_dir: str,
|
||||||
|
num_pairs: int = 6000,
|
||||||
|
output_txt: Optional[str] = None
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""
|
||||||
|
Generate individual image pairs with labels (1=same, 0=different).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing subfolders of images
|
||||||
|
num_pairs: Number of pairs to generate
|
||||||
|
output_txt: Optional path to save pairs as txt file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (path1, path2, label) tuples
|
||||||
|
"""
|
||||||
|
files_dict = self._get_image_files(root_dir)
|
||||||
|
|
||||||
|
same_pairs = self._generate_same_pairs(files_dict, num_pairs, min_size=30)
|
||||||
|
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||||
|
|
||||||
|
pairs = same_pairs + cross_pairs
|
||||||
|
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||||
|
|
||||||
|
if output_txt:
|
||||||
|
try:
|
||||||
|
with open(output_txt, 'w') as f:
|
||||||
|
for file1, file2, label in pairs:
|
||||||
|
f.write(f"{file1} {file2} {label}\n")
|
||||||
|
self.logger.info(f"Saved pairs to {output_txt}")
|
||||||
|
except IOError as e:
|
||||||
|
self.logger.warning(f"Failed to write pairs to {output_txt}: {str(e)}")
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def get_group_pairs(
|
||||||
|
self,
|
||||||
|
root_dir: str,
|
||||||
|
img_num: int = 20,
|
||||||
|
group_num: int = 10,
|
||||||
|
num_pairs: int = 5000
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""
|
||||||
|
Generate grouped image pairs with labels (1=same, 0=different).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing subfolders of images
|
||||||
|
img_num: Minimum images required per folder
|
||||||
|
group_num: Number of images per group
|
||||||
|
num_pairs: Number of pairs to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (path1, path2, label) tuples
|
||||||
|
"""
|
||||||
|
# Filter folders with enough images
|
||||||
|
files_dict = {
|
||||||
|
k: v for k, v in self._get_image_files(root_dir).items()
|
||||||
|
if len(v) >= img_num
|
||||||
|
}
|
||||||
|
|
||||||
|
# Split into groups
|
||||||
|
grouped_files = {}
|
||||||
|
for folder, files in files_dict.items():
|
||||||
|
random.shuffle(files)
|
||||||
|
grouped_files[folder] = [
|
||||||
|
files[i:i + group_num]
|
||||||
|
for i in range(0, len(files), group_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate pairs
|
||||||
|
same_pairs = self._generate_same_pairs(
|
||||||
|
grouped_files, num_pairs, min_size=30, group_size=group_num
|
||||||
|
)
|
||||||
|
cross_pairs = self._generate_cross_pairs(
|
||||||
|
grouped_files, len(same_pairs)
|
||||||
|
)
|
||||||
|
|
||||||
|
pairs = same_pairs + cross_pairs
|
||||||
|
self.logger.info(f"Generated {len(pairs)} group pairs")
|
||||||
|
|
||||||
|
# Save to JSON
|
||||||
|
with open("cross_same.json", 'w') as f:
|
||||||
|
json.dump(pairs, f)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
original_path = '/home/lc/data_center/contrast_data/v1/extra'
|
||||||
|
parent_dir = str(Path(original_path).parent)
|
||||||
|
generator = PairGenerator(original_path)
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
pairs = generator.get_pairs(original_path,
|
||||||
|
output_txt=os.sep.join([parent_dir, 'extra_cross_same.txt'])) # Individual pairs
|
||||||
|
# groups = generator.get_group_pairs('val') # Group pairs
|
71
tools/image_joint.py
Normal file
71
tools/image_joint.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from tools.getHeatMap import cal_cam
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def merge_imgs(img1_path, img2_path, conf, similar=None, label=None, cam=None, save_path=None):
|
||||||
|
save = True
|
||||||
|
position = (50, 50) # 文字的左上角坐标
|
||||||
|
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||||
|
# if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||||
|
# os.makedirs(os.sep.join([save_path, str(label)]))
|
||||||
|
# save_path = os.sep.join([save_path, str(label)])
|
||||||
|
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||||
|
if save_path is None:
|
||||||
|
save_path = conf['data']['image_joint_pth']
|
||||||
|
if not conf['heatmap']['show_heatmap']:
|
||||||
|
img1 = Image.open(img1_path)
|
||||||
|
img2 = Image.open(img2_path)
|
||||||
|
img1 = img1.resize((224, 224))
|
||||||
|
img2 = img2.resize((224, 224))
|
||||||
|
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||||
|
# save_path = conf['data']['image_joint_pth']
|
||||||
|
else:
|
||||||
|
assert cam is not None, 'cam is None'
|
||||||
|
img1 = cam.get_hot_map(img1_path)
|
||||||
|
img2 = cam.get_hot_map(img2_path)
|
||||||
|
img1_ori = Image.open(img1_path)
|
||||||
|
img2_ori = Image.open(img2_path)
|
||||||
|
img1_ori = img1_ori.resize((224, 224))
|
||||||
|
img2_ori = img2_ori.resize((224, 224))
|
||||||
|
new_img = Image.new('RGB',
|
||||||
|
(img1.width + img2.width + 10,
|
||||||
|
img1.height + img2.width + 10))
|
||||||
|
# save_path = conf['heatmap']['image_joint_pth']
|
||||||
|
# print('img1_path', img1)
|
||||||
|
# print('img2_path', img2)
|
||||||
|
if not os.path.exists(os.sep.join([save_path, str(label)])) and (label is not None):
|
||||||
|
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||||
|
save_path = os.sep.join([save_path, str(label)])
|
||||||
|
if save_path is None:
|
||||||
|
# save_path = os.sep.join([save_path, str(label)])
|
||||||
|
pass
|
||||||
|
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||||
|
img_name = os.path.basename(img1_path).split('.')[0][:30] + '_' + os.path.basename(img2_path).split('.')[0][
|
||||||
|
:30] + '.png'
|
||||||
|
assert img1.height == img2.height
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# print('new_img', new_img)
|
||||||
|
if not conf['heatmap']['show_heatmap']:
|
||||||
|
new_img.paste(img1, (0, 0))
|
||||||
|
new_img.paste(img2, (img1.width + 10, 0))
|
||||||
|
else:
|
||||||
|
new_img.paste(img1_ori, (10, 10))
|
||||||
|
new_img.paste(img2_ori, (img2_ori.width + 20, 10))
|
||||||
|
new_img.paste(img1, (10, img1.height+20))
|
||||||
|
new_img.paste(img2, (img2.width+20, img2.height+20))
|
||||||
|
|
||||||
|
if similar is not None:
|
||||||
|
if label == '1' and (similar > 0.5 or similar < 0.25):
|
||||||
|
save = False
|
||||||
|
elif label == '0' and similar > 0.25:
|
||||||
|
save = False
|
||||||
|
similar = str(similar) + '_' + str(label)
|
||||||
|
draw = ImageDraw.Draw(new_img)
|
||||||
|
draw.text(position, str(similar), color, font_size=36)
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
img_save = os.path.join(save_path, img_name)
|
||||||
|
if save:
|
||||||
|
new_img.save(img_save)
|
@ -2,17 +2,29 @@ import pdb
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from model import resnet18
|
from model import resnet18
|
||||||
from config import config as conf
|
# from config import config as conf
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from configs import trainer_tools
|
||||||
import cv2
|
import cv2
|
||||||
|
import yaml
|
||||||
|
|
||||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
def tranform_onnx_model():
|
||||||
# 定义模型
|
# # 定义模型
|
||||||
if model_name == 'resnet18':
|
# if model_name == 'resnet18':
|
||||||
model = resnet18(scale=0.75)
|
# model = resnet18(scale=0.75)
|
||||||
|
|
||||||
print('model_name >>> {}'.format(model_name))
|
with open('../configs/transform.yml', 'r') as f:
|
||||||
if conf.multiple_cards:
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
pretrained_weights = conf['models']['model_path']
|
||||||
|
print('model_name >>> {}'.format(conf['models']['backbone']))
|
||||||
|
if conf['base']['distributed']:
|
||||||
model = model.to(torch.device('cpu'))
|
model = model.to(torch.device('cpu'))
|
||||||
checkpoint = torch.load(pretrained_weights)
|
checkpoint = torch.load(pretrained_weights)
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
@ -22,22 +34,8 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
|||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||||
# try:
|
|
||||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
|
||||||
# except Exception as e:
|
|
||||||
# print(e)
|
|
||||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
|
||||||
# model = nn.DataParallel(model).to(conf.device)
|
|
||||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
|
||||||
|
|
||||||
|
|
||||||
# 转换为ONNX
|
# 转换为ONNX
|
||||||
if model_name == 'gift_type2':
|
|
||||||
input_shape = [1, 64, 13, 13]
|
|
||||||
elif model_name == 'gift_type3':
|
|
||||||
input_shape = [1, 3, 224, 224]
|
|
||||||
else:
|
|
||||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
|
||||||
input_shape = [1, 3, 224, 224]
|
input_shape = [1, 3, 224, 224]
|
||||||
|
|
||||||
img = cv2.imread('./dog_224x224.jpg')
|
img = cv2.imread('./dog_224x224.jpg')
|
||||||
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
tranform_onnx_model()
|
||||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
|
||||||
|
@ -6,15 +6,14 @@ import time
|
|||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from config import config as conf
|
|
||||||
from rknn.api import RKNN
|
from rknn.api import RKNN
|
||||||
|
import yaml
|
||||||
import config
|
with open('../configs/transform.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
ONNX_MODEL = conf['models']['onnx_model']
|
||||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
RKNN_MODEL = conf['models']['rknn_model']
|
||||||
|
|
||||||
|
|
||||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||||
@ -97,10 +96,11 @@ if __name__ == '__main__':
|
|||||||
rknn.config(
|
rknn.config(
|
||||||
mean_values=[[127.5, 127.5, 127.5]],
|
mean_values=[[127.5, 127.5, 127.5]],
|
||||||
std_values=[[127.5, 127.5, 127.5]],
|
std_values=[[127.5, 127.5, 127.5]],
|
||||||
target_platform='rk3588',
|
target_platform='rk3566',
|
||||||
model_pruning=False,
|
model_pruning=False,
|
||||||
compress_weight=False,
|
compress_weight=False,
|
||||||
single_core_mode=True)
|
single_core_mode=True,
|
||||||
|
enable_flash_attention=True)
|
||||||
# rknn.config(
|
# rknn.config(
|
||||||
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||||
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||||
@ -122,7 +122,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Build model
|
# Build model
|
||||||
print('--> Building model')
|
print('--> Building model')
|
||||||
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
|
ret = rknn.build(do_quantization=False, # True
|
||||||
|
dataset='./dataset.txt',
|
||||||
|
rknn_batch_size=conf['models']['rknn_batch_size'])
|
||||||
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
print('Build model failed!')
|
print('Build model failed!')
|
||||||
|
242
tools/picdir_to_picdir_similar.py
Normal file
242
tools/picdir_to_picdir_similar.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
from similar_analysis import SimilarAnalysis
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from tools.image_joint import merge_imgs
|
||||||
|
import yaml
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
'''
|
||||||
|
轨迹图与标准库之间的相似度分析
|
||||||
|
1.用于生成轨迹图与标准库中所有图片的相似度
|
||||||
|
2.用于分析轨迹图与标准库比对选取策略的判断
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class picDirSimilarAnalysis(SimilarAnalysis):
|
||||||
|
def __init__(self):
|
||||||
|
super(picDirSimilarAnalysis, self).__init__()
|
||||||
|
with open('../configs/pic_pic_similar.yml', 'r') as f:
|
||||||
|
self.conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
if not os.path.exists(self.conf['data']['total_pkl']):
|
||||||
|
# self.create_total_feature()
|
||||||
|
self.create_total_pkl()
|
||||||
|
if os.path.exists(self.conf['data']['total_pkl']):
|
||||||
|
self.all_dicts = self.load_dict_from_pkl()
|
||||||
|
|
||||||
|
def is_image_file(self, filename):
|
||||||
|
"""
|
||||||
|
检查文件是否为图像文件
|
||||||
|
"""
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
|
return filename.lower().endswith(image_extensions)
|
||||||
|
|
||||||
|
def create_total_pkl(self): # 将目录下所有的图片特征存入pkl文件
|
||||||
|
all_images_feature_dict = {}
|
||||||
|
for roots, dirs, files in os.walk(self.conf['data']['data_dir']):
|
||||||
|
for file_name in files:
|
||||||
|
if self.is_image_file(file_name):
|
||||||
|
try:
|
||||||
|
print(f"处理图像 {os.sep.join([roots, file_name])}")
|
||||||
|
feature = self.extract_features(os.sep.join([roots, file_name]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {os.sep.join([roots, file_name])} 时出错: {e}")
|
||||||
|
feature = None
|
||||||
|
all_images_feature_dict[os.sep.join([roots, file_name])] = feature
|
||||||
|
if not os.path.exists(self.conf['data']['total_pkl']):
|
||||||
|
with open(self.conf['data']['total_pkl'], 'wb') as f:
|
||||||
|
pickle.dump(all_images_feature_dict, f)
|
||||||
|
|
||||||
|
def load_dict_from_pkl(self):
|
||||||
|
with open(self.conf['data']['total_pkl'], 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
print(f"字典已从 {self.conf['data']['total_pkl']} 加载")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_image_files(self, folder_path):
|
||||||
|
"""
|
||||||
|
获取文件夹中的所有图像文件
|
||||||
|
"""
|
||||||
|
image_files = []
|
||||||
|
for root, _, files in os.walk(folder_path):
|
||||||
|
for file in files:
|
||||||
|
if self.is_image_file(file):
|
||||||
|
image_files.append(os.path.join(root, file))
|
||||||
|
return image_files
|
||||||
|
|
||||||
|
def extract_features(self, image_path):
|
||||||
|
feature_dict = self.get_feature(image_path)
|
||||||
|
return feature_dict[image_path]
|
||||||
|
|
||||||
|
def create_one_similarity_matrix(self, folder1_path, folder2_path):
|
||||||
|
images1 = self.get_image_files(folder1_path)
|
||||||
|
images2 = self.get_image_files(folder2_path)
|
||||||
|
|
||||||
|
print(f"文件夹1 ({folder1_path}) 包含 {len(images1)} 张图像")
|
||||||
|
print(f"文件夹2 ({folder2_path}) 包含 {len(images2)} 张图像")
|
||||||
|
|
||||||
|
if len(images1) == 0 or len(images2) == 0:
|
||||||
|
raise ValueError("至少有一个文件夹中没有图像文件")
|
||||||
|
|
||||||
|
# 提取文件夹1中的所有图像特征
|
||||||
|
features1 = []
|
||||||
|
print("正在提取文件夹1中的图像特征...")
|
||||||
|
for i, img_path in enumerate(images1):
|
||||||
|
try:
|
||||||
|
# feature = self.extract_features(img_path)
|
||||||
|
feature = self.all_dicts[img_path]
|
||||||
|
features1.append(feature.cpu().numpy())
|
||||||
|
# if (i + 1) % 10 == 0:
|
||||||
|
# print(f"已处理 {i + 1}/{len(images1)} 张图像")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {img_path} 时出错: {e}")
|
||||||
|
features1.append(None)
|
||||||
|
|
||||||
|
# 提取文件夹2中的所有图像特征
|
||||||
|
features2 = []
|
||||||
|
print("正在提取文件夹2中的图像特征...")
|
||||||
|
for i, img_path in enumerate(images2):
|
||||||
|
try:
|
||||||
|
# feature = self.extract_features(img_path)
|
||||||
|
feature = self.all_dicts[img_path]
|
||||||
|
features2.append(feature.cpu().numpy())
|
||||||
|
# if (i + 1) % 10 == 0:
|
||||||
|
# print(f"已处理 {i + 1}/{len(images2)} 张图像")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {img_path} 时出错: {e}")
|
||||||
|
features2.append(None)
|
||||||
|
|
||||||
|
# 移除处理失败的图像
|
||||||
|
valid_features1 = []
|
||||||
|
valid_images1 = []
|
||||||
|
for i, feature in enumerate(features1):
|
||||||
|
if feature is not None:
|
||||||
|
valid_features1.append(feature)
|
||||||
|
valid_images1.append(images1[i])
|
||||||
|
|
||||||
|
valid_features2 = []
|
||||||
|
valid_images2 = []
|
||||||
|
for i, feature in enumerate(features2):
|
||||||
|
if feature is not None:
|
||||||
|
valid_features2.append(feature)
|
||||||
|
valid_images2.append(images2[i])
|
||||||
|
|
||||||
|
# print(f"文件夹1中成功处理 {len(valid_features1)} 张图像")
|
||||||
|
# print(f"文件夹2中成功处理 {len(valid_features2)} 张图像")
|
||||||
|
|
||||||
|
if len(valid_features1) == 0 or len(valid_features2) == 0:
|
||||||
|
raise ValueError("没有成功处理任何图像")
|
||||||
|
|
||||||
|
# 计算相似度矩阵
|
||||||
|
print("正在计算相似度矩阵...")
|
||||||
|
similarity_matrix = cosine_similarity(valid_features1, valid_features2)
|
||||||
|
|
||||||
|
return similarity_matrix, valid_images1, valid_images2
|
||||||
|
|
||||||
|
def get_group_similarity_matrix(self, folder_path):
|
||||||
|
tracking_folder = os.sep.join([folder_path, 'tracking'])
|
||||||
|
standard_folder = os.sep.join([folder_path, 'standard_slim'])
|
||||||
|
for dir_name in os.listdir(tracking_folder):
|
||||||
|
tracking_dir = os.sep.join([tracking_folder, dir_name])
|
||||||
|
standard_dir = os.sep.join([standard_folder, dir_name])
|
||||||
|
similarity_matrix, valid_images1, valid_images2 = self.create_one_similarity_matrix(tracking_dir,
|
||||||
|
standard_dir)
|
||||||
|
mean_similarity = np.mean(similarity_matrix)
|
||||||
|
std_similarity = np.std(similarity_matrix)
|
||||||
|
max_similarity = np.max(similarity_matrix)
|
||||||
|
min_similarity = np.min(similarity_matrix)
|
||||||
|
print(f"文件夹 {dir_name} 的相似度矩阵已计算完成 "
|
||||||
|
f"均值:{mean_similarity} 标准差:{std_similarity} 最大值:{max_similarity} 最小值:{min_similarity}")
|
||||||
|
result = f"{os.path.basename(standard_folder)} {dir_name} {mean_similarity:.3f} {std_similarity:.3f} {max_similarity:.3f} {min_similarity:.3f}"
|
||||||
|
with open(self.conf['data']['result_txt'], 'a') as f:
|
||||||
|
f.write(result + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def read_result_txt():
|
||||||
|
parts = []
|
||||||
|
value_num = 2
|
||||||
|
with open('../configs/pic_pic_similar.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
f.close()
|
||||||
|
with open(conf['data']['result_txt'], 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
parts.append(line.split(' '))
|
||||||
|
parts = np.array(parts)
|
||||||
|
print(parts)
|
||||||
|
labels = ['Mean', 'Std', 'Max', 'Min']
|
||||||
|
while value_num < 6:
|
||||||
|
dicts = {}
|
||||||
|
for barcode, value in zip(parts[:, 1], parts[:, value_num]):
|
||||||
|
if barcode in dicts:
|
||||||
|
dicts[barcode].append(float(value))
|
||||||
|
else:
|
||||||
|
dicts[barcode] = [float(value)]
|
||||||
|
get_histogram(dicts, labels[value_num - 2])
|
||||||
|
value_num += 1
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_histogram(data, label=None):
|
||||||
|
# 准备数据
|
||||||
|
categories = list(data.keys())
|
||||||
|
values1 = [data[cat][0] for cat in categories] # 第一个值
|
||||||
|
values2 = [data[cat][1] for cat in categories] # 第二个值
|
||||||
|
|
||||||
|
# 设置柱状图的位置
|
||||||
|
x = np.arange(len(categories)) # 标签位置
|
||||||
|
width = 0.35 # 柱状图的宽度
|
||||||
|
|
||||||
|
# 创建图形和轴
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 6))
|
||||||
|
|
||||||
|
# 绘制柱状图
|
||||||
|
bars1 = ax.bar(x - width / 2, values1, width, label='standard', color='red', alpha=0.7)
|
||||||
|
bars2 = ax.bar(x + width / 2, values2, width, label='standard_slim', color='green', alpha=0.7)
|
||||||
|
|
||||||
|
# 在每个柱状图上显示数值
|
||||||
|
for bar in bars1:
|
||||||
|
height = bar.get_height()
|
||||||
|
ax.annotate(f'{height:.3f}',
|
||||||
|
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||||
|
xytext=(0, 3), # 3点垂直偏移
|
||||||
|
textcoords="offset points",
|
||||||
|
ha='center', va='bottom',
|
||||||
|
fontsize=12)
|
||||||
|
|
||||||
|
for bar in bars2:
|
||||||
|
height = bar.get_height()
|
||||||
|
ax.annotate(f'{height:.3f}',
|
||||||
|
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||||
|
xytext=(0, 3), # 3点垂直偏移
|
||||||
|
textcoords="offset points",
|
||||||
|
ha='center', va='bottom',
|
||||||
|
fontsize=12)
|
||||||
|
|
||||||
|
# 添加标签和标题
|
||||||
|
if label is None:
|
||||||
|
label = ''
|
||||||
|
ax.set_xlabel('barcode')
|
||||||
|
ax.set_ylabel('Values')
|
||||||
|
ax.set_title(label)
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(categories)
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
# 添加网格
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 调整布局并显示
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# picTopic_matrix = picDirSimilarAnalysis()
|
||||||
|
# picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix')
|
||||||
|
read_result_txt()
|
107
tools/similar_analysis.py
Normal file
107
tools/similar_analysis.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
from configs.utils import trainer_tools
|
||||||
|
from test_ori import group_image, featurize, cosin_metric
|
||||||
|
from tools.dataset import get_transform
|
||||||
|
from tools.getHeatMap import cal_cam
|
||||||
|
from tools.image_joint import merge_imgs
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from collections import ChainMap
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarAnalysis:
|
||||||
|
def __init__(self):
|
||||||
|
with open('../configs/similar_analysis.yml', 'r') as f:
|
||||||
|
self.conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
self.model = self.initialize_model(self.conf)
|
||||||
|
_, self.test_transform = get_transform(self.conf)
|
||||||
|
self.cam = cal_cam(self.model, self.conf)
|
||||||
|
|
||||||
|
def initialize_model(self, conf):
|
||||||
|
"""初始化模型和度量方法"""
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]()
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
try:
|
||||||
|
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device']))
|
||||||
|
except:
|
||||||
|
state_dict = torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device'])
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
new_key = k.replace("module.", "")
|
||||||
|
new_state_dict[new_key] = v
|
||||||
|
model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
def get_feature(self, img_pth):
|
||||||
|
group = group_image([img_pth], self.conf['data']['val_batch_size'])
|
||||||
|
feature = featurize(group[0], self.test_transform, self.model, self.conf['base']['device'])
|
||||||
|
return feature
|
||||||
|
|
||||||
|
def get_similarity(self, feature_dict1, feature_dict2):
|
||||||
|
similarity = cosin_metric(feature_dict1, feature_dict2)
|
||||||
|
print(f"Similarity: {similarity}")
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
def get_feature_map(self, all_imgs):
|
||||||
|
feature_dicts = {}
|
||||||
|
for img_pth in all_imgs:
|
||||||
|
print(f"Processing {img_pth}")
|
||||||
|
feature_dict = self.get_feature(img_pth)
|
||||||
|
feature_dicts = dict(ChainMap(feature_dict, feature_dicts))
|
||||||
|
return feature_dicts
|
||||||
|
|
||||||
|
def get_image_map(self):
|
||||||
|
all_compare_img = []
|
||||||
|
for root, dirs, files in os.walk(self.conf['data']['data_dir']):
|
||||||
|
if len(dirs) == 2:
|
||||||
|
dir_pth_1 = os.sep.join([root, dirs[0]])
|
||||||
|
dir_pth_2 = os.sep.join([root, dirs[1]])
|
||||||
|
for img_name_1 in os.listdir(dir_pth_1):
|
||||||
|
for img_name_2 in os.listdir(dir_pth_2):
|
||||||
|
all_compare_img.append((os.sep.join([dir_pth_1, img_name_1]),
|
||||||
|
os.sep.join([dir_pth_2, img_name_2])))
|
||||||
|
return all_compare_img
|
||||||
|
|
||||||
|
def create_total_feature(self):
|
||||||
|
all_imgs = []
|
||||||
|
for root, dirs, files in os.walk(self.conf['data']['data_dir']):
|
||||||
|
if len(dirs) == 2:
|
||||||
|
for dir_name in dirs:
|
||||||
|
dir_pth = os.sep.join([root, dir_name])
|
||||||
|
for img_name in os.listdir(dir_pth):
|
||||||
|
all_imgs.append(os.sep.join([dir_pth, img_name]))
|
||||||
|
return all_imgs
|
||||||
|
|
||||||
|
def get_contrast_result(self, feature_dicts, all_compare_img):
|
||||||
|
for img_pth1, img_pth2 in all_compare_img:
|
||||||
|
feature_dict1 = feature_dicts[img_pth1]
|
||||||
|
feature_dict2 = feature_dicts[img_pth2]
|
||||||
|
similarity = self.get_similarity(feature_dict1.cpu().numpy(),
|
||||||
|
feature_dict2.cpu().numpy())
|
||||||
|
dir_name = img_pth1.split('/')[-3]
|
||||||
|
save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name])
|
||||||
|
if similarity > 0.7:
|
||||||
|
merge_imgs(img_pth1,
|
||||||
|
img_pth2,
|
||||||
|
self.conf,
|
||||||
|
similarity,
|
||||||
|
label=None,
|
||||||
|
cam=self.cam,
|
||||||
|
save_path=save_path)
|
||||||
|
print(similarity)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ana = SimilarAnalysis()
|
||||||
|
all_imgs = ana.create_total_feature()
|
||||||
|
feature_dicts = ana.get_feature_map(all_imgs)
|
||||||
|
all_compare_img = ana.get_image_map()
|
||||||
|
ana.get_contrast_result(feature_dicts, all_compare_img)
|
@ -50,7 +50,7 @@ class FeatureExtractor:
|
|||||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = resnet18().to(self.conf['base']['device'])
|
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||||
|
|
||||||
# Handle multi-GPU case
|
# Handle multi-GPU case
|
||||||
if conf['base']['distributed']:
|
if conf['base']['distributed']:
|
||||||
@ -407,5 +407,5 @@ if __name__ == "__main__":
|
|||||||
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||||||
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||||||
filter=column_values,
|
filter=column_values,
|
||||||
create_single_json=False) # False
|
create_single_json=conf['save']['create_single_json']) # False
|
||||||
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
||||||
|
422
train_compare.py
422
train_compare.py
@ -3,140 +3,360 @@ import os.path as osp
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from model.loss import FocalLoss
|
from model.loss import FocalLoss
|
||||||
from tools.dataset import load_data
|
from tools.dataset import load_data, MultiEpochsDataLoader
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
with open('configs/scatter.yml', 'r') as f:
|
|
||||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
# Data Setup
|
def load_configuration(config_path='configs/compare.yml'):
|
||||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
"""加载配置文件"""
|
||||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
with open(config_path, 'r') as f:
|
||||||
|
return yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
tr_tools = trainer_tools(conf)
|
|
||||||
backbone_mapping = tr_tools.get_backbone()
|
|
||||||
metric_mapping = tr_tools.get_metric(class_num)
|
|
||||||
|
|
||||||
if conf['models']['backbone'] in backbone_mapping:
|
def initialize_model_and_metric(conf, class_num):
|
||||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
"""初始化模型和度量方法"""
|
||||||
else:
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
metric_mapping = tr_tools.get_metric(class_num)
|
||||||
|
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]()
|
||||||
|
else:
|
||||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
|
||||||
if conf['training']['metric'] in metric_mapping:
|
if conf['training']['metric'] in metric_mapping:
|
||||||
metric = metric_mapping[conf['training']['metric']]()
|
metric = metric_mapping[conf['training']['metric']]()
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
return model, metric
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
||||||
model = nn.DataParallel(model)
|
|
||||||
metric = nn.DataParallel(metric)
|
|
||||||
|
|
||||||
# Training Setup
|
|
||||||
if conf['training']['loss'] == 'focal_loss':
|
|
||||||
criterion = FocalLoss(gamma=2)
|
|
||||||
else:
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
def setup_optimizer_and_scheduler(conf, model, metric):
|
||||||
if conf['training']['optimizer'] in optimizer_mapping:
|
"""设置优化器和学习率调度器"""
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||||
|
|
||||||
|
if conf['training']['optimizer'] in optimizer_mapping:
|
||||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||||
scheduler = optim.lr_scheduler.StepLR(
|
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||||
optimizer,
|
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||||
step_size=conf['training']['lr_step'],
|
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||||
gamma=conf['training']['lr_decay']
|
conf['training']['scheduler']))
|
||||||
)
|
return optimizer, scheduler
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
|
|
||||||
# Checkpoints Setup
|
|
||||||
checkpoints = conf['training']['checkpoints']
|
|
||||||
os.makedirs(checkpoints, exist_ok=True)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def setup_loss_function(conf):
|
||||||
print('backbone>{} '.format(conf['models']['backbone']),
|
"""配置损失函数"""
|
||||||
'metric>{} '.format(conf['training']['metric']),
|
if conf['training']['loss'] == 'focal_loss':
|
||||||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
return FocalLoss(gamma=2)
|
||||||
)
|
else:
|
||||||
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||||||
|
"""执行单个训练周期"""
|
||||||
|
model.train()
|
||||||
|
train_loss = 0
|
||||||
|
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||||||
|
data = data.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
# with torch.cuda.amp.autocast():
|
||||||
|
embeddings = model(data)
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
train_loss += loss.item()
|
||||||
|
return train_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(model, metric, criterion, dataloader, device, conf):
|
||||||
|
"""执行验证"""
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||||||
|
data = data.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
embeddings = model(data)
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
return val_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, path, is_parallel):
|
||||||
|
"""保存模型权重"""
|
||||||
|
if is_parallel:
|
||||||
|
torch.save(model.module.state_dict(), path)
|
||||||
|
else:
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
|
||||||
|
|
||||||
|
def log_training_info(log_path, log_info):
|
||||||
|
"""记录训练信息到日志文件"""
|
||||||
|
with open(log_path, 'a') as f:
|
||||||
|
f.write(log_info + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_training_components(distributed=False):
|
||||||
|
"""初始化所有训练所需组件"""
|
||||||
|
# 加载配置
|
||||||
|
conf = load_configuration()
|
||||||
|
|
||||||
|
# 初始化分布式训练相关参数
|
||||||
|
components = {
|
||||||
|
'conf': conf,
|
||||||
|
'distributed': distributed,
|
||||||
|
'device': None,
|
||||||
|
'train_dataloader': None,
|
||||||
|
'val_dataloader': None,
|
||||||
|
'model': None,
|
||||||
|
'metric': None,
|
||||||
|
'criterion': None,
|
||||||
|
'optimizer': None,
|
||||||
|
'scheduler': None,
|
||||||
|
'checkpoints': None,
|
||||||
|
'scaler': None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果是非分布式训练,直接创建所有组件
|
||||||
|
if not distributed:
|
||||||
|
# 数据加载
|
||||||
|
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||||
|
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||||||
|
|
||||||
|
train_dataloader = MultiEpochsDataLoader(train_dataloader,
|
||||||
|
batch_size=conf['data']['train_batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=conf['data']['num_workers'],
|
||||||
|
pin_memory=conf['base']['pin_memory'],
|
||||||
|
drop_last=True)
|
||||||
|
val_dataloader = MultiEpochsDataLoader(val_dataloader,
|
||||||
|
batch_size=conf['data']['val_batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=conf['data']['num_workers'],
|
||||||
|
pin_memory=conf['base']['pin_memory'],
|
||||||
|
drop_last=False)
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
device = conf['base']['device']
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
# 更新组件字典
|
||||||
|
components.update({
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device
|
||||||
|
})
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def run_training_loop(components):
|
||||||
|
"""运行完整的训练循环"""
|
||||||
|
# 解包组件
|
||||||
|
conf = components['conf']
|
||||||
|
train_dataloader = components['train_dataloader']
|
||||||
|
val_dataloader = components['val_dataloader']
|
||||||
|
model = components['model']
|
||||||
|
metric = components['metric']
|
||||||
|
criterion = components['criterion']
|
||||||
|
optimizer = components['optimizer']
|
||||||
|
scheduler = components['scheduler']
|
||||||
|
checkpoints = components['checkpoints']
|
||||||
|
scaler = components['scaler']
|
||||||
|
device = components['device']
|
||||||
|
|
||||||
|
# 训练状态
|
||||||
train_losses = []
|
train_losses = []
|
||||||
val_losses = []
|
val_losses = []
|
||||||
epochs = []
|
epochs = []
|
||||||
temp_loss = 100
|
temp_loss = 100
|
||||||
|
|
||||||
if conf['training']['restore']:
|
if conf['training']['restore']:
|
||||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||||
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||||||
map_location=conf['base']['device']))
|
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
for e in range(conf['training']['epochs']):
|
for e in range(conf['training']['epochs']):
|
||||||
train_loss = 0
|
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||||||
model.train()
|
train_losses.append(train_loss_avg)
|
||||||
|
|
||||||
for train_data, train_labels in tqdm(train_dataloader,
|
|
||||||
desc="Epoch {}/{}"
|
|
||||||
.format(e, conf['training']['epochs']),
|
|
||||||
ascii=True,
|
|
||||||
total=len(train_dataloader)):
|
|
||||||
train_data = train_data.to(conf['base']['device'])
|
|
||||||
train_labels = train_labels.to(conf['base']['device'])
|
|
||||||
|
|
||||||
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
|
||||||
# pdb.set_trace()
|
|
||||||
|
|
||||||
if not conf['training']['metric'] == 'softmax':
|
|
||||||
thetas = metric(train_embeddings, train_labels) # [256,357]
|
|
||||||
else:
|
|
||||||
thetas = metric(train_embeddings)
|
|
||||||
tloss = criterion(thetas, train_labels)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
tloss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
train_loss += tloss.item()
|
|
||||||
train_lossAvg = train_loss / len(train_dataloader)
|
|
||||||
train_losses.append(train_lossAvg)
|
|
||||||
epochs.append(e)
|
epochs.append(e)
|
||||||
val_loss = 0
|
|
||||||
model.eval()
|
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||||||
with torch.no_grad():
|
val_losses.append(val_loss_avg)
|
||||||
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
|
||||||
ascii=True, total=len(val_dataloader)):
|
if val_loss_avg < temp_loss:
|
||||||
val_data = val_data.to(conf['base']['device'])
|
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||||||
val_labels = val_labels.to(conf['base']['device'])
|
temp_loss = val_loss_avg
|
||||||
val_embeddings = model(val_data).to(conf['base']['device'])
|
|
||||||
if not conf['training']['metric'] == 'softmax':
|
|
||||||
thetas = metric(val_embeddings, val_labels)
|
|
||||||
else:
|
|
||||||
thetas = metric(val_embeddings)
|
|
||||||
vloss = criterion(thetas, val_labels)
|
|
||||||
val_loss += vloss.item()
|
|
||||||
val_lossAvg = val_loss / len(val_dataloader)
|
|
||||||
val_losses.append(val_lossAvg)
|
|
||||||
if val_lossAvg < temp_loss:
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
|
||||||
else:
|
|
||||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
|
||||||
temp_loss = val_lossAvg
|
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||||
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
|
.format(datetime.now(),
|
||||||
|
e,
|
||||||
|
conf['training']['epochs'],
|
||||||
|
train_loss_avg,
|
||||||
|
val_loss_avg,
|
||||||
|
current_lr))
|
||||||
print(log_info)
|
print(log_info)
|
||||||
# 写入日志文件
|
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||||||
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
|
||||||
f.write(log_info + '\n')
|
|
||||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
|
||||||
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
# 保存最终模型
|
||||||
else:
|
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||||||
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
|
||||||
plt.plot(epochs, train_losses, color='blue')
|
# 绘制损失曲线
|
||||||
plt.plot(epochs, val_losses, color='red')
|
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||||||
# plt.savefig('lossMobilenetv3.png')
|
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||||||
|
plt.legend()
|
||||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数入口"""
|
||||||
|
# 加载配置
|
||||||
|
conf = load_configuration()
|
||||||
|
|
||||||
|
# 检查是否启用分布式训练
|
||||||
|
distributed = conf['base']['distributed']
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
# 分布式训练:使用mp.spawn启动多个进程
|
||||||
|
world_size = torch.cuda.device_count()
|
||||||
|
mp.spawn(
|
||||||
|
run_training,
|
||||||
|
args=(world_size, conf),
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 单机训练:直接运行训练流程
|
||||||
|
components = initialize_training_components(distributed=False)
|
||||||
|
run_training_loop(components)
|
||||||
|
|
||||||
|
|
||||||
|
def run_training(rank, world_size, conf):
|
||||||
|
"""实际执行训练的函数,供mp.spawn调用"""
|
||||||
|
# 初始化分布式环境
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
device = torch.device('cuda', rank)
|
||||||
|
|
||||||
|
# 获取数据集而不是DataLoader
|
||||||
|
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||||
|
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||||||
|
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
# 包装为DistributedDataParallel模型
|
||||||
|
model = DDP(model, device_ids=[rank], output_device=rank)
|
||||||
|
metric = DDP(metric, device_ids=[rank], output_device=rank)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
# 创建分布式采样器
|
||||||
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
|
||||||
|
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
|
||||||
|
train_dataloader = MultiEpochsDataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=conf['data']['train_batch_size'],
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=conf['data']['num_workers'],
|
||||||
|
pin_memory=conf['base']['pin_memory'],
|
||||||
|
drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = MultiEpochsDataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=conf['data']['val_batch_size'],
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=conf['data']['num_workers'],
|
||||||
|
pin_memory=conf['base']['pin_memory'],
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建组件字典
|
||||||
|
components = {
|
||||||
|
'conf': conf,
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device,
|
||||||
|
'distributed': True # 因为是在mp.spawn中运行
|
||||||
|
}
|
||||||
|
|
||||||
|
# 运行训练循环
|
||||||
|
run_training_loop(components)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
233
train_compare.py.bak
Normal file
233
train_compare.py.bak
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from model.loss import FocalLoss
|
||||||
|
from tools.dataset import load_data
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from configs import trainer_tools
|
||||||
|
import yaml
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def load_configuration(config_path='configs/scatter.yml'):
|
||||||
|
"""加载配置文件"""
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
return yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_model_and_metric(conf, class_num):
|
||||||
|
"""初始化模型和度量方法"""
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
metric_mapping = tr_tools.get_metric(class_num)
|
||||||
|
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]()
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
|
||||||
|
if conf['training']['metric'] in metric_mapping:
|
||||||
|
metric = metric_mapping[conf['training']['metric']]()
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||||
|
|
||||||
|
return model, metric
|
||||||
|
|
||||||
|
|
||||||
|
def setup_optimizer_and_scheduler(conf, model, metric):
|
||||||
|
"""设置优化器和学习率调度器"""
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||||
|
|
||||||
|
if conf['training']['optimizer'] in optimizer_mapping:
|
||||||
|
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||||
|
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||||
|
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||||
|
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||||
|
conf['training']['scheduler']))
|
||||||
|
return optimizer, scheduler
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_loss_function(conf):
|
||||||
|
"""配置损失函数"""
|
||||||
|
if conf['training']['loss'] == 'focal_loss':
|
||||||
|
return FocalLoss(gamma=2)
|
||||||
|
else:
|
||||||
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||||||
|
"""执行单个训练周期"""
|
||||||
|
model.train()
|
||||||
|
train_loss = 0
|
||||||
|
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||||||
|
data = data.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
embeddings = model(data)
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
train_loss += loss.item()
|
||||||
|
return train_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(model, metric, criterion, dataloader, device, conf):
|
||||||
|
"""执行验证"""
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||||||
|
data = data.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
embeddings = model(data)
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
return val_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, path, is_parallel):
|
||||||
|
"""保存模型权重"""
|
||||||
|
if is_parallel:
|
||||||
|
torch.save(model.module.state_dict(), path)
|
||||||
|
else:
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
|
||||||
|
|
||||||
|
def log_training_info(log_path, log_info):
|
||||||
|
"""记录训练信息到日志文件"""
|
||||||
|
with open(log_path, 'a') as f:
|
||||||
|
f.write(log_info + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_training_components():
|
||||||
|
"""初始化所有训练所需组件"""
|
||||||
|
# 加载配置
|
||||||
|
conf = load_configuration()
|
||||||
|
|
||||||
|
# 数据加载
|
||||||
|
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||||
|
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||||
|
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
device = conf['base']['device']
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
metric = nn.DataParallel(metric)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'conf': conf,
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_training_loop(components):
|
||||||
|
"""运行完整的训练循环"""
|
||||||
|
# 解包组件
|
||||||
|
conf = components['conf']
|
||||||
|
train_dataloader = components['train_dataloader']
|
||||||
|
val_dataloader = components['val_dataloader']
|
||||||
|
model = components['model']
|
||||||
|
metric = components['metric']
|
||||||
|
criterion = components['criterion']
|
||||||
|
optimizer = components['optimizer']
|
||||||
|
scheduler = components['scheduler']
|
||||||
|
checkpoints = components['checkpoints']
|
||||||
|
scaler = components['scaler']
|
||||||
|
device = components['device']
|
||||||
|
|
||||||
|
# 训练状态
|
||||||
|
train_losses = []
|
||||||
|
val_losses = []
|
||||||
|
epochs = []
|
||||||
|
temp_loss = 100
|
||||||
|
|
||||||
|
if conf['training']['restore']:
|
||||||
|
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||||
|
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
for e in range(conf['training']['epochs']):
|
||||||
|
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||||||
|
train_losses.append(train_loss_avg)
|
||||||
|
epochs.append(e)
|
||||||
|
|
||||||
|
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||||||
|
val_losses.append(val_loss_avg)
|
||||||
|
|
||||||
|
if val_loss_avg < temp_loss:
|
||||||
|
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||||||
|
temp_loss = val_loss_avg
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||||
|
.format(datetime.now(),
|
||||||
|
e,
|
||||||
|
conf['training']['epochs'],
|
||||||
|
train_loss_avg,
|
||||||
|
val_loss_avg,
|
||||||
|
current_lr))
|
||||||
|
print(log_info)
|
||||||
|
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||||||
|
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||||
|
|
||||||
|
# 保存最终模型
|
||||||
|
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||||||
|
|
||||||
|
# 绘制损失曲线
|
||||||
|
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||||||
|
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 初始化训练组件
|
||||||
|
components = initialize_training_components()
|
||||||
|
|
||||||
|
# 运行训练循环
|
||||||
|
run_training_loop(components)
|
Reference in New Issue
Block a user