回传数据解析,兼容v5和v10
This commit is contained in:
74
.gitignore
vendored
Normal file
74
.gitignore
vendored
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Repo-specific GitIgnore ----------------------------------------------------------------------------------------------
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.png
|
||||||
|
*.bmp
|
||||||
|
*.tif
|
||||||
|
*.tiff
|
||||||
|
*.heic
|
||||||
|
*.JPG
|
||||||
|
*.JPEG
|
||||||
|
*.PNG
|
||||||
|
*.BMP
|
||||||
|
*.TIF
|
||||||
|
*.TIFF
|
||||||
|
*.HEIC
|
||||||
|
*.mp4
|
||||||
|
*.mov
|
||||||
|
*.MOV
|
||||||
|
*.avi
|
||||||
|
*.data
|
||||||
|
*.json
|
||||||
|
*.cfg
|
||||||
|
|
||||||
|
*.rar
|
||||||
|
*.pkl
|
||||||
|
*.pickle
|
||||||
|
*.npy
|
||||||
|
*.csv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# for tracking ---------------------------------------------------------------
|
||||||
|
tracking/.git
|
||||||
|
tracking/bakeup
|
||||||
|
tracking/.gitignore
|
||||||
|
tracking/result/**/*.mp4
|
||||||
|
tracking/result/**/*.png
|
||||||
|
tracking/data/boxes_imgs/*
|
||||||
|
tracking/data/trackfeats/*
|
||||||
|
tracking/data/tracks/*
|
||||||
|
tracking/data/handlocal/*
|
||||||
|
ckpts/*
|
||||||
|
doc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Datasets -------------------------------------------------------------------------------------------------------------
|
||||||
|
coco/
|
||||||
|
coco128/
|
||||||
|
VOC/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Neural Network weights -----------------------------------------------------------------------------------------------
|
||||||
|
*.weights
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
*.pb
|
||||||
|
*.onnx
|
||||||
|
*.engine
|
||||||
|
*.mlmodel
|
||||||
|
*.torchscript
|
||||||
|
*.tflite
|
||||||
|
*.h5
|
||||||
|
*.caffemodel
|
||||||
|
*_saved_model/
|
||||||
|
*_web_model/
|
||||||
|
*_openvino_model/
|
||||||
|
*_paddle_model/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
131
.idea/workspace.xml
generated
Normal file
131
.idea/workspace.xml
generated
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="AutoImportSettings">
|
||||||
|
<option name="autoReloadType" value="SELECTIVE" />
|
||||||
|
</component>
|
||||||
|
<component name="ChangeListManager">
|
||||||
|
<list default="true" id="ba103475-3e5f-4113-9d53-799254beb8ef" name="Changes" comment="" />
|
||||||
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||||
|
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||||
|
</component>
|
||||||
|
<component name="FlaskConsoleOptions" custom-start-script="import sys sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo locals().update(ScriptInfo(create_app=None).load_app().make_shell_context()) print("Python %s on %s\nApp: %s [%s]\nInstance: %s" % (sys.version, sys.platform, app.import_name, app.env, app.instance_path))">
|
||||||
|
<envs>
|
||||||
|
<env key="FLASK_APP" value="app" />
|
||||||
|
</envs>
|
||||||
|
<option name="myCustomStartScript" value="import sys sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo locals().update(ScriptInfo(create_app=None).load_app().make_shell_context()) print("Python %s on %s\nApp: %s [%s]\nInstance: %s" % (sys.version, sys.platform, app.import_name, app.env, app.instance_path))" />
|
||||||
|
<option name="myEnvs">
|
||||||
|
<map>
|
||||||
|
<entry key="FLASK_APP" value="app" />
|
||||||
|
</map>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="MarkdownSettingsMigration">
|
||||||
|
<option name="stateVersion" value="1" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectColorInfo">{
|
||||||
|
"associatedIndex": 2
|
||||||
|
}</component>
|
||||||
|
<component name="ProjectId" id="2vpqs0oD3mrHDNf3qrBBd4BMnsc" />
|
||||||
|
<component name="ProjectViewState">
|
||||||
|
<option name="hideEmptyMiddlePackages" value="true" />
|
||||||
|
<option name="showLibraryContents" value="true" />
|
||||||
|
</component>
|
||||||
|
<component name="PropertiesComponent">{
|
||||||
|
"keyToString": {
|
||||||
|
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||||
|
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||||
|
"WebServerToolWindowFactoryState": "true",
|
||||||
|
"node.js.detected.package.eslint": "true",
|
||||||
|
"node.js.detected.package.tslint": "true",
|
||||||
|
"node.js.selected.package.eslint": "(autodetect)",
|
||||||
|
"node.js.selected.package.tslint": "(autodetect)",
|
||||||
|
"settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
|
||||||
|
"vue.rearranger.settings.migration": "true"
|
||||||
|
}
|
||||||
|
}</component>
|
||||||
|
<component name="RunManager" selected="Python.read_xlsx_filter_events">
|
||||||
|
<configuration name="pipeline" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||||
|
<module name="detecttracking_20250417_callbackdata" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||||
|
<option name="IS_MODULE_SDK" value="true" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/pipeline.py" />
|
||||||
|
<option name="PARAMETERS" value="" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<configuration name="read_xlsx_filter_events" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||||
|
<module name="detecttracking_20250417_callbackdata" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||||
|
<option name="IS_MODULE_SDK" value="true" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/read_xlsx_filter_events.py" />
|
||||||
|
<option name="PARAMETERS" value="" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<recent_temporary>
|
||||||
|
<list>
|
||||||
|
<item itemvalue="Python.read_xlsx_filter_events" />
|
||||||
|
<item itemvalue="Python.pipeline" />
|
||||||
|
</list>
|
||||||
|
</recent_temporary>
|
||||||
|
</component>
|
||||||
|
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
||||||
|
<component name="TaskManager">
|
||||||
|
<task active="true" id="Default" summary="Default task">
|
||||||
|
<changelist id="ba103475-3e5f-4113-9d53-799254beb8ef" name="Changes" comment="" />
|
||||||
|
<created>1744852597949</created>
|
||||||
|
<option name="number" value="Default" />
|
||||||
|
<option name="presentableId" value="Default" />
|
||||||
|
<updated>1744852597949</updated>
|
||||||
|
<workItem from="1744852602490" duration="17024000" />
|
||||||
|
<workItem from="1744940203193" duration="9768000" />
|
||||||
|
</task>
|
||||||
|
<servers />
|
||||||
|
</component>
|
||||||
|
<component name="TypeScriptGeneratedFilesManager">
|
||||||
|
<option name="version" value="3" />
|
||||||
|
</component>
|
||||||
|
<component name="XDebuggerManager">
|
||||||
|
<breakpoint-manager>
|
||||||
|
<breakpoints>
|
||||||
|
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||||
|
<url>file://$PROJECT_DIR$/pipeline.py</url>
|
||||||
|
<line>314</line>
|
||||||
|
<option name="timeStamp" value="3" />
|
||||||
|
</line-breakpoint>
|
||||||
|
</breakpoints>
|
||||||
|
</breakpoint-manager>
|
||||||
|
</component>
|
||||||
|
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||||
|
<SUITE FILE_PATH="coverage/detecttracking_20250417_callbackdata$read_xlsx_filter_events.coverage" NAME="read_xlsx_filter_events Coverage Results" MODIFIED="1744957439346" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||||
|
<SUITE FILE_PATH="coverage/detecttracking_20250417_callbackdata$pipeline.coverage" NAME="pipeline Coverage Results" MODIFIED="1744947271686" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||||
|
</component>
|
||||||
|
</project>
|
36
README.en.md
Normal file
36
README.en.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# YoloV5_track
|
||||||
|
|
||||||
|
#### Description
|
||||||
|
{**When you're done, you can delete the content in this README and update the file with details for others getting started with your repository**}
|
||||||
|
|
||||||
|
#### Software Architecture
|
||||||
|
Software architecture description
|
||||||
|
|
||||||
|
#### Installation
|
||||||
|
|
||||||
|
1. xxxx
|
||||||
|
2. xxxx
|
||||||
|
3. xxxx
|
||||||
|
|
||||||
|
#### Instructions
|
||||||
|
|
||||||
|
1. xxxx
|
||||||
|
2. xxxx
|
||||||
|
3. xxxx
|
||||||
|
|
||||||
|
#### Contribution
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create Feat_xxx branch
|
||||||
|
3. Commit your code
|
||||||
|
4. Create Pull Request
|
||||||
|
|
||||||
|
|
||||||
|
#### Gitee Feature
|
||||||
|
|
||||||
|
1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md
|
||||||
|
2. Gitee blog [blog.gitee.com](https://blog.gitee.com)
|
||||||
|
3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore)
|
||||||
|
4. The most valuable open source project [GVP](https://gitee.com/gvp)
|
||||||
|
5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help)
|
||||||
|
6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
|
7
README.md
Normal file
7
README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
说明:
|
||||||
|
该代码仓库是yolov5_track文件夹下代码的备份,这里的yolov5代码来至:https://github.com/ultralytics/yolov5
|
||||||
|
gitee地址为:https://gitee.com/nanjing-yimao-information/dettrack
|
||||||
|
核心模块:
|
||||||
|
track_reid.py实现:
|
||||||
|
1. yolov5检测
|
||||||
|
2. Bot-SORT用于目标跟踪
|
0
__init__.py
Normal file
0
__init__.py
Normal file
BIN
__pycache__/event_time_specify.cpython-39.pyc
Normal file
BIN
__pycache__/event_time_specify.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/export.cpython-312.pyc
Normal file
BIN
__pycache__/export.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/export.cpython-39.pyc
Normal file
BIN
__pycache__/export.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/imgs_inference.cpython-39.pyc
Normal file
BIN
__pycache__/imgs_inference.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/move_detect.cpython-312.pyc
Normal file
BIN
__pycache__/move_detect.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/move_detect.cpython-39.pyc
Normal file
BIN
__pycache__/move_detect.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/pipeline_01.cpython-312.pyc
Normal file
BIN
__pycache__/pipeline_01.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/pipeline_01.cpython-39.pyc
Normal file
BIN
__pycache__/pipeline_01.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/track_reid.cpython-312.pyc
Normal file
BIN
__pycache__/track_reid.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/track_reid.cpython-39.pyc
Normal file
BIN
__pycache__/track_reid.cpython-39.pyc
Normal file
Binary file not shown.
359
bakeup/pipeline.py
Normal file
359
bakeup/pipeline.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Sun Sep 29 08:59:21 2024
|
||||||
|
|
||||||
|
@author: ym
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
# import sys
|
||||||
|
import cv2
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
from track_reid import yolo_resnet_tracker, yolov10_resnet_tracker
|
||||||
|
|
||||||
|
from tracking.dotrack.dotracks_back import doBackTracks
|
||||||
|
from tracking.dotrack.dotracks_front import doFrontTracks
|
||||||
|
from tracking.utils.drawtracks import plot_frameID_y2, draw_all_trajectories
|
||||||
|
from utils.getsource import get_image_pairs, get_video_pairs
|
||||||
|
from tracking.utils.read_data import read_similar
|
||||||
|
|
||||||
|
|
||||||
|
def save_subimgs(imgdict, boxes, spath, ctype, featdict = None):
|
||||||
|
'''
|
||||||
|
当前 box 特征和该轨迹前一个 box 特征的相似度,可用于和跟踪序列中的相似度进行比较
|
||||||
|
'''
|
||||||
|
boxes = boxes[np.argsort(boxes[:, 7])]
|
||||||
|
for i in range(len(boxes)):
|
||||||
|
simi = None
|
||||||
|
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
|
||||||
|
|
||||||
|
if i>0:
|
||||||
|
_, fid0, bid0 = int(boxes[i-1, 4]), int(boxes[i-1, 7]), int(boxes[i-1, 8])
|
||||||
|
if f"{fid0}_{bid0}" in featdict.keys() and f"{fid}_{bid}" in featdict.keys():
|
||||||
|
feat0 = featdict[f"{fid0}_{bid0}"]
|
||||||
|
feat1 = featdict[f"{fid}_{bid}"]
|
||||||
|
simi = 1 - np.maximum(0.0, cdist(feat0[None, :], feat1[None, :], "cosine"))[0][0]
|
||||||
|
|
||||||
|
img = imgdict[f"{fid}_{bid}"]
|
||||||
|
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png"
|
||||||
|
if simi is not None:
|
||||||
|
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simi:.2f}.png"
|
||||||
|
|
||||||
|
cv2.imwrite(imgpath, img)
|
||||||
|
|
||||||
|
|
||||||
|
def save_subimgs_1(imgdict, boxes, spath, ctype, simidict = None):
|
||||||
|
'''
|
||||||
|
当前 box 特征和该轨迹 smooth_feat 特征的相似度, yolo_resnet_tracker 函数中,
|
||||||
|
采用该方式记录特征相似度
|
||||||
|
'''
|
||||||
|
for i in range(len(boxes)):
|
||||||
|
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
|
||||||
|
|
||||||
|
key = f"{fid}_{bid}"
|
||||||
|
img = imgdict[key]
|
||||||
|
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png"
|
||||||
|
if simidict is not None and key in simidict.keys():
|
||||||
|
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simidict[key]:.2f}.png"
|
||||||
|
|
||||||
|
cv2.imwrite(imgpath, img)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline(
|
||||||
|
eventpath,
|
||||||
|
savepath,
|
||||||
|
SourceType,
|
||||||
|
weights,
|
||||||
|
YoloVersion="V5"
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
eventpath: 单个事件的存储路径
|
||||||
|
|
||||||
|
'''
|
||||||
|
optdict = {}
|
||||||
|
optdict["weights"] = weights
|
||||||
|
|
||||||
|
if SourceType == "video":
|
||||||
|
vpaths = get_video_pairs(eventpath)
|
||||||
|
elif SourceType == "image":
|
||||||
|
vpaths = get_image_pairs(eventpath)
|
||||||
|
event_tracks = []
|
||||||
|
|
||||||
|
## 构造购物事件字典
|
||||||
|
evtname = Path(eventpath).stem
|
||||||
|
barcode = evtname.split('_')[-1] if len(evtname.split('_'))>=2 \
|
||||||
|
and len(evtname.split('_')[-1])>=8 \
|
||||||
|
and evtname.split('_')[-1].isdigit() else ''
|
||||||
|
'''事件结果存储文件夹'''
|
||||||
|
if not savepath:
|
||||||
|
savepath = Path(__file__).resolve().parents[0] / "events_result"
|
||||||
|
|
||||||
|
savepath_pipeline = Path(savepath) / Path("Yolos_Tracking") / evtname
|
||||||
|
|
||||||
|
|
||||||
|
"""ShoppingDict pickle 文件保存地址 """
|
||||||
|
savepath_spdict = Path(savepath) / "ShoppingDict_pkfile"
|
||||||
|
if not savepath_spdict.exists():
|
||||||
|
savepath_spdict.mkdir(parents=True, exist_ok=True)
|
||||||
|
pf_path = Path(savepath_spdict) / Path(str(evtname)+".pickle")
|
||||||
|
|
||||||
|
# if pf_path.exists():
|
||||||
|
# print(f"Pickle file have saved: {evtname}.pickle")
|
||||||
|
# return
|
||||||
|
|
||||||
|
'''====================== 构造 ShoppingDict 模块 ======================='''
|
||||||
|
ShoppingDict = {"eventPath": eventpath,
|
||||||
|
"eventName": evtname,
|
||||||
|
"barcode": barcode,
|
||||||
|
"eventType": '', # "input", "output", "other"
|
||||||
|
"frontCamera": {},
|
||||||
|
"backCamera": {},
|
||||||
|
"one2n": [] #
|
||||||
|
}
|
||||||
|
yrtDict = {}
|
||||||
|
|
||||||
|
|
||||||
|
procpath = Path(eventpath).joinpath('process.data')
|
||||||
|
if procpath.is_file():
|
||||||
|
SimiDict = read_similar(procpath)
|
||||||
|
ShoppingDict["one2n"] = SimiDict['one2n']
|
||||||
|
|
||||||
|
|
||||||
|
for vpath in vpaths:
|
||||||
|
'''================= 1. 构造相机事件字典 ================='''
|
||||||
|
CameraEvent = {"cameraType": '', # "front", "back"
|
||||||
|
"videoPath": '',
|
||||||
|
"imagePaths": [],
|
||||||
|
"yoloResnetTracker": [],
|
||||||
|
"tracking": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(vpath, list):
|
||||||
|
CameraEvent["imagePaths"] = vpath
|
||||||
|
bname = os.path.basename(vpath[0])
|
||||||
|
if not isinstance(vpath, list):
|
||||||
|
CameraEvent["videoPath"] = vpath
|
||||||
|
bname = os.path.basename(vpath).split('.')[0]
|
||||||
|
if bname.split('_')[0] == "0" or bname.find('back')>=0:
|
||||||
|
CameraEvent["cameraType"] = "back"
|
||||||
|
if bname.split('_')[0] == "1" or bname.find('front')>=0:
|
||||||
|
CameraEvent["cameraType"] = "front"
|
||||||
|
|
||||||
|
'''================= 2. 事件结果存储文件夹 ================='''
|
||||||
|
if isinstance(vpath, list):
|
||||||
|
savepath_pipeline_imgs = savepath_pipeline / Path("images")
|
||||||
|
else:
|
||||||
|
savepath_pipeline_imgs = savepath_pipeline / Path(str(Path(vpath).stem))
|
||||||
|
|
||||||
|
if not savepath_pipeline_imgs.exists():
|
||||||
|
savepath_pipeline_imgs.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
savepath_pipeline_subimgs = savepath_pipeline / Path("subimgs")
|
||||||
|
if not savepath_pipeline_subimgs.exists():
|
||||||
|
savepath_pipeline_subimgs.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
'''================= 3. Yolo + Resnet + Tracker ================='''
|
||||||
|
optdict["source"] = vpath
|
||||||
|
optdict["save_dir"] = savepath_pipeline_imgs
|
||||||
|
optdict["is_save_img"] = True
|
||||||
|
optdict["is_save_video"] = True
|
||||||
|
|
||||||
|
|
||||||
|
if YoloVersion == "V5":
|
||||||
|
yrtOut = yolo_resnet_tracker(**optdict)
|
||||||
|
elif YoloVersion == "V10":
|
||||||
|
yrtOut = yolov10_resnet_tracker(**optdict)
|
||||||
|
|
||||||
|
|
||||||
|
yrtOut_save = []
|
||||||
|
for frdict in yrtOut:
|
||||||
|
fr_dict = {}
|
||||||
|
for k, v in frdict.items():
|
||||||
|
if k != "imgs":
|
||||||
|
fr_dict[k]=v
|
||||||
|
yrtOut_save.append(fr_dict)
|
||||||
|
CameraEvent["yoloResnetTracker"] = yrtOut_save
|
||||||
|
|
||||||
|
# CameraEvent["yoloResnetTracker"] = yrtOut
|
||||||
|
|
||||||
|
'''================= 4. tracking ================='''
|
||||||
|
'''(1) 生成用于 tracking 模块的 boxes、feats'''
|
||||||
|
bboxes = np.empty((0, 6), dtype=np.float64)
|
||||||
|
trackerboxes = np.empty((0, 9), dtype=np.float64)
|
||||||
|
trackefeats = {}
|
||||||
|
for frameDict in yrtOut:
|
||||||
|
tboxes = frameDict["tboxes"]
|
||||||
|
ffeats = frameDict["feats"]
|
||||||
|
|
||||||
|
boxes = frameDict["bboxes"]
|
||||||
|
bboxes = np.concatenate((bboxes, np.array(boxes)), axis=0)
|
||||||
|
trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0)
|
||||||
|
for i in range(len(tboxes)):
|
||||||
|
fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8])
|
||||||
|
trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]})
|
||||||
|
|
||||||
|
|
||||||
|
'''(2) tracking, 后摄'''
|
||||||
|
if CameraEvent["cameraType"] == "back":
|
||||||
|
vts = doBackTracks(trackerboxes, trackefeats)
|
||||||
|
vts.classify()
|
||||||
|
event_tracks.append(("back", vts))
|
||||||
|
|
||||||
|
CameraEvent["tracking"] = vts
|
||||||
|
ShoppingDict["backCamera"] = CameraEvent
|
||||||
|
|
||||||
|
yrtDict["backyrt"] = yrtOut
|
||||||
|
|
||||||
|
'''(2) tracking, 前摄'''
|
||||||
|
if CameraEvent["cameraType"] == "front":
|
||||||
|
vts = doFrontTracks(trackerboxes, trackefeats)
|
||||||
|
vts.classify()
|
||||||
|
event_tracks.append(("front", vts))
|
||||||
|
|
||||||
|
CameraEvent["tracking"] = vts
|
||||||
|
ShoppingDict["frontCamera"] = CameraEvent
|
||||||
|
|
||||||
|
yrtDict["frontyrt"] = yrtOut
|
||||||
|
|
||||||
|
'''========================== 保存模块 ================================='''
|
||||||
|
'''(1) 保存 ShoppingDict 事件'''
|
||||||
|
with open(str(pf_path), 'wb') as f:
|
||||||
|
pickle.dump(ShoppingDict, f)
|
||||||
|
|
||||||
|
'''(2) 保存 Tracking 输出的运动轨迹子图,并记录相似度'''
|
||||||
|
for CamerType, vts in event_tracks:
|
||||||
|
if len(vts.tracks)==0: continue
|
||||||
|
if CamerType == 'front':
|
||||||
|
# yolos = ShoppingDict["frontCamera"]["yoloResnetTracker"]
|
||||||
|
|
||||||
|
yolos = yrtDict["frontyrt"]
|
||||||
|
ctype = 1
|
||||||
|
if CamerType == 'back':
|
||||||
|
# yolos = ShoppingDict["backCamera"]["yoloResnetTracker"]
|
||||||
|
|
||||||
|
yolos = yrtDict["backyrt"]
|
||||||
|
ctype = 0
|
||||||
|
|
||||||
|
imgdict, featdict, simidict = {}, {}, {}
|
||||||
|
for y in yolos:
|
||||||
|
imgdict.update(y["imgs"])
|
||||||
|
featdict.update(y["feats"])
|
||||||
|
simidict.update(y["featsimi"])
|
||||||
|
|
||||||
|
for track in vts.Residual:
|
||||||
|
if isinstance(track, np.ndarray):
|
||||||
|
save_subimgs(imgdict, track, savepath_pipeline_subimgs, ctype, featdict)
|
||||||
|
else:
|
||||||
|
save_subimgs(imgdict, track.slt_boxes, savepath_pipeline_subimgs, ctype, featdict)
|
||||||
|
|
||||||
|
'''(3) 轨迹显示与保存'''
|
||||||
|
illus = [None, None]
|
||||||
|
for CamerType, vts in event_tracks:
|
||||||
|
if len(vts.tracks)==0: continue
|
||||||
|
|
||||||
|
if CamerType == 'front':
|
||||||
|
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png")
|
||||||
|
|
||||||
|
h, w = edgeline.shape[:2]
|
||||||
|
# nh, nw = h//2, w//2
|
||||||
|
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
|
||||||
|
illus[0] = img_tracking
|
||||||
|
|
||||||
|
plt = plot_frameID_y2(vts)
|
||||||
|
plt.savefig(os.path.join(savepath_pipeline, "front_y2.png"))
|
||||||
|
|
||||||
|
if CamerType == 'back':
|
||||||
|
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png")
|
||||||
|
|
||||||
|
h, w = edgeline.shape[:2]
|
||||||
|
# nh, nw = h//2, w//2
|
||||||
|
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
|
||||||
|
illus[1] = img_tracking
|
||||||
|
|
||||||
|
illus = [im for im in illus if im is not None]
|
||||||
|
if len(illus):
|
||||||
|
img_cat = np.concatenate(illus, axis = 1)
|
||||||
|
if len(illus)==2:
|
||||||
|
H, W = img_cat.shape[:2]
|
||||||
|
cv2.line(img_cat, (int(W/2), 0), (int(W/2), int(H)), (128, 128, 255), 3)
|
||||||
|
|
||||||
|
trajpath = os.path.join(savepath_pipeline, "trajectory.png")
|
||||||
|
cv2.imwrite(trajpath, img_cat)
|
||||||
|
|
||||||
|
def execute_pipeline(evtdir = r"D:\datasets\ym\后台数据\unzip",
|
||||||
|
source_type = "video", # video, image,
|
||||||
|
save_path = r"D:\work\result_pipeline",
|
||||||
|
yolo_ver = "V10", # V10, V5
|
||||||
|
|
||||||
|
weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' ,
|
||||||
|
weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt',
|
||||||
|
k=0
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
运行函数 pipeline(),遍历事件文件夹,每个文件夹是一个事件
|
||||||
|
'''
|
||||||
|
parmDict = {}
|
||||||
|
parmDict["SourceType"] = source_type
|
||||||
|
parmDict["savepath"] = save_path
|
||||||
|
parmDict["YoloVersion"] = yolo_ver
|
||||||
|
if parmDict["YoloVersion"] == "V5":
|
||||||
|
parmDict["weights"] = weight_yolo_v5
|
||||||
|
elif parmDict["YoloVersion"] == "V10":
|
||||||
|
parmDict["weights"] = weight_yolo_v10
|
||||||
|
|
||||||
|
evtdir = Path(evtdir)
|
||||||
|
errEvents = []
|
||||||
|
for item in evtdir.iterdir():
|
||||||
|
if item.is_dir():
|
||||||
|
item = evtdir/Path("20250310-175352-741")
|
||||||
|
parmDict["eventpath"] = item
|
||||||
|
pipeline(**parmDict)
|
||||||
|
# try:
|
||||||
|
# pipeline(**parmDict)
|
||||||
|
# except Exception as e:
|
||||||
|
# errEvents.append(str(item))
|
||||||
|
k+=1
|
||||||
|
if k==1:
|
||||||
|
break
|
||||||
|
|
||||||
|
errfile = os.path.join(parmDict["savepath"], 'error_events.txt')
|
||||||
|
with open(errfile, 'w', encoding='utf-8') as f:
|
||||||
|
for line in errEvents:
|
||||||
|
f.write(line + '\n')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
execute_pipeline()
|
||||||
|
|
||||||
|
# spath_v10 = r"D:\work\result_pipeline_v10"
|
||||||
|
# spath_v5 = r"D:\work\result_pipeline_v5"
|
||||||
|
# execute_pipeline(save_path=spath_v10, yolo_ver="V10")
|
||||||
|
# execute_pipeline(save_path=spath_v5, yolo_ver="V5")
|
||||||
|
|
||||||
|
datapath = r'/home/wqg/dataset/test_dataset/base_dataset/single_event/source/'
|
||||||
|
savepath = r'/home/wqg/dataset/pipeline/contrast/single_event_V5'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
execute_pipeline(evtdir = datapath,
|
||||||
|
DataType = "raw", # raw, pkl
|
||||||
|
kk=1,
|
||||||
|
source_type = "video", # video, image,
|
||||||
|
save_path = savepath,
|
||||||
|
yolo_ver = "V10", # V10, V5
|
||||||
|
weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' ,
|
||||||
|
weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt',
|
||||||
|
saveimages = False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
629
bakeup/track_reid_20240515.py
Normal file
629
bakeup/track_reid_20240515.py
Normal file
@ -0,0 +1,629 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||||
|
|
||||||
|
Usage - sources:
|
||||||
|
$ python detect.py --weights yolov5s.pt --source 0 # webcam
|
||||||
|
img.jpg # image
|
||||||
|
vid.mp4 # video
|
||||||
|
screen # screenshot
|
||||||
|
path/ # directory
|
||||||
|
list.txt # list of images
|
||||||
|
list.streams # list of streams
|
||||||
|
'path/*.jpg' # glob
|
||||||
|
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||||
|
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||||
|
|
||||||
|
Usage - formats:
|
||||||
|
$ python detect.py --weights yolov5s.pt # PyTorch
|
||||||
|
yolov5s.torchscript # TorchScript
|
||||||
|
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||||
|
yolov5s_openvino_model # OpenVINO
|
||||||
|
yolov5s.engine # TensorRT
|
||||||
|
yolov5s.mlmodel # CoreML (macOS-only)
|
||||||
|
yolov5s_saved_model # TensorFlow SavedModel
|
||||||
|
yolov5s.pb # TensorFlow GraphDef
|
||||||
|
yolov5s.tflite # TensorFlow Lite
|
||||||
|
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
||||||
|
yolov5s_paddle_model # PaddlePaddle
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import torch
|
||||||
|
|
||||||
|
FILE = Path(__file__).resolve()
|
||||||
|
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||||
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||||
|
|
||||||
|
from models.common import DetectMultiBackend
|
||||||
|
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||||
|
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||||
|
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
|
||||||
|
from utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
|
'''集成跟踪模块,输出跟踪结果文件 .npy'''
|
||||||
|
# from ultralytics.engine.results import Boxes # Results
|
||||||
|
# from ultralytics.utils import IterableSimpleNamespace, yaml_load
|
||||||
|
from tracking.utils.plotting import Annotator, colors
|
||||||
|
from tracking.utils import Boxes, IterableSimpleNamespace, yaml_load, boxes_add_fid
|
||||||
|
from tracking.trackers import BOTSORT, BYTETracker
|
||||||
|
from tracking.utils.showtrack import drawtracks
|
||||||
|
from hands.hand_inference import hand_pose
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# from tracking.trackers.reid.reid_interface import ReIDInterface
|
||||||
|
# from tracking.trackers.reid.config import config as ReIDConfig
|
||||||
|
# ReIDEncoder = ReIDInterface(ReIDConfig)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# tracker_yaml = r"./tracking/trackers/cfg/botsort.yaml"
|
||||||
|
|
||||||
|
def init_trackers(tracker_yaml = None, bs=1):
|
||||||
|
"""
|
||||||
|
Initialize trackers for object tracking during prediction.
|
||||||
|
"""
|
||||||
|
# tracker_yaml = r"./tracking/trackers/cfg/botsort.yaml"
|
||||||
|
|
||||||
|
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||||
|
|
||||||
|
cfg = IterableSimpleNamespace(**yaml_load(tracker_yaml))
|
||||||
|
trackers = []
|
||||||
|
for _ in range(bs):
|
||||||
|
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
|
||||||
|
trackers.append(tracker)
|
||||||
|
|
||||||
|
return trackers
|
||||||
|
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
|
def run(
|
||||||
|
weights=ROOT / 'yolov5s.pt', # model path or triton URL
|
||||||
|
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
|
||||||
|
|
||||||
|
project=ROOT / 'runs/detect', # save results to project/name
|
||||||
|
name='exp', # save results to project/name
|
||||||
|
|
||||||
|
tracker_yaml = "./tracking/trackers/cfg/botsort.yaml",
|
||||||
|
imgsz=(640, 640), # inference size (height, width)
|
||||||
|
conf_thres=0.25, # confidence threshold
|
||||||
|
iou_thres=0.45, # NMS IOU threshold
|
||||||
|
max_det=1000, # maximum detections per image
|
||||||
|
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||||
|
view_img=False, # show results
|
||||||
|
save_txt=False, # save results to *.txt
|
||||||
|
save_csv=False, # save results in CSV format
|
||||||
|
save_conf=False, # save confidences in --save-txt labels
|
||||||
|
save_crop=False, # save cropped prediction boxes
|
||||||
|
nosave=False, # do not save images/videos
|
||||||
|
classes=None, # filter by class: --class 0, or --class 0 2 3
|
||||||
|
agnostic_nms=False, # class-agnostic NMS
|
||||||
|
augment=False, # augmented inference
|
||||||
|
visualize=False, # visualize features
|
||||||
|
update=False, # update all models
|
||||||
|
exist_ok=False, # existing project/name ok, do not increment
|
||||||
|
line_thickness=3, # bounding box thickness (pixels)
|
||||||
|
hide_labels=False, # hide labels
|
||||||
|
hide_conf=False, # hide confidencesL
|
||||||
|
half=False, # use FP16 half-precision inference
|
||||||
|
dnn=False, # use OpenCV DNN for ONNX inference
|
||||||
|
vid_stride=1, # video frame-rate stride
|
||||||
|
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||||
|
):
|
||||||
|
source = str(source)
|
||||||
|
# filename = os.path.split(source)[-1]
|
||||||
|
|
||||||
|
save_img = not nosave and not source.endswith('.txt') # save inference images
|
||||||
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||||
|
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||||
|
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||||
|
screenshot = source.lower().startswith('screen')
|
||||||
|
if is_url and is_file:
|
||||||
|
source = check_file(source) # download
|
||||||
|
|
||||||
|
save_dir = Path(project) / Path(source).stem
|
||||||
|
if save_dir.exists():
|
||||||
|
print(Path(source).stem)
|
||||||
|
# return
|
||||||
|
|
||||||
|
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
|
||||||
|
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
||||||
|
else:
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
device = select_device(device)
|
||||||
|
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
|
||||||
|
stride, names, pt = model.stride, model.names, model.pt
|
||||||
|
imgsz = check_img_size(imgsz, s=stride) # check image size
|
||||||
|
|
||||||
|
# Dataloader
|
||||||
|
bs = 1 # batch_size
|
||||||
|
|
||||||
|
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||||
|
vid_path, vid_writer = [None] * bs, [None] * bs
|
||||||
|
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
||||||
|
seen, dt = 0, (Profile(), Profile(), Profile())
|
||||||
|
|
||||||
|
tracker = init_trackers(tracker_yaml, bs)[0]
|
||||||
|
|
||||||
|
handpose = hand_pose()
|
||||||
|
handlocals_dict = {}
|
||||||
|
|
||||||
|
boxes_and_imgs = []
|
||||||
|
track_boxes = np.empty((0, 9), dtype = np.float32)
|
||||||
|
det_boxes = np.empty((0, 9), dtype = np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
features_dict = {}
|
||||||
|
for path, im, im0s, vid_cap, s in dataset:
|
||||||
|
if save_img and 'imgshow' not in locals().keys():
|
||||||
|
imgshow = im0s.copy()
|
||||||
|
|
||||||
|
## ============================= tracking 功能只处理视频,writed by WQG
|
||||||
|
if dataset.mode == 'image':
|
||||||
|
continue
|
||||||
|
|
||||||
|
with dt[0]:
|
||||||
|
im = torch.from_numpy(im).to(model.device)
|
||||||
|
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||||
|
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||||
|
if len(im.shape) == 3:
|
||||||
|
im = im[None] # expand for batch dim
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with dt[1]:
|
||||||
|
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
|
||||||
|
pred = model(im, augment=augment, visualize=visualize)
|
||||||
|
|
||||||
|
# NMS
|
||||||
|
with dt[2]:
|
||||||
|
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
||||||
|
|
||||||
|
# Second-stage classifier (optional)
|
||||||
|
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
|
||||||
|
|
||||||
|
# Define the path for the CSV file
|
||||||
|
# csv_path = save_dir / 'predictions.csv'
|
||||||
|
|
||||||
|
# Create or append to the CSV file
|
||||||
|
# def write_to_csv(image_name, prediction, confidence):
|
||||||
|
# data = {'Image Name': image_name, 'Prediction': prediction, 'Confidence': confidence}
|
||||||
|
# with open(csv_path, mode='a', newline='') as f:
|
||||||
|
# writer = csv.DictWriter(f, fieldnames=data.keys())
|
||||||
|
# if not csv_path.is_file():
|
||||||
|
# writer.writeheader()
|
||||||
|
# writer.writerow(data)
|
||||||
|
|
||||||
|
# Process predictions
|
||||||
|
for i, det in enumerate(pred): # per image
|
||||||
|
seen += 1
|
||||||
|
if webcam: # batch_size >= 1
|
||||||
|
p, im0, frame = path[i], im0s[i].copy(), dataset.count
|
||||||
|
s += f'{i}: '
|
||||||
|
else:
|
||||||
|
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
|
||||||
|
|
||||||
|
im0_ant = im0.copy()
|
||||||
|
|
||||||
|
p = Path(p) # to Path
|
||||||
|
save_path = str(save_dir / p.name) # im.jpg
|
||||||
|
s += '%gx%g ' % im.shape[2:] # print string
|
||||||
|
|
||||||
|
annotator = Annotator(im0_ant, line_width=line_thickness, example=str(names))
|
||||||
|
|
||||||
|
nd = len(det)
|
||||||
|
if nd:
|
||||||
|
# Rescale boxes from img_size to im0 size
|
||||||
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||||
|
|
||||||
|
det = det.cpu().numpy()
|
||||||
|
det = np.concatenate([det[:, :4], np.arange(nd).reshape(-1, 1), det[:, 4:]], axis=-1)
|
||||||
|
|
||||||
|
'''FeatFlag为相对于上一帧boxes,当前boxes是否为静止的标志。'''
|
||||||
|
# def static_estimate(box1, box2, TH1=8, TH2=12):
|
||||||
|
# dij_abs = max(np.abs(box1 - box2))
|
||||||
|
# dij_euc = max([np.linalg.norm((box1[:2] - box2[:2])),
|
||||||
|
# np.linalg.norm((box1[2:4] - box2[2:4]))
|
||||||
|
# ])
|
||||||
|
# if dij_abs < TH1 and dij_euc < TH2:
|
||||||
|
# return True
|
||||||
|
# else:
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# FeatFlag = [-1] * nd
|
||||||
|
# if len(boxes_and_imgs):
|
||||||
|
# detj = boxes_and_imgs[-1][0]
|
||||||
|
# frmj = boxes_and_imgs[-1][-1]
|
||||||
|
# for ii in range(nd):
|
||||||
|
# ## flag 中保存的是box索引
|
||||||
|
# condt1 = frame-frmj==1
|
||||||
|
# flag = [idx for jj, idx in enumerate(detj[:, 4]) if condt1 and static_estimate(det[ii, :4], detj[jj, :4])]
|
||||||
|
# if len(flag) == 1:
|
||||||
|
# FeatFlag[ii] = flag[0]
|
||||||
|
|
||||||
|
|
||||||
|
boxes_and_imgs.append((det, im0, frame))
|
||||||
|
|
||||||
|
## ================================================================ writed by WQG
|
||||||
|
'''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
|
||||||
|
这里,frame_index 也可以用视频的 帧ID 代替, box_index 保持不变
|
||||||
|
'''
|
||||||
|
|
||||||
|
det_tracking = Boxes(det, im0.shape).cpu().numpy()
|
||||||
|
tracks = tracker.update(det_tracking, im0)
|
||||||
|
|
||||||
|
# detbox = [tlwh2tlbr(x._tlwh).tolist() + [x.track_id, x.score, x.cls, x.frame_id, x.idx]
|
||||||
|
# for x in tracker.tracked_stracks if x.is_activated]
|
||||||
|
if len(tracks):
|
||||||
|
'''
|
||||||
|
tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
|
||||||
|
0 1 2 3 4 5 6 7 8
|
||||||
|
'''
|
||||||
|
tracks[:, 7] = dataset.frame
|
||||||
|
|
||||||
|
|
||||||
|
'''================== 1. 提取手势位置 ==================='''
|
||||||
|
# idx_0 = tracks[:, 6].astype(np.int_) == 0
|
||||||
|
# hn = 0
|
||||||
|
# for j, index in enumerate(idx_0):
|
||||||
|
# if index:
|
||||||
|
# track = tracks[j, :]
|
||||||
|
# hand_local, imgshow = handpose.get_hand_local(track, im0)
|
||||||
|
# handlocals_dict.update({int(track[7]): {int(track[8]): hand_local}})
|
||||||
|
|
||||||
|
# # '''yoloV5和手势检测的召回率并不一直,用hand_local代替tracks中手部的(x1, y1, x2, y2),会使得两种坐标方式混淆'''
|
||||||
|
# # if hand_local: tracks[j, :4] = hand_local
|
||||||
|
|
||||||
|
# hn += 1
|
||||||
|
# cv2.imwrite(f"D:\DeepLearning\yolov5\hands\images\{Path(source).stem}_{int(track[7])}_{hn}.png", imgshow)
|
||||||
|
|
||||||
|
|
||||||
|
'''================== 2. 存储轨迹信息 ==================='''
|
||||||
|
track_boxes = np.concatenate([track_boxes, tracks], axis=0)
|
||||||
|
|
||||||
|
# det_boxes = np.concatenate([det_boxes, detbox], axis=0)
|
||||||
|
'''================== 3. 存储轨迹 REID 特征 ============='''
|
||||||
|
|
||||||
|
def crop_img(track, image):
|
||||||
|
tlbr = track.tlwh_to_tlbr(track._tlwh).astype(np.int_)
|
||||||
|
|
||||||
|
H, W = image.shape[:2]
|
||||||
|
tlbr[0] = max(0, tlbr[0])
|
||||||
|
tlbr[1] = max(0, tlbr[1])
|
||||||
|
tlbr[2] = min(W - 1, tlbr[2])
|
||||||
|
tlbr[3] = min(H - 1, tlbr[3])
|
||||||
|
|
||||||
|
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
|
||||||
|
# cv2.imwrite(f"./runs/imgs/{int(track.idx)}.png", img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
feat_dict_1 = {f'{int(x.idx)}_img': crop_img(x, im0) for x in tracker.tracked_stracks if x.is_activated}
|
||||||
|
feat_dict = {int(x.idx): x.curr_feat for x in tracker.tracked_stracks if x.is_activated}
|
||||||
|
feat_dict.update(feat_dict_1)
|
||||||
|
|
||||||
|
features_dict.update({int(dataset.frame): feat_dict})
|
||||||
|
|
||||||
|
# det_anno = tracks.copy()
|
||||||
|
# else:
|
||||||
|
# idmark = -1 * np.ones([det.shape[0], 1])
|
||||||
|
# det_anno = np.concatenate([det[:,:4], idmark, det[:, 4:]], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for *xyxy, id, conf, cls, fid, bid in reversed(tracks):
|
||||||
|
name = ('' if id==-1 else f'id:{int(id)} ') + names[int(cls)]
|
||||||
|
label = None if hide_labels else (name if hide_conf else f'{name} {conf:.2f}')
|
||||||
|
|
||||||
|
if id >=0 and cls==0:
|
||||||
|
color = colors(int(cls), True)
|
||||||
|
elif id >=0 and cls!=0:
|
||||||
|
color = colors(int(id), True)
|
||||||
|
else:
|
||||||
|
color = colors(19, True) # 19为调色板的最后一个元素
|
||||||
|
|
||||||
|
annotator.box_label(xyxy, label, color=color)
|
||||||
|
|
||||||
|
# Save results (image and video with tracking)
|
||||||
|
im0 = annotator.result()
|
||||||
|
save_path_img, ext = os.path.splitext(save_path)
|
||||||
|
if save_img:
|
||||||
|
if dataset.mode == 'image':
|
||||||
|
imgpath = save_path_img + f"_{dataset}.png"
|
||||||
|
else:
|
||||||
|
imgpath = save_path_img + f"_{dataset.frame}.png"
|
||||||
|
|
||||||
|
cv2.imwrite(Path(imgpath), im0)
|
||||||
|
|
||||||
|
if vid_path[i] != save_path: # new video
|
||||||
|
vid_path[i] = save_path
|
||||||
|
if isinstance(vid_writer[i], cv2.VideoWriter):
|
||||||
|
vid_writer[i].release() # release previous video writer
|
||||||
|
if vid_cap: # video
|
||||||
|
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
else: # stream
|
||||||
|
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||||
|
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
||||||
|
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
vid_writer[i].write(im0)
|
||||||
|
|
||||||
|
# Print time (inference-only)
|
||||||
|
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
## ======================================================================== written by WQG
|
||||||
|
## track_boxes: Array, [x1, y1, x2, y2, track_id, score, cls, frame_index, box_id]
|
||||||
|
|
||||||
|
'''上面保存了检测结果是视频和图像,以下还保存五种类型的数据'''
|
||||||
|
filename = os.path.split(save_path_img)[-1]
|
||||||
|
# file, ext = os.path.splitext(filename)
|
||||||
|
# =============================================================================
|
||||||
|
# fileElements = filename.split('_')
|
||||||
|
# if len(fileElements) == 6 and len(fileElements[3])==1:
|
||||||
|
# barcode = fileElements[1]
|
||||||
|
# camera = fileElements[3]
|
||||||
|
# elif len(fileElements) == 7 and len(fileElements[3])==1:
|
||||||
|
# barcode = fileElements[2]
|
||||||
|
# camera = fileElements[4]
|
||||||
|
# else:
|
||||||
|
# barcode = ''
|
||||||
|
# camera = ''
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
'''======================== 1. save in './run/detect/' ===================='''
|
||||||
|
if source.find("front") >= 0:
|
||||||
|
carttemp = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png")
|
||||||
|
else:
|
||||||
|
carttemp = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png")
|
||||||
|
|
||||||
|
|
||||||
|
imgshow = drawtracks(track_boxes, carttemp)
|
||||||
|
showpath_1 = save_path_img + "_show.png"
|
||||||
|
cv2.imwrite(Path(showpath_1), imgshow)
|
||||||
|
|
||||||
|
|
||||||
|
'''======================== 2. save boxes and raw images =================='''
|
||||||
|
# boxes_imgs_dir = Path('./tracking/data/boxes_imgs/')
|
||||||
|
# if not boxes_imgs_dir.exists():
|
||||||
|
# boxes_imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# boxes_imgs_path = boxes_imgs_dir.joinpath(f'{filename}.pkl')
|
||||||
|
# with open(boxes_imgs_path, 'wb') as file:
|
||||||
|
# pickle.dump(boxes_and_imgs, file)
|
||||||
|
|
||||||
|
'''======================== 3. save tracks data ==========================='''
|
||||||
|
tracks_dir = Path('./tracking/data/tracks/')
|
||||||
|
if not tracks_dir.exists():
|
||||||
|
tracks_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
tracks_path = tracks_dir.joinpath(filename + "_track.npy")
|
||||||
|
np.save(tracks_path, track_boxes)
|
||||||
|
|
||||||
|
detect_path = tracks_dir.joinpath(filename + "_detect.npy")
|
||||||
|
np.save(detect_path, det_boxes)
|
||||||
|
|
||||||
|
'''======================== 4. save reid features data ===================='''
|
||||||
|
feats_dir = Path('./tracking/data/trackfeats/')
|
||||||
|
if not feats_dir.exists():
|
||||||
|
feats_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
feats_path = feats_dir.joinpath(f'{filename}.pkl')
|
||||||
|
with open(feats_path, 'wb') as file:
|
||||||
|
pickle.dump(features_dict, file)
|
||||||
|
|
||||||
|
'''======================== 5. save hand_local data =================='''
|
||||||
|
# handlocal_dir = Path('./tracking/data/handlocal/')
|
||||||
|
# if not handlocal_dir.exists():
|
||||||
|
# handlocal_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# handlocal_path = handlocal_dir.joinpath(f'{filename}.pkl')
|
||||||
|
# with open(handlocal_path, 'wb') as file:
|
||||||
|
# pickle.dump(handlocals_dict, file)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
|
||||||
|
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
|
||||||
|
if save_txt or save_img:
|
||||||
|
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
|
||||||
|
if update:
|
||||||
|
strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_opt():
|
||||||
|
modelpath = ROOT / 'ckpts/best_yolov5m_250000.pt' # 'ckpts/best_15000_0908.pt', 'ckpts/yolov5s.pt', 'ckpts/best_20000_cls30.pt'
|
||||||
|
|
||||||
|
'''datapath为视频文件目录或视频文件'''
|
||||||
|
datapath = r"D:/datasets/ym/videos/标记视频/" # ROOT/'data/videos', ROOT/'data/images' images
|
||||||
|
# datapath = r"D:\datasets\ym\highvalue\videos"
|
||||||
|
# datapath = r"D:/dcheng/videos/"
|
||||||
|
# modelpath = ROOT / 'ckpts/yolov5s.pt'
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--weights', nargs='+', type=str, default=modelpath, help='model path or triton URL') # 'yolov5s.pt', best_15000_0908.pt
|
||||||
|
parser.add_argument('--source', type=str, default=datapath, help='file/dir/URL/glob/screen/0(webcam)') # images, videos
|
||||||
|
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
|
||||||
|
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
||||||
|
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
|
||||||
|
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
|
||||||
|
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
|
||||||
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||||
|
parser.add_argument('--view-img', action='store_true', help='show results')
|
||||||
|
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
||||||
|
parser.add_argument('--save-csv', action='store_true', help='save results in CSV format')
|
||||||
|
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
|
||||||
|
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
|
||||||
|
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
|
||||||
|
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
|
||||||
|
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
||||||
|
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||||
|
parser.add_argument('--visualize', action='store_true', help='visualize features')
|
||||||
|
parser.add_argument('--update', action='store_true', help='update all models')
|
||||||
|
parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
|
||||||
|
parser.add_argument('--name', default='exp', help='save results to project/name')
|
||||||
|
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||||||
|
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
|
||||||
|
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
|
||||||
|
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
|
||||||
|
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
||||||
|
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
|
||||||
|
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
||||||
|
print_args(vars(opt))
|
||||||
|
return opt
|
||||||
|
|
||||||
|
def main_loop_folders(opt):
|
||||||
|
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
|
||||||
|
|
||||||
|
# path1 = r"D:\datasets\ym\videos\标记视频"
|
||||||
|
|
||||||
|
path2 = r"D:\datasets\ym\永辉双摄视频\加购_前摄\videos_front"
|
||||||
|
# path3 = r"D:\datasets\ym\永辉双摄视频\加购_后摄\videos_back"
|
||||||
|
# path4 = r"D:\datasets\ym\永辉双摄视频\退购_前摄\videos_front"
|
||||||
|
# path5 = r"D:\datasets\ym\永辉双摄视频\退购_后摄\videos_back"
|
||||||
|
|
||||||
|
path6 = r"D:\datasets\ym\测试数据20240328\front"
|
||||||
|
path7 = r"D:\datasets\ym\测试数据20240328\back"
|
||||||
|
|
||||||
|
'''列表paths内的元素为视频文件夹,该文件夹下元素为视频文件'''
|
||||||
|
paths = [path2, path7] # [path1, path2, path3, path4, path5]
|
||||||
|
|
||||||
|
optdict = vars(opt)
|
||||||
|
k1, k2 = 0, 0
|
||||||
|
for p in paths:
|
||||||
|
files = []
|
||||||
|
if os.path.isdir(p):
|
||||||
|
files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))
|
||||||
|
for file in files:
|
||||||
|
file = r"D:\datasets\ym\测试数据20240328\front\112954521-7dd5ddad-922a-427b-b59e-a593e95e6ff4_front.mp4"
|
||||||
|
|
||||||
|
optdict["source"] = file
|
||||||
|
run(**optdict)
|
||||||
|
|
||||||
|
k2 += 1
|
||||||
|
if k2 == 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
elif os.path.isfile(p):
|
||||||
|
run(**optdict)
|
||||||
|
|
||||||
|
k1 += 1
|
||||||
|
if k1 == 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
def find_files_in_nested_dirs(root_dir):
|
||||||
|
all_files = []
|
||||||
|
extensions = ['.mp4']
|
||||||
|
for dirpath, dirnames, filenames in os.walk(root_dir):
|
||||||
|
for filename in filenames:
|
||||||
|
file, ext = os.path.splitext(filename)
|
||||||
|
if ext in extensions:
|
||||||
|
all_files.append(os.path.join(dirpath, filename))
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
print('=======')
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
|
||||||
|
|
||||||
|
p = r"D:\datasets\ym\永辉测试数据_202404\20240402"
|
||||||
|
optdict = vars(opt)
|
||||||
|
files = []
|
||||||
|
k = 0
|
||||||
|
|
||||||
|
all_files = find_files_in_nested_dirs(p)
|
||||||
|
|
||||||
|
if os.path.isdir(p):
|
||||||
|
files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))
|
||||||
|
for file in files:
|
||||||
|
optdict["source"] = file
|
||||||
|
run(**optdict)
|
||||||
|
|
||||||
|
k += 1
|
||||||
|
if k == 2:
|
||||||
|
break
|
||||||
|
elif os.path.isfile(p):
|
||||||
|
run(**vars(opt))
|
||||||
|
|
||||||
|
def main_loop(opt):
|
||||||
|
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
|
||||||
|
|
||||||
|
optdict = vars(opt)
|
||||||
|
|
||||||
|
# p = r"D:\datasets\ym\永辉测试数据_比对"
|
||||||
|
p = r"D:\datasets\ym\广告板遮挡测试\8"
|
||||||
|
# p = r"D:\datasets\ym\videos\标记视频"
|
||||||
|
# p = r"D:\datasets\ym\实验室测试"
|
||||||
|
|
||||||
|
k = 0
|
||||||
|
if os.path.isdir(p):
|
||||||
|
files = find_files_in_nested_dirs(p)
|
||||||
|
|
||||||
|
files = [r"D:\datasets\ym\videos\标记视频\test_20240402-173935_6920152400975_back_174037372.mp4",
|
||||||
|
r"D:\datasets\ym\videos\标记视频\test_20240402-173935_6920152400975_front_174037379.mp4"
|
||||||
|
]
|
||||||
|
|
||||||
|
files = [r"D:\datasets\ym\广告板遮挡测试\8\2500441577966_20240508-175946_front_addGood_70f75407b7ae_155_17788571404.mp4"]
|
||||||
|
for file in files:
|
||||||
|
optdict["source"] = file
|
||||||
|
run(**optdict)
|
||||||
|
|
||||||
|
k += 1
|
||||||
|
if k == 2:
|
||||||
|
break
|
||||||
|
elif os.path.isfile(p):
|
||||||
|
optdict["source"] = p
|
||||||
|
run(**vars(opt))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
opt = parse_opt()
|
||||||
|
|
||||||
|
# main_loop_folders(opt)
|
||||||
|
# main(opt)
|
||||||
|
main_loop(opt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
50
bclass.py
Normal file
50
bclass.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Fri Nov 15 16:23:03 2024
|
||||||
|
|
||||||
|
@author: ym
|
||||||
|
"""
|
||||||
|
|
||||||
|
class CamEvent:
|
||||||
|
def __init__(self, datapath):
|
||||||
|
self.data_path = datapath
|
||||||
|
self.bboxes = None
|
||||||
|
self.bfeats = None
|
||||||
|
self.tboxes = None
|
||||||
|
self.tfeats = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ShopEvent:
|
||||||
|
def __init__(self, eventpath, stdpath):
|
||||||
|
self.barcode = ""
|
||||||
|
self.event_path = eventpath
|
||||||
|
self.event_type = self.get_event_type(eventpath)
|
||||||
|
|
||||||
|
self.FrontEvent = ""
|
||||||
|
self.BackEvent = ""
|
||||||
|
self.fusion_boxes = None
|
||||||
|
self.fusion_feats = None
|
||||||
|
self.stdfeats = self.get_stdfeats(stdpath)
|
||||||
|
self.weight = None
|
||||||
|
self.imu = None
|
||||||
|
|
||||||
|
def get_event_type(self, eventpath):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_stdfeats(self, stdpath):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
174
benchmarks.py
Normal file
174
benchmarks.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Run YOLOv5 benchmarks on all supported export formats
|
||||||
|
|
||||||
|
Format | `export.py --include` | Model
|
||||||
|
--- | --- | ---
|
||||||
|
PyTorch | - | yolov5s.pt
|
||||||
|
TorchScript | `torchscript` | yolov5s.torchscript
|
||||||
|
ONNX | `onnx` | yolov5s.onnx
|
||||||
|
OpenVINO | `openvino` | yolov5s_openvino_model/
|
||||||
|
TensorRT | `engine` | yolov5s.engine
|
||||||
|
CoreML | `coreml` | yolov5s.mlmodel
|
||||||
|
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
|
||||||
|
TensorFlow GraphDef | `pb` | yolov5s.pb
|
||||||
|
TensorFlow Lite | `tflite` | yolov5s.tflite
|
||||||
|
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
|
||||||
|
TensorFlow.js | `tfjs` | yolov5s_web_model/
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||||
|
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
||||||
|
$ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$ python benchmarks.py --weights yolov5s.pt --img 640
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
FILE = Path(__file__).resolve()
|
||||||
|
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||||
|
# ROOT = ROOT.relative_to(Path.cwd()) # relative
|
||||||
|
|
||||||
|
import export
|
||||||
|
from models.experimental import attempt_load
|
||||||
|
from models.yolo import SegmentationModel
|
||||||
|
from segment.val import run as val_seg
|
||||||
|
from utils import notebook_init
|
||||||
|
from utils.general import LOGGER, check_yaml, file_size, print_args
|
||||||
|
from utils.torch_utils import select_device
|
||||||
|
from val import run as val_det
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
weights=ROOT / 'yolov5s.pt', # weights path
|
||||||
|
imgsz=640, # inference size (pixels)
|
||||||
|
batch_size=1, # batch size
|
||||||
|
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||||
|
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||||
|
half=False, # use FP16 half-precision inference
|
||||||
|
test=False, # test exports only
|
||||||
|
pt_only=False, # test PyTorch only
|
||||||
|
hard_fail=False, # throw error on benchmark failure
|
||||||
|
):
|
||||||
|
y, t = [], time.time()
|
||||||
|
device = select_device(device)
|
||||||
|
model_type = type(attempt_load(weights, fuse=False)) # DetectionModel, SegmentationModel, etc.
|
||||||
|
for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
|
||||||
|
try:
|
||||||
|
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||||
|
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||||
|
if 'cpu' in device.type:
|
||||||
|
assert cpu, 'inference not supported on CPU'
|
||||||
|
if 'cuda' in device.type:
|
||||||
|
assert gpu, 'inference not supported on GPU'
|
||||||
|
|
||||||
|
# Export
|
||||||
|
if f == '-':
|
||||||
|
w = weights # PyTorch format
|
||||||
|
else:
|
||||||
|
w = export.run(weights=weights,
|
||||||
|
imgsz=[imgsz],
|
||||||
|
include=[f],
|
||||||
|
batch_size=batch_size,
|
||||||
|
device=device,
|
||||||
|
half=half)[-1] # all others
|
||||||
|
assert suffix in str(w), 'export failed'
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
if model_type == SegmentationModel:
|
||||||
|
result = val_seg(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
|
||||||
|
metric = result[0][7] # (box(p, r, map50, map), mask(p, r, map50, map), *loss(box, obj, cls))
|
||||||
|
else: # DetectionModel:
|
||||||
|
result = val_det(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
|
||||||
|
metric = result[0][3] # (p, r, map50, map, *loss(box, obj, cls))
|
||||||
|
speed = result[2][1] # times (preprocess, inference, postprocess)
|
||||||
|
y.append([name, round(file_size(w), 1), round(metric, 4), round(speed, 2)]) # MB, mAP, t_inference
|
||||||
|
except Exception as e:
|
||||||
|
if hard_fail:
|
||||||
|
assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}'
|
||||||
|
LOGGER.warning(f'WARNING ⚠️ Benchmark failure for {name}: {e}')
|
||||||
|
y.append([name, None, None, None]) # mAP, t_inference
|
||||||
|
if pt_only and i == 0:
|
||||||
|
break # break after PyTorch
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
LOGGER.info('\n')
|
||||||
|
parse_opt()
|
||||||
|
notebook_init() # print system info
|
||||||
|
c = ['Format', 'Size (MB)', 'mAP50-95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
|
||||||
|
py = pd.DataFrame(y, columns=c)
|
||||||
|
LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
|
||||||
|
LOGGER.info(str(py if map else py.iloc[:, :2]))
|
||||||
|
if hard_fail and isinstance(hard_fail, str):
|
||||||
|
metrics = py['mAP50-95'].array # values to compare to floor
|
||||||
|
floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||||
|
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: mAP50-95 < floor {floor}'
|
||||||
|
return py
|
||||||
|
|
||||||
|
|
||||||
|
def test(
|
||||||
|
weights=ROOT / 'yolov5s.pt', # weights path
|
||||||
|
imgsz=640, # inference size (pixels)
|
||||||
|
batch_size=1, # batch size
|
||||||
|
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||||
|
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||||
|
half=False, # use FP16 half-precision inference
|
||||||
|
test=False, # test exports only
|
||||||
|
pt_only=False, # test PyTorch only
|
||||||
|
hard_fail=False, # throw error on benchmark failure
|
||||||
|
):
|
||||||
|
y, t = [], time.time()
|
||||||
|
device = select_device(device)
|
||||||
|
for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
|
||||||
|
try:
|
||||||
|
w = weights if f == '-' else \
|
||||||
|
export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
|
||||||
|
assert suffix in str(w), 'export failed'
|
||||||
|
y.append([name, True])
|
||||||
|
except Exception:
|
||||||
|
y.append([name, False]) # mAP, t_inference
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
LOGGER.info('\n')
|
||||||
|
parse_opt()
|
||||||
|
notebook_init() # print system info
|
||||||
|
py = pd.DataFrame(y, columns=['Format', 'Export'])
|
||||||
|
LOGGER.info(f'\nExports complete ({time.time() - t:.2f}s)')
|
||||||
|
LOGGER.info(str(py))
|
||||||
|
return py
|
||||||
|
|
||||||
|
|
||||||
|
def parse_opt():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
|
||||||
|
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||||
|
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
||||||
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||||
|
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
||||||
|
parser.add_argument('--test', action='store_true', help='test exports only')
|
||||||
|
parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
|
||||||
|
parser.add_argument('--hard-fail', nargs='?', const=True, default=False, help='Exception on error or < min metric')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
opt.data = check_yaml(opt.data) # check YAML
|
||||||
|
print_args(vars(opt))
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
test(**vars(opt)) if opt.test else run(**vars(opt))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
opt = parse_opt()
|
||||||
|
main(opt)
|
7
contrast/__init__.py
Normal file
7
contrast/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Thu Sep 26 08:53:58 2024
|
||||||
|
|
||||||
|
@author: ym
|
||||||
|
"""
|
||||||
|
|
BIN
contrast/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/config.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/config.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/event_test.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/event_test.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/event_test.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/event_test.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/feat_inference.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/feat_inference.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/genfeats.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/genfeats.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/genfeats.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/genfeats.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/one2n_contrast.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/one2n_contrast.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/one2n_contrast.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/one2n_contrast.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/test_ori.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/test_ori.cpython-39.pyc
Normal file
Binary file not shown.
374
contrast/event_test.py
Normal file
374
contrast/event_test.py
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Mon Dec 16 18:56:18 2024
|
||||||
|
|
||||||
|
@author: ym
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from matplotlib import rcParams
|
||||||
|
from matplotlib.font_manager import FontProperties
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
from utils.event import ShoppingEvent, save_data
|
||||||
|
from utils.calsimi import calsimi_vs_stdfeat_new, get_topk_percent, cluster
|
||||||
|
from utils.tools import get_evtList
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
rcParams['font.sans-serif'] = ['SimHei'] # 用黑体显示中文
|
||||||
|
rcParams['axes.unicode_minus'] = False # 正确显示负号
|
||||||
|
|
||||||
|
'''*********** USearch ***********'''
|
||||||
|
def read_usearch():
|
||||||
|
stdFeaturePath = r"D:\contrast\stdlib\v11_test.json"
|
||||||
|
stdBarcode = []
|
||||||
|
stdlib = {}
|
||||||
|
with open(stdFeaturePath, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for dic in data['total']:
|
||||||
|
barcode = dic['key']
|
||||||
|
feature = np.array(dic['value'])
|
||||||
|
stdBarcode.append(barcode)
|
||||||
|
stdlib[barcode] = feature
|
||||||
|
|
||||||
|
return stdlib
|
||||||
|
|
||||||
|
def get_eventlist_errortxt(evtpaths):
|
||||||
|
'''
|
||||||
|
读取一次测试中的错误事件
|
||||||
|
'''
|
||||||
|
text1 = "one_2_Small_n_Error.txt"
|
||||||
|
text2 = "one_2_Big_N_Error.txt"
|
||||||
|
events = []
|
||||||
|
text = (text1, text2)
|
||||||
|
for txt in text:
|
||||||
|
txtfile = os.path.join(evtpaths, txt)
|
||||||
|
with open(txtfile, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
fpath=os.path.join(evtpaths, line)
|
||||||
|
events.append(fpath)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
events = list(set(events))
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def save_eventdata():
|
||||||
|
evtpaths = r"/home/wqg/dataset/test_dataset/performence_dataset/"
|
||||||
|
events = get_eventlist_errortxt(evtpaths)
|
||||||
|
|
||||||
|
'''定义当前事件存储地址及生成相应文件件'''
|
||||||
|
resultPath = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\result\single_event"
|
||||||
|
for evtpath in events:
|
||||||
|
event = ShoppingEvent(evtpath)
|
||||||
|
save_data(event, resultPath)
|
||||||
|
|
||||||
|
print(event.evtname)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def get_topk_percent(data, k):
|
||||||
|
# """
|
||||||
|
# 获取数据中最大的 k% 的元素
|
||||||
|
# """
|
||||||
|
# # 将数据转换为 NumPy 数组
|
||||||
|
# if isinstance(data, list):
|
||||||
|
# data = np.array(data)
|
||||||
|
|
||||||
|
# percentile = np.percentile(data, 100-k)
|
||||||
|
# top_k_percent = data[data >= percentile]
|
||||||
|
|
||||||
|
# return top_k_percent
|
||||||
|
# def cluster(data, thresh=0.15):
|
||||||
|
# # data = np.array([0.1, 0.13, 0.7, 0.2, 0.8, 0.52, 0.3, 0.7, 0.85, 0.58])
|
||||||
|
# # data = np.array([0.1, 0.13, 0.2, 0.3])
|
||||||
|
# # data = np.array([0.1])
|
||||||
|
|
||||||
|
# if isinstance(data, list):
|
||||||
|
# data = np.array(data)
|
||||||
|
|
||||||
|
# data1 = np.sort(data)
|
||||||
|
# cluter, Cluters, = [data1[0]], []
|
||||||
|
# for i in range(1, len(data1)):
|
||||||
|
# if data1[i] - data1[i-1]< thresh:
|
||||||
|
# cluter.append(data1[i])
|
||||||
|
# else:
|
||||||
|
# Cluters.append(cluter)
|
||||||
|
# cluter = [data1[i]]
|
||||||
|
# Cluters.append(cluter)
|
||||||
|
|
||||||
|
# clt_center = []
|
||||||
|
# for clt in Cluters:
|
||||||
|
# ## 是否应该在此处限制一个聚类中的最小轨迹样本数,应该将该因素放在轨迹分析中
|
||||||
|
# # if len(clt)>=3:
|
||||||
|
# # clt_center.append(np.mean(clt))
|
||||||
|
# clt_center.append(np.mean(clt))
|
||||||
|
|
||||||
|
# # print(clt_center)
|
||||||
|
|
||||||
|
# return clt_center
|
||||||
|
|
||||||
|
# def calsimi_vs_stdfeat_new(event, stdfeat):
|
||||||
|
# '''事件与标准库的对比策略
|
||||||
|
# 该比对策略是否可以拓展到事件与事件的比对?
|
||||||
|
# '''
|
||||||
|
|
||||||
|
|
||||||
|
# def calsiml(feat1, feat2, topkp=75, cluth=0.15):
|
||||||
|
# '''轨迹样本和标准特征集样本相似度的选择策略'''
|
||||||
|
# matrix = 1 - cdist(feat1, feat2, 'cosine')
|
||||||
|
# simi_max = []
|
||||||
|
# for i in range(len(matrix)):
|
||||||
|
# sim = np.mean(get_topk_percent(matrix[i, :], topkp))
|
||||||
|
# simi_max.append(sim)
|
||||||
|
# cltc_max = cluster(simi_max, cluth)
|
||||||
|
# Simi = max(cltc_max)
|
||||||
|
|
||||||
|
# ## cltc_max为空属于编程考虑不周,应予以排查解决
|
||||||
|
# # if len(cltc_max):
|
||||||
|
# # Simi = max(cltc_max)
|
||||||
|
# # else:
|
||||||
|
# # Simi = 0 #不应该走到该处
|
||||||
|
|
||||||
|
|
||||||
|
# return Simi
|
||||||
|
|
||||||
|
|
||||||
|
# front_boxes = np.empty((0, 9), dtype=np.float64) ##和类doTracks兼容
|
||||||
|
# front_feats = np.empty((0, 256), dtype=np.float64) ##和类doTracks兼容
|
||||||
|
# for i in range(len(event.front_boxes)):
|
||||||
|
# front_boxes = np.concatenate((front_boxes, event.front_boxes[i]), axis=0)
|
||||||
|
# front_feats = np.concatenate((front_feats, event.front_feats[i]), axis=0)
|
||||||
|
|
||||||
|
# back_boxes = np.empty((0, 9), dtype=np.float64) ##和类doTracks兼容
|
||||||
|
# back_feats = np.empty((0, 256), dtype=np.float64) ##和类doTracks兼容
|
||||||
|
# for i in range(len(event.back_boxes)):
|
||||||
|
# back_boxes = np.concatenate((back_boxes, event.back_boxes[i]), axis=0)
|
||||||
|
# back_feats = np.concatenate((back_feats, event.back_feats[i]), axis=0)
|
||||||
|
|
||||||
|
# if len(front_feats):
|
||||||
|
# front_simi = calsiml(front_feats, stdfeat)
|
||||||
|
# if len(back_feats):
|
||||||
|
# back_simi = calsiml(back_feats, stdfeat)
|
||||||
|
|
||||||
|
# '''前后摄相似度融合策略'''
|
||||||
|
# if len(front_feats) and len(back_feats):
|
||||||
|
# diff_simi = abs(front_simi - back_simi)
|
||||||
|
# if diff_simi>0.15:
|
||||||
|
# Similar = max([front_simi, back_simi])
|
||||||
|
# else:
|
||||||
|
# Similar = (front_simi+back_simi)/2
|
||||||
|
# elif len(front_feats) and len(back_feats)==0:
|
||||||
|
# Similar = front_simi
|
||||||
|
# elif len(front_feats)==0 and len(back_feats):
|
||||||
|
# Similar = back_simi
|
||||||
|
# else:
|
||||||
|
# Similar = None # 在event.front_feats和event.back_feats同时为空时
|
||||||
|
|
||||||
|
# return Similar
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def simi_matrix():
|
||||||
|
evtpaths = r"/home/wqg/dataset/pipeline/contrast/single_event_V10/evtobjs/"
|
||||||
|
|
||||||
|
stdfeatPath = r"/home/wqg/dataset/test_dataset/total_barcode/features_json/v11_barcode_0304/"
|
||||||
|
resultPath = r"/home/wqg/dataset/performence_dataset/result/"
|
||||||
|
|
||||||
|
evt_paths, bcdSet = get_evtList(evtpaths)
|
||||||
|
|
||||||
|
## read std features
|
||||||
|
stdDict={}
|
||||||
|
evtDict = {}
|
||||||
|
for barcode in bcdSet:
|
||||||
|
stdpath = os.path.join(stdfeatPath, f"{barcode}.json")
|
||||||
|
if not os.path.isfile(stdpath):
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(stdpath, 'r', encoding='utf-8') as f:
|
||||||
|
stddata = json.load(f)
|
||||||
|
feat = np.array(stddata["value"])
|
||||||
|
stdDict[barcode] = feat
|
||||||
|
|
||||||
|
for evtpath in evt_paths:
|
||||||
|
barcode = Path(evtpath).stem.split("_")[-1]
|
||||||
|
|
||||||
|
if barcode not in stdDict.keys():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# with open(evtpath, 'rb') as f:
|
||||||
|
# evtdata = pickle.load(f)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(evtname)
|
||||||
|
|
||||||
|
with open(evtpath, 'rb') as f:
|
||||||
|
event = pickle.load(f)
|
||||||
|
|
||||||
|
stdfeat = stdDict[barcode]
|
||||||
|
|
||||||
|
Similar = calsimi_vs_stdfeat_new(event, stdfeat)
|
||||||
|
|
||||||
|
# 构造 boxes 子图存储路径
|
||||||
|
subimgpath = os.path.join(resultPath, f"{event.evtname}", "subimg")
|
||||||
|
if not os.path.exists(subimgpath):
|
||||||
|
os.makedirs(subimgpath)
|
||||||
|
histpath = os.path.join(resultPath, "simi_hist")
|
||||||
|
if not os.path.exists(histpath):
|
||||||
|
os.makedirs(histpath)
|
||||||
|
|
||||||
|
mean_values, max_values = [], []
|
||||||
|
cameras = ('front', 'back')
|
||||||
|
fig, ax = plt.subplots(2, 3, figsize=(16, 9), dpi=100)
|
||||||
|
kpercent = 25
|
||||||
|
for camera in cameras:
|
||||||
|
boxes = np.empty((0, 9), dtype=np.float64) ##和类doTracks兼容
|
||||||
|
evtfeat = np.empty((0, 256), dtype=np.float64) ##和类doTracks兼容
|
||||||
|
if camera == 'front':
|
||||||
|
for i in range(len(event.front_boxes)):
|
||||||
|
boxes = np.concatenate((boxes, event.front_boxes[i]), axis=0)
|
||||||
|
evtfeat = np.concatenate((evtfeat, event.front_feats[i]), axis=0)
|
||||||
|
imgpaths = event.front_imgpaths
|
||||||
|
|
||||||
|
else:
|
||||||
|
for i in range(len(event.back_boxes)):
|
||||||
|
boxes = np.concatenate((boxes, event.back_boxes[i]), axis=0)
|
||||||
|
evtfeat = np.concatenate((evtfeat, event.back_feats[i]), axis=0)
|
||||||
|
imgpaths = event.back_imgpaths
|
||||||
|
|
||||||
|
assert len(boxes)==len(evtfeat), f"Please check the Event: {event.evtname}"
|
||||||
|
if len(boxes)==0: continue
|
||||||
|
print(event.evtname)
|
||||||
|
|
||||||
|
matrix = 1 - cdist(evtfeat, stdfeat, 'cosine')
|
||||||
|
simi_1d = matrix.flatten()
|
||||||
|
simi_mean = np.mean(matrix, axis=1)
|
||||||
|
# simi_max = np.max(matrix, axis=1)
|
||||||
|
|
||||||
|
'''以相似度矩阵每一行最大的 k% 的相似度做均值计算'''
|
||||||
|
simi_max = []
|
||||||
|
for i in range(len(matrix)):
|
||||||
|
sim = np.mean(get_topk_percent(matrix[i, :], kpercent))
|
||||||
|
simi_max.append(sim)
|
||||||
|
|
||||||
|
|
||||||
|
mean_values.append(np.mean(matrix))
|
||||||
|
max_values.append(np.mean(simi_max))
|
||||||
|
|
||||||
|
diff_max_mean = np.mean(simi_max) - np.mean(matrix)
|
||||||
|
|
||||||
|
'''相似度统计特性图示'''
|
||||||
|
k =0
|
||||||
|
if camera == 'front': k = 1
|
||||||
|
|
||||||
|
'''********************* 相似度全体数据 *********************'''
|
||||||
|
ax[k, 0].hist(simi_1d, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||||
|
ax[k, 0].set_xlim([-0.2, 1])
|
||||||
|
ax[k, 0].set_title(camera)
|
||||||
|
|
||||||
|
_, y_max = ax[k, 0].get_ylim() # 获取y轴范围
|
||||||
|
'''相似度变动范围'''
|
||||||
|
ax[k, 0].text(-0.1, 0.15*y_max, f"rng:{max(simi_1d)-min(simi_1d):.3f}", fontsize=18, color='b')
|
||||||
|
|
||||||
|
'''********************* 均值********************************'''
|
||||||
|
ax[k, 1].hist(simi_mean, bins=24, range=(-0.2, 1), edgecolor='black')
|
||||||
|
ax[k, 1].set_xlim([-0.2, 1])
|
||||||
|
ax[k, 1].set_title("mean")
|
||||||
|
_, y_max = ax[k, 1].get_ylim() # 获取y轴范围
|
||||||
|
'''相似度变动范围'''
|
||||||
|
ax[k, 1].text(-0.1, 0.15*y_max, f"rng:{max(simi_mean)-min(simi_mean):.3f}", fontsize=18, color='b')
|
||||||
|
|
||||||
|
|
||||||
|
'''********************* 最大值 ******************************'''
|
||||||
|
ax[k, 2].hist(simi_max, bins=24, range=(-0.2, 1), edgecolor='black')
|
||||||
|
ax[k, 2].set_xlim([-0.2, 1])
|
||||||
|
ax[k, 2].set_title("max")
|
||||||
|
_, y_max = ax[k, 2].get_ylim() # 获取y轴范围
|
||||||
|
'''相似度变动范围'''
|
||||||
|
ax[k, 2].text(-0.1, 0.15*y_max, f"rng:{max(simi_max)-min(simi_max):.3f}", fontsize=18, color='b')
|
||||||
|
|
||||||
|
|
||||||
|
'''绘制聚类中心'''
|
||||||
|
cltc_mean = cluster(simi_mean)
|
||||||
|
for value in cltc_mean:
|
||||||
|
ax[k, 1].axvline(x=value, color='m', linestyle='--', linewidth=3)
|
||||||
|
|
||||||
|
cltc_max = cluster(simi_max)
|
||||||
|
for value in cltc_max:
|
||||||
|
ax[k, 2].axvline(x=value, color='m', linestyle='--', linewidth=3)
|
||||||
|
|
||||||
|
'''绘制相似度均值与最大值均值'''
|
||||||
|
ax[k, 1].axvline(x=np.mean(matrix), color='r', linestyle='-', linewidth=3)
|
||||||
|
ax[k, 2].axvline(x=np.mean(simi_max), color='g', linestyle='-', linewidth=3)
|
||||||
|
|
||||||
|
'''绘制相似度最大值均值 - 均值'''
|
||||||
|
_, y_max = ax[k, 2].get_ylim() # 获取y轴范围
|
||||||
|
ax[k, 2].text(-0.1, 0.05*y_max, f"g-r={diff_max_mean:.3f}", fontsize=18, color='m')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# for i, box in enumerate(boxes):
|
||||||
|
# x1, y1, x2, y2, tid, score, cls, fid, bid = box
|
||||||
|
# imgpath = imgpaths[int(fid-1)]
|
||||||
|
# image = cv2.imread(imgpath)
|
||||||
|
# subimg = image[int(y1/2):int(y2/2), int(x1/2):int(x2/2), :]
|
||||||
|
# camerType, timeTamp, _, frameID = os.path.basename(imgpath).split('.')[0].split('_')
|
||||||
|
# subimgName = f"cam{camerType}_{i}_tid{int(tid)}_fid({int(fid)}, {frameID})_{simi_mean[i]:.3f}.png"
|
||||||
|
# imgpairs.append((subimgName, subimg))
|
||||||
|
# spath = os.path.join(subimgpath, subimgName)
|
||||||
|
# cv2.imwrite(spath, subimg)
|
||||||
|
|
||||||
|
# oldname = f"cam{camerType}_{i}_tid{int(tid)}_fid({int(fid)}, {frameID}).png"
|
||||||
|
# oldpath = os.path.join(subimgpath, oldname)
|
||||||
|
# if os.path.exists(oldpath):
|
||||||
|
# os.remove(oldpath)
|
||||||
|
|
||||||
|
|
||||||
|
if len(mean_values)==2:
|
||||||
|
mean_diff = abs(mean_values[1]-mean_values[0])
|
||||||
|
ax[0, 1].set_title(f"mean diff: {mean_diff:.3f}")
|
||||||
|
if len(max_values)==2:
|
||||||
|
max_diff = abs(max_values[1]-max_values[0])
|
||||||
|
ax[0, 2].set_title(f"max diff: {max_diff:.3f}")
|
||||||
|
try:
|
||||||
|
fig.suptitle(f"Similar: {Similar:.3f}", fontsize=16)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print(f"Similar: {Similar}")
|
||||||
|
pltpath = os.path.join(subimgpath, f"hist_max_{kpercent}%_.png")
|
||||||
|
plt.savefig(pltpath)
|
||||||
|
|
||||||
|
pltpath1 = os.path.join(histpath, f"{event.evtname}_.png")
|
||||||
|
plt.savefig(pltpath1)
|
||||||
|
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
simi_matrix()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
# cluster()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
contrast/feat_extract/__pycache__/config.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/__pycache__/config.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/config.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/__pycache__/inference.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/inference.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/__pycache__/inference.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/inference.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/checkpoints/resnet18_0515/best.rknn
Normal file
BIN
contrast/feat_extract/checkpoints/resnet18_0515/best.rknn
Normal file
Binary file not shown.
88
contrast/feat_extract/config.py
Normal file
88
contrast/feat_extract/config.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# network settings
|
||||||
|
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||||||
|
metric = 'arcface' # [cosface, arcface]
|
||||||
|
cbam = True
|
||||||
|
embedding_size = 256
|
||||||
|
drop_ratio = 0.5
|
||||||
|
img_size = 224
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
# data preprocess
|
||||||
|
# input_shape = [1, 128, 128]
|
||||||
|
"""transforms.RandomCrop(size),
|
||||||
|
transforms.RandomVerticalFlip(p=0.5),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
RandomRotate(15, 0.3),
|
||||||
|
# RandomGaussianBlur()"""
|
||||||
|
|
||||||
|
train_transform = T.Compose([
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((img_size, img_size)),
|
||||||
|
# T.RandomCrop(img_size),
|
||||||
|
# T.RandomHorizontalFlip(p=0.5),
|
||||||
|
T.RandomRotation(180),
|
||||||
|
T.ColorJitter(brightness=0.5),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[0.5], std=[0.5]),
|
||||||
|
])
|
||||||
|
test_transform = T.Compose([
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((img_size, img_size)),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[0.5], std=[0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
train_root = './data/2250_train/train' # 初始筛选过一次的数据集
|
||||||
|
# train_root = './data/0612_train/train'
|
||||||
|
test_root = "./data/2250_train/val/"
|
||||||
|
# test_root = "./data/0612_train/val"
|
||||||
|
test_list = "./data/2250_train/val_pair.txt"
|
||||||
|
|
||||||
|
test_group_json = "./2250_train/cross_same_0508.json"
|
||||||
|
|
||||||
|
|
||||||
|
# test_list = "./data/test_data_100/val_pair.txt"
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
checkpoints = "checkpoints/resnet18_0613/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||||
|
restore = False
|
||||||
|
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||||||
|
restore_model = "checkpoints/resnet18_0515/best.pth" # best_resnet18_1491_0306.pth
|
||||||
|
|
||||||
|
# test_model = "checkpoints/renet18_2250_0314/best_resnet18_2250_0314.pth"
|
||||||
|
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||||||
|
test_val = "D:/比对/cl"
|
||||||
|
# test_val = "./data/test_data_100"
|
||||||
|
|
||||||
|
test_model = "checkpoints/best_20250228.pth"
|
||||||
|
# test_model = "checkpoints/zhanting_res_801.pth"
|
||||||
|
# test_model = "checkpoints/zhanting_res_abroad_8021.pth"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
train_batch_size = 512 # 256
|
||||||
|
test_batch_size = 256 # 256
|
||||||
|
|
||||||
|
epoch = 300
|
||||||
|
optimizer = 'sgd' # ['sgd', 'adam']
|
||||||
|
lr = 1.5e-2 # 1e-2
|
||||||
|
lr_step = 5 # 10
|
||||||
|
lr_decay = 0.95 # 0.98
|
||||||
|
weight_decay = 5e-4
|
||||||
|
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
|
||||||
|
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
||||||
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
pin_memory = True # if memory is large, set it True to speed up a bit
|
||||||
|
num_workers = 4 # dataloader
|
||||||
|
|
||||||
|
group_test = True
|
||||||
|
|
||||||
|
config = Config()
|
606
contrast/feat_extract/inference.py
Normal file
606
contrast/feat_extract/inference.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
|
||||||
|
@author: LiChen
|
||||||
|
"""
|
||||||
|
# import pdb
|
||||||
|
# import shutil
|
||||||
|
import torch.nn as nn
|
||||||
|
# import statistics
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
import torch
|
||||||
|
import os.path as osp
|
||||||
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
# import sys
|
||||||
|
# sys.path.append(r"D:\DetectTracking")
|
||||||
|
# from contrast.config import config as conf
|
||||||
|
# from contrast.model import resnet18
|
||||||
|
|
||||||
|
from .config import config as conf
|
||||||
|
from .model import resnet18
|
||||||
|
|
||||||
|
# from model import (mobilevit_s, resnet14, resnet18, resnet34, resnet50, mobilenet_v2,
|
||||||
|
# MobileNetV3_Small, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||||
|
|
||||||
|
curpath = Path(__file__).resolve().parents[0]
|
||||||
|
|
||||||
|
class FeatsInterface:
|
||||||
|
def __init__(self, conf):
|
||||||
|
self.device = conf.device
|
||||||
|
|
||||||
|
# if conf.backbone == 'resnet18':
|
||||||
|
# model = resnet18().to(conf.device)
|
||||||
|
|
||||||
|
model = resnet18().to(conf.device)
|
||||||
|
self.transform = conf.test_transform
|
||||||
|
self.batch_size = conf.batch_size
|
||||||
|
self.embedding_size = conf.embedding_size
|
||||||
|
|
||||||
|
###yj 注释
|
||||||
|
# if conf.test_model.find("zhanting") == -1:
|
||||||
|
# model = nn.DataParallel(model).to(conf.device)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
modpath = os.path.join(curpath, conf.test_model)
|
||||||
|
self.model.load_state_dict(torch.load(modpath, map_location=conf.device))
|
||||||
|
self.model.eval()
|
||||||
|
# print('load model {} '.format(conf.testbackbone))
|
||||||
|
|
||||||
|
def inference(self, images, detections=None):
|
||||||
|
'''
|
||||||
|
如果是BGR,需要转变为RGB格式
|
||||||
|
'''
|
||||||
|
if isinstance(images, np.ndarray):
|
||||||
|
imgs, features = self.inference_image(images, detections)
|
||||||
|
return imgs, features
|
||||||
|
|
||||||
|
batch_patches = []
|
||||||
|
patches = []
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
img = img.copy()
|
||||||
|
|
||||||
|
## 对 img 进行补黑边,生成新的图像new_img
|
||||||
|
width, height = img.size
|
||||||
|
new_size = max(width, height)
|
||||||
|
new_img = Image.new("RGB", (new_size, new_size), (0, 0, 0))
|
||||||
|
paste_x = (new_size - width) // 2
|
||||||
|
paste_y = (new_size - height) // 2
|
||||||
|
new_img.paste(img, (paste_x, paste_y))
|
||||||
|
|
||||||
|
patch = self.transform(new_img)
|
||||||
|
patch = patch.to(device=self.device)
|
||||||
|
# if str(self.device) != "cpu":
|
||||||
|
# patch = patch.to(device=self.device).half()
|
||||||
|
# else:
|
||||||
|
# patch = patch.to(device=self.device)
|
||||||
|
|
||||||
|
patches.append(patch)
|
||||||
|
if (i + 1) % self.batch_size == 0:
|
||||||
|
patches = torch.stack(patches, dim=0)
|
||||||
|
batch_patches.append(patches)
|
||||||
|
patches = []
|
||||||
|
|
||||||
|
if len(patches):
|
||||||
|
patches = torch.stack(patches, dim=0)
|
||||||
|
batch_patches.append(patches)
|
||||||
|
|
||||||
|
features = np.zeros((0, self.embedding_size))
|
||||||
|
for patches in batch_patches:
|
||||||
|
pred=self.model(patches)
|
||||||
|
pred[torch.isinf(pred)] = 1.0
|
||||||
|
feat = pred.cpu().data.numpy()
|
||||||
|
features = np.vstack((features, feat))
|
||||||
|
return features
|
||||||
|
|
||||||
|
def inference_image(self, image, detections):
|
||||||
|
H, W, _ = np.shape(image)
|
||||||
|
|
||||||
|
batch_patches = []
|
||||||
|
patches = []
|
||||||
|
imgs = []
|
||||||
|
for d in range(np.size(detections, 0)):
|
||||||
|
tlbr = detections[d, :4].astype(np.int_)
|
||||||
|
tlbr[0] = max(0, tlbr[0])
|
||||||
|
tlbr[1] = max(0, tlbr[1])
|
||||||
|
tlbr[2] = min(W - 1, tlbr[2])
|
||||||
|
tlbr[3] = min(H - 1, tlbr[3])
|
||||||
|
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
|
||||||
|
|
||||||
|
imgs.append(img)
|
||||||
|
|
||||||
|
|
||||||
|
img1 = img[:, :, ::-1].copy() # the model expects RGB inputs
|
||||||
|
patch = self.transform(img1)
|
||||||
|
|
||||||
|
# patch = patch.to(device=self.device).half()
|
||||||
|
# if str(self.device) != "cpu":
|
||||||
|
# patch = patch.to(device=self.device).half()
|
||||||
|
# patch = patch.to(device=self.device)
|
||||||
|
# else:
|
||||||
|
# patch = patch.to(device=self.device)
|
||||||
|
patch = patch.to(device=self.device)
|
||||||
|
|
||||||
|
patches.append(patch)
|
||||||
|
if (d + 1) % self.batch_size == 0:
|
||||||
|
patches = torch.stack(patches, dim=0)
|
||||||
|
batch_patches.append(patches)
|
||||||
|
patches = []
|
||||||
|
|
||||||
|
if len(patches):
|
||||||
|
patches = torch.stack(patches, dim=0)
|
||||||
|
batch_patches.append(patches)
|
||||||
|
|
||||||
|
features = np.zeros((0, self.embedding_size))
|
||||||
|
for patches in batch_patches:
|
||||||
|
pred = self.model(patches)
|
||||||
|
pred[torch.isinf(pred)] = 1.0
|
||||||
|
feat = pred.cpu().data.numpy()
|
||||||
|
features = np.vstack((features, feat))
|
||||||
|
|
||||||
|
return imgs, features
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def unique_image(pair_list) -> set:
|
||||||
|
"""Return unique image path in pair_list.txt"""
|
||||||
|
with open(pair_list, 'r') as fd:
|
||||||
|
pairs = fd.readlines()
|
||||||
|
unique = set()
|
||||||
|
for pair in pairs:
|
||||||
|
id1, id2, _ = pair.split()
|
||||||
|
unique.add(id1)
|
||||||
|
unique.add(id2)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
def group_image(images: set, batch) -> list:
|
||||||
|
"""Group image paths by batch size"""
|
||||||
|
images = list(images)
|
||||||
|
size = len(images)
|
||||||
|
res = []
|
||||||
|
for i in range(0, size, batch):
|
||||||
|
end = min(batch + i, size)
|
||||||
|
res.append(images[i: end])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess(images: list, transform) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
im = Image.open(img)
|
||||||
|
im = transform(im)
|
||||||
|
res.append(im)
|
||||||
|
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||||||
|
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess(images: list, transform) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
im = Image.open(img)
|
||||||
|
im = transform(im)
|
||||||
|
res.append(im)
|
||||||
|
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||||||
|
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def featurize(images: list, transform, net, device, train=False) -> dict:
|
||||||
|
"""featurize each image and save into a dictionary
|
||||||
|
Args:
|
||||||
|
images: image paths
|
||||||
|
transform: test transform
|
||||||
|
net: pretrained model
|
||||||
|
device: cpu or cuda
|
||||||
|
Returns:
|
||||||
|
Dict (key: imagePath, value: feature)
|
||||||
|
"""
|
||||||
|
if train:
|
||||||
|
data = _preprocess(images, transform)
|
||||||
|
data = data.to(device)
|
||||||
|
net = net.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = net(data)
|
||||||
|
res = {img: feature for (img, feature) in zip(images, features)}
|
||||||
|
else:
|
||||||
|
data = test_preprocess(images, transform)
|
||||||
|
data = data.to(device)
|
||||||
|
net = net.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = net(data)
|
||||||
|
res = {img: feature for (img, feature) in zip(images, features)}
|
||||||
|
return res
|
||||||
|
|
||||||
|
# def inference_image(images: list, transform, net, device, bs=16, embedding_size=256) -> dict:
|
||||||
|
# batch_patches = []
|
||||||
|
# patches = []
|
||||||
|
# for d, img in enumerate(images):
|
||||||
|
# img = Image.open(img)
|
||||||
|
# patch = transform(img)
|
||||||
|
|
||||||
|
# if str(device) != "cpu":
|
||||||
|
# patch = patch.to(device).half()
|
||||||
|
# else:
|
||||||
|
# patch = patch.to(device)
|
||||||
|
|
||||||
|
# patches.append(patch)
|
||||||
|
# if (d + 1) % bs == 0:
|
||||||
|
# patches = torch.stack(patches, dim=0)
|
||||||
|
# batch_patches.append(patches)
|
||||||
|
# patches = []
|
||||||
|
|
||||||
|
# if len(patches):
|
||||||
|
# patches = torch.stack(patches, dim=0)
|
||||||
|
# batch_patches.append(patches)
|
||||||
|
|
||||||
|
# features = np.zeros((0, embedding_size), dtype=np.float32)
|
||||||
|
# for patches in batch_patches:
|
||||||
|
# pred = net(patches)
|
||||||
|
# pred[torch.isinf(pred)] = 1.0
|
||||||
|
# feat = pred.cpu().data.numpy()
|
||||||
|
# features = np.vstack((features, feat))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# return features
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def featurize_1(images: list, transform, net, device, train=False) -> dict:
|
||||||
|
"""featurize each image and save into a dictionary
|
||||||
|
Args:
|
||||||
|
images: image paths
|
||||||
|
transform: test transform
|
||||||
|
net: pretrained model
|
||||||
|
device: cpu or cuda
|
||||||
|
Returns:
|
||||||
|
Dict (key: imagePath, value: feature)
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = test_preprocess(images, transform)
|
||||||
|
data = data.to(device)
|
||||||
|
net = net.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = net(data).data.numpy()
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def cosin_metric(x1, x2):
|
||||||
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||||
|
|
||||||
|
|
||||||
|
def threshold_search(y_score, y_true):
|
||||||
|
y_score = np.asarray(y_score)
|
||||||
|
y_true = np.asarray(y_true)
|
||||||
|
best_acc = 0
|
||||||
|
best_th = 0
|
||||||
|
for i in range(len(y_score)):
|
||||||
|
th = y_score[i]
|
||||||
|
y_test = (y_score >= th)
|
||||||
|
acc = np.mean((y_test == y_true).astype(int))
|
||||||
|
if acc > best_acc:
|
||||||
|
best_acc = acc
|
||||||
|
best_th = th
|
||||||
|
return best_acc, best_th
|
||||||
|
|
||||||
|
|
||||||
|
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg):
|
||||||
|
x = np.linspace(start=-1.0, stop=1.0, num=50, endpoint=True).tolist()
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(x, recall, color='red', label='recall')
|
||||||
|
plt.plot(x, recall_TN, color='black', label='recall_TN')
|
||||||
|
plt.plot(x, PrecisePos, color='blue', label='PrecisePos')
|
||||||
|
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
# plt.ylabel('Similarity')
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
plt.savefig('accuracy_recall_grid.png')
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy_recall(score, labels):
|
||||||
|
th = 0.1
|
||||||
|
squence = np.linspace(-1, 1, num=50)
|
||||||
|
# squence = [0.4]
|
||||||
|
recall, PrecisePos, PreciseNeg, recall_TN = [], [], [], []
|
||||||
|
for th in squence:
|
||||||
|
t_score = (score > th)
|
||||||
|
t_labels = (labels == 1)
|
||||||
|
# print(t_score)
|
||||||
|
# print(t_labels)
|
||||||
|
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||||
|
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||||
|
f_score = (score < th)
|
||||||
|
f_labels = (labels == 0)
|
||||||
|
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||||
|
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||||
|
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||||
|
|
||||||
|
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||||
|
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||||
|
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||||
|
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||||
|
showgrid(recall, recall_TN, PrecisePos, PreciseNeg)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy(feature_dict, pair_list, test_root):
|
||||||
|
with open(pair_list, 'r') as f:
|
||||||
|
pairs = f.readlines()
|
||||||
|
|
||||||
|
similarities = []
|
||||||
|
labels = []
|
||||||
|
for pair in pairs:
|
||||||
|
img1, img2, label = pair.split()
|
||||||
|
img1 = osp.join(test_root, img1)
|
||||||
|
img2 = osp.join(test_root, img2)
|
||||||
|
feature1 = feature_dict[img1].cpu().numpy()
|
||||||
|
feature2 = feature_dict[img2].cpu().numpy()
|
||||||
|
label = int(label)
|
||||||
|
|
||||||
|
similarity = cosin_metric(feature1, feature2)
|
||||||
|
similarities.append(similarity)
|
||||||
|
labels.append(label)
|
||||||
|
|
||||||
|
accuracy, threshold = threshold_search(similarities, labels)
|
||||||
|
# print('similarities >> {}'.format(similarities))
|
||||||
|
# print('labels >> {}'.format(labels))
|
||||||
|
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||||||
|
return accuracy, threshold
|
||||||
|
|
||||||
|
|
||||||
|
def deal_group_pair(pairList1, pairList2):
|
||||||
|
allsimilarity = []
|
||||||
|
one_similarity = []
|
||||||
|
for pair1 in pairList1:
|
||||||
|
for pair2 in pairList2:
|
||||||
|
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||||
|
one_similarity.append(similarity)
|
||||||
|
allsimilarity.append(max(one_similarity)) # 最大值
|
||||||
|
# allsimilarity.append(sum(one_similarity)/len(one_similarity)) # 均值
|
||||||
|
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
|
||||||
|
# print(allsimilarity)
|
||||||
|
# print(labels)
|
||||||
|
return allsimilarity
|
||||||
|
|
||||||
|
def compute_group_accuracy(content_list_read):
|
||||||
|
allSimilarity, allLabel= [], []
|
||||||
|
for data_loaded in content_list_read:
|
||||||
|
one_group_list = []
|
||||||
|
for i in range(2):
|
||||||
|
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||||
|
group = group_image(images, conf.test_batch_size)
|
||||||
|
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||||
|
one_group_list.append(d.values())
|
||||||
|
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||||
|
allLabel.append(data_loaded[-1])
|
||||||
|
allSimilarity.extend(similarity)
|
||||||
|
# print(allSimilarity)
|
||||||
|
# print(allLabel)
|
||||||
|
return allSimilarity, allLabel
|
||||||
|
|
||||||
|
def compute_contrast_accuracy(content_list_read):
|
||||||
|
|
||||||
|
npairs = 50
|
||||||
|
|
||||||
|
same_folder_pairs = content_list_read['same_folder_pairs']
|
||||||
|
cross_folder_pairs = content_list_read['cross_folder_pairs']
|
||||||
|
|
||||||
|
npairs = min((len(same_folder_pairs), len(cross_folder_pairs)))
|
||||||
|
|
||||||
|
Encoder = FeatsInterface(conf)
|
||||||
|
|
||||||
|
same_pairs = same_folder_pairs[:npairs]
|
||||||
|
cross_pairs = cross_folder_pairs[:npairs]
|
||||||
|
|
||||||
|
same_pairs_similarity = []
|
||||||
|
for i in range(len(same_pairs)):
|
||||||
|
images_a = [osp.join(conf.test_val, img) for img in same_pairs[i][0]]
|
||||||
|
images_b = [osp.join(conf.test_val, img) for img in same_pairs[i][1]]
|
||||||
|
|
||||||
|
feats_a = Encoder.inference(images_a)
|
||||||
|
feats_b = Encoder.inference(images_b)
|
||||||
|
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
|
||||||
|
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
|
||||||
|
|
||||||
|
feats_am = np.mean(feats_a, axis=0, keepdims=True)
|
||||||
|
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
|
||||||
|
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
|
||||||
|
|
||||||
|
same_pairs_similarity.append(np.mean(matrix))
|
||||||
|
|
||||||
|
'''保存相同 Barcode 图像对'''
|
||||||
|
# foldi = os.path.join('./result/same', f'{i}')
|
||||||
|
# if os.path.exists(foldi):
|
||||||
|
# shutil.rmtree(foldi)
|
||||||
|
# os.makedirs(foldi)
|
||||||
|
# else:
|
||||||
|
# os.makedirs(foldi)
|
||||||
|
# for ipt in range(len(images_a)):
|
||||||
|
# source_path = images_a[ipt]
|
||||||
|
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
|
||||||
|
# shutil.copy2(source_path, destination_path)
|
||||||
|
# for ipt in range(len(images_b)):
|
||||||
|
# source_path = images_b[ipt]
|
||||||
|
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
|
||||||
|
# shutil.copy2(source_path, destination_path)
|
||||||
|
|
||||||
|
cross_pairs_similarity = []
|
||||||
|
for i in range(len(cross_pairs)):
|
||||||
|
images_a = [osp.join(conf.test_val, img) for img in cross_pairs[i][0]]
|
||||||
|
images_b = [osp.join(conf.test_val, img) for img in cross_pairs[i][1]]
|
||||||
|
|
||||||
|
feats_a = Encoder.inference(images_a)
|
||||||
|
feats_b = Encoder.inference(images_b)
|
||||||
|
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
|
||||||
|
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
|
||||||
|
|
||||||
|
feats_am = np.mean(feats_a, axis=0, keepdims=True)
|
||||||
|
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
|
||||||
|
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
|
||||||
|
|
||||||
|
cross_pairs_similarity.append(np.mean(matrix))
|
||||||
|
|
||||||
|
'''保存不同 Barcode 图像对'''
|
||||||
|
# foldi = os.path.join('./result/cross', f'{i}')
|
||||||
|
# if os.path.exists(foldi):
|
||||||
|
# shutil.rmtree(foldi)
|
||||||
|
# os.makedirs(foldi)
|
||||||
|
# else:
|
||||||
|
# os.makedirs(foldi)
|
||||||
|
# for ipt in range(len(images_a)):
|
||||||
|
# source_path = images_a[ipt]
|
||||||
|
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
|
||||||
|
# shutil.copy2(source_path, destination_path)
|
||||||
|
# for ipt in range(len(images_b)):
|
||||||
|
# source_path = images_b[ipt]
|
||||||
|
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
|
||||||
|
# shutil.copy2(source_path, destination_path)
|
||||||
|
|
||||||
|
|
||||||
|
Thresh = np.linspace(-0.2, 1, 100)
|
||||||
|
|
||||||
|
Same = np.array(same_pairs_similarity)
|
||||||
|
Cross = np.array(cross_pairs_similarity)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].hist(Same, bins=60, edgecolor='black')
|
||||||
|
axs[0].set_xlim([-0.2, 1])
|
||||||
|
axs[0].set_title('Same Barcode')
|
||||||
|
|
||||||
|
axs[1].hist(Cross, bins=60, edgecolor='black')
|
||||||
|
axs[1].set_xlim([-0.2, 1])
|
||||||
|
axs[1].set_title('Cross Barcode')
|
||||||
|
|
||||||
|
TPFN = len(Same)
|
||||||
|
TNFP = len(Cross)
|
||||||
|
Recall_Pos, Recall_Neg = [], []
|
||||||
|
Precision_Pos, Precision_Neg = [], []
|
||||||
|
Correct = []
|
||||||
|
for th in Thresh:
|
||||||
|
TP = np.sum(Same > th)
|
||||||
|
FN = TPFN - TP
|
||||||
|
TN = np.sum(Cross < th)
|
||||||
|
FP = TNFP - TN
|
||||||
|
|
||||||
|
Recall_Pos.append(TP/TPFN)
|
||||||
|
Recall_Neg.append(TN/TNFP)
|
||||||
|
Precision_Pos.append(TP/(TP+FP))
|
||||||
|
Precision_Neg.append(TN/(TN+FN))
|
||||||
|
Correct.append((TN+TP)/(TPFN+TNFP))
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.plot(Thresh, Correct, 'r', label='Correct: (TN+TP)/(TPFN+TNFP)')
|
||||||
|
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
|
||||||
|
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
|
||||||
|
ax.plot(Thresh, Precision_Pos, 'c', label='Precision_Pos: TP/(TP+FP)')
|
||||||
|
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
|
||||||
|
|
||||||
|
ax.set_xlim([0, 1])
|
||||||
|
ax.set_ylim([0, 1])
|
||||||
|
ax.grid(True)
|
||||||
|
ax.set_title('PrecisePos & PreciseNeg')
|
||||||
|
ax.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print("Haved done!!!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Network Setup
|
||||||
|
if conf.testbackbone == 'resnet18':
|
||||||
|
# model = ResIRSE(conf.img_size, conf.embedding_size, conf.drop_ratio).to(conf.device)
|
||||||
|
model = resnet18().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'resnet34':
|
||||||
|
# model = resnet34().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'resnet50':
|
||||||
|
# model = resnet50().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'mobilevit_s':
|
||||||
|
# model = mobilevit_s().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'mobilenetv3':
|
||||||
|
# model = MobileNetV3_Small().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'mobilenet_v1':
|
||||||
|
# model = mobilenet_v1().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'PPLCNET_x1_0':
|
||||||
|
# model = PPLCNET_x1_0().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'PPLCNET_x0_5':
|
||||||
|
# model = PPLCNET_x0_5().to(conf.device)
|
||||||
|
# elif conf.backbone == 'PPLCNET_x2_5':
|
||||||
|
# model = PPLCNET_x2_5().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'mobilenet_v2':
|
||||||
|
# model = mobilenet_v2().to(conf.device)
|
||||||
|
# elif conf.testbackbone == 'resnet14':
|
||||||
|
# model = resnet14().to(conf.device)
|
||||||
|
else:
|
||||||
|
raise ValueError('Have not model {}'.format(conf.backbone))
|
||||||
|
|
||||||
|
print('load model {} '.format(conf.testbackbone))
|
||||||
|
# model = nn.DataParallel(model).to(conf.device)
|
||||||
|
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
||||||
|
model.eval()
|
||||||
|
if not conf.group_test:
|
||||||
|
images = unique_image(conf.test_list)
|
||||||
|
images = [osp.join(conf.test_val, img) for img in images]
|
||||||
|
|
||||||
|
groups = group_image(images, conf.test_batch_size) ##根据batch_size取图片
|
||||||
|
|
||||||
|
feature_dict = dict()
|
||||||
|
for group in groups:
|
||||||
|
d = featurize(group, conf.test_transform, model, conf.device)
|
||||||
|
feature_dict.update(d)
|
||||||
|
# print('feature_dict', feature_dict)
|
||||||
|
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_val)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Test Model: {conf.test_model}\n"
|
||||||
|
f"Accuracy: {accuracy:.3f}\n"
|
||||||
|
f"Threshold: {threshold:.3f}\n"
|
||||||
|
)
|
||||||
|
elif conf.group_test:
|
||||||
|
"""
|
||||||
|
conf.test_val: 测试数据集地址
|
||||||
|
conf.test_group_json:测试数据分组配置文件
|
||||||
|
"""
|
||||||
|
filename = conf.test_group_json
|
||||||
|
|
||||||
|
filename = "../cl/images_1.json"
|
||||||
|
with open(filename, 'r', encoding='utf-8') as file:
|
||||||
|
content_list_read = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
compute_contrast_accuracy(content_list_read)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Similarity, Label = compute_group_accuracy(content_list_read)
|
||||||
|
# print('allSimilarity >> {}'.format(Similarity))
|
||||||
|
# print('allLabel >> {}'.format(Label))
|
||||||
|
# compute_accuracy_recall(np.array(Similarity), np.array(Label))
|
||||||
|
# # compute_group_accuracy(data_loaded)
|
||||||
|
#
|
||||||
|
# =============================================================================
|
88
contrast/feat_extract/model/BAM.py
Normal file
88
contrast/feat_extract/model/BAM.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
from torch.nn import init
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.shape[0], -1)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAttention(nn.Module):
|
||||||
|
def __int__(self, channel, reduction, num_layers):
|
||||||
|
super(ChannelAttention, self).__init__()
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
gate_channels = [channel]
|
||||||
|
gate_channels += [len(channel) // reduction] * num_layers
|
||||||
|
gate_channels += [channel]
|
||||||
|
|
||||||
|
self.ca = nn.Sequential()
|
||||||
|
self.ca.add_module('flatten', Flatten())
|
||||||
|
for i in range(len(gate_channels) - 2):
|
||||||
|
self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
|
||||||
|
self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
|
||||||
|
self.ca.add_module('', nn.ReLU())
|
||||||
|
self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.avgpool(x)
|
||||||
|
res = self.ca(res)
|
||||||
|
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
|
||||||
|
super(SpatialAttention).__init__()
|
||||||
|
self.sa = nn.Sequential()
|
||||||
|
self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
|
||||||
|
self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
|
||||||
|
self.sa.add_module('', nn.ReLU())
|
||||||
|
for i in range(num_lay):
|
||||||
|
self.sa.add_module('', nn.Conv2d(kernel_size=3,
|
||||||
|
in_channels=(channel // reduction),
|
||||||
|
out_channels=(channel // reduction),
|
||||||
|
padding=1,
|
||||||
|
dilation=2))
|
||||||
|
self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
|
||||||
|
self.sa.add_module('', nn.ReLU())
|
||||||
|
self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.sa(x)
|
||||||
|
res = res.expand_as(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class BAMblock(nn.Module):
|
||||||
|
def __init__(self, channel=512, reduction=16, dia_val=2):
|
||||||
|
super(BAMblock, self).__init__()
|
||||||
|
self.ca = ChannelAttention(channel, reduction)
|
||||||
|
self.sa = SpatialAttention(channel, reduction, dia_val)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal(m.weight, mode='fan_out')
|
||||||
|
if m.bais is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, _, _ = x.size()
|
||||||
|
sa_out = self.sa(x)
|
||||||
|
ca_out = self.ca(x)
|
||||||
|
weight = self.sigmoid(sa_out + ca_out)
|
||||||
|
out = (1 + weight) * x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(512 // 14)
|
70
contrast/feat_extract/model/CBAM.py
Normal file
70
contrast/feat_extract/model/CBAM.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
class channelAttention(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(channelAttention, self).__init__()
|
||||||
|
self.Maxpooling = nn.AdaptiveMaxPool2d(1)
|
||||||
|
self.Avepooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.ca = nn.Sequential()
|
||||||
|
self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
|
||||||
|
self.ca.add_module('Relu', nn.ReLU())
|
||||||
|
self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
M_out = self.Maxpooling(x)
|
||||||
|
A_out = self.Avepooling(x)
|
||||||
|
M_out = self.ca(M_out)
|
||||||
|
A_out = self.ca(A_out)
|
||||||
|
out = self.sigmod(M_out+A_out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
max_result, _ = torch.max(x, dim=1, keepdim=True)
|
||||||
|
avg_result = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
result = torch.cat([max_result, avg_result], dim=1)
|
||||||
|
output = self.conv(result)
|
||||||
|
output = self.sigmoid(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class CBAM(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16, kernel_size=7):
|
||||||
|
super().__init__()
|
||||||
|
self.ca = channelAttention(channel, reduction)
|
||||||
|
self.sa = SpatialAttention(kernel_size)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():#权重初始化
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# b,c_,_ = x.size()
|
||||||
|
# residual = x
|
||||||
|
out = x*self.ca(x)
|
||||||
|
out = out*self.sa(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
input=torch.randn(50,512,7,7)
|
||||||
|
kernel_size=input.shape[2]
|
||||||
|
cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
|
||||||
|
output=cbam(input)
|
||||||
|
print(output.shape)
|
33
contrast/feat_extract/model/Tool.py
Normal file
33
contrast/feat_extract/model/Tool.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
class GeM(nn.Module):
|
||||||
|
def __init__(self, p=3, eps=1e-6):
|
||||||
|
super(GeM, self).__init__()
|
||||||
|
self.p = nn.Parameter(torch.ones(1) * p)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.gem(x, p=self.p, eps=self.eps, stride = 2)
|
||||||
|
|
||||||
|
def gem(self, x, p=3, eps=1e-6, stride = 2):
|
||||||
|
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + \
|
||||||
|
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||||
|
', ' + 'eps=' + str(self.eps) + ')'
|
||||||
|
|
||||||
|
class TripletLoss(nn.Module):
|
||||||
|
def __init__(self, margin):
|
||||||
|
super(TripletLoss, self).__init__()
|
||||||
|
self.margin = margin
|
||||||
|
|
||||||
|
def forward(self, anchor, positive, negative, size_average = True):
|
||||||
|
distance_positive = (anchor-positive).pow(2).sum(1)
|
||||||
|
distance_negative = (anchor-negative).pow(2).sum(1)
|
||||||
|
losses = F.relu(distance_negative-distance_positive+self.margin)
|
||||||
|
return losses.mean() if size_average else losses.sum()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('')
|
11
contrast/feat_extract/model/__init__.py
Normal file
11
contrast/feat_extract/model/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from .fmobilenet import FaceMobileNet
|
||||||
|
from .resnet_face import ResIRSE
|
||||||
|
from .mobilevit import mobilevit_s
|
||||||
|
from .metric import ArcFace, CosFace
|
||||||
|
from .loss import FocalLoss
|
||||||
|
from .resbam import resnet
|
||||||
|
from .resnet_pre import resnet18, resnet34, resnet50, resnet14
|
||||||
|
from .mobilenet_v2 import mobilenet_v2
|
||||||
|
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||||
|
# from .mobilenet_v1 import mobilenet_v1
|
||||||
|
from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, PPLCNET_x2_5
|
BIN
contrast/feat_extract/model/__pycache__/BAM.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/BAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-310.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
124
contrast/feat_extract/model/fmobilenet.py
Normal file
124
contrast/feat_extract/model/fmobilenet.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.shape[0], -1)
|
||||||
|
|
||||||
|
class ConvBn(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||||
|
nn.BatchNorm2d(out_c)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnPrelu(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
ConvBn(in_c, out_c, kernel, stride, padding, groups),
|
||||||
|
nn.PReLU(out_c)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWise(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
|
||||||
|
ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
|
||||||
|
ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseRes(nn.Module):
|
||||||
|
"""DepthWise with Residual"""
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDepthWiseRes(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = nn.Sequential(*[
|
||||||
|
DepthWiseRes(channels, channels, kernel, stride, padding, groups)
|
||||||
|
for _ in range(num_block)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceMobileNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, embedding_size):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
|
||||||
|
self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
|
||||||
|
self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
|
||||||
|
self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
|
||||||
|
self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
|
||||||
|
self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||||
|
self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
|
||||||
|
self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||||
|
self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
|
||||||
|
self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
|
||||||
|
self.flatten = Flatten()
|
||||||
|
self.linear = nn.Linear(2048, embedding_size, bias=False)
|
||||||
|
self.bn = nn.BatchNorm1d(embedding_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
#print('x',x.shape)
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.conv4(out)
|
||||||
|
out = self.conv5(out)
|
||||||
|
out = self.conv6(out)
|
||||||
|
out = self.conv7(out)
|
||||||
|
out = self.conv8(out)
|
||||||
|
out = self.conv9(out)
|
||||||
|
out = self.conv10(out)
|
||||||
|
out = self.flatten(out)
|
||||||
|
out = self.linear(out)
|
||||||
|
out = self.bn(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
x = Image.open("../samples/009.jpg").convert('L')
|
||||||
|
x = x.resize((128, 128))
|
||||||
|
x = np.asarray(x, dtype=np.float32)
|
||||||
|
x = x[None, None, ...]
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
net = FaceMobileNet(512)
|
||||||
|
net.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
out = net(x)
|
||||||
|
print(out.shape)
|
233
contrast/feat_extract/model/lcnet.py
Normal file
233
contrast/feat_extract/model/lcnet.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import thop
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import softpool_cuda
|
||||||
|
# from SoftPool import soft_pool2d, SoftPool2d
|
||||||
|
# except ImportError:
|
||||||
|
# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
|
||||||
|
# exit(0)
|
||||||
|
|
||||||
|
NET_CONFIG = {
|
||||||
|
# k, in_c, out_c, s, use_se
|
||||||
|
"blocks2": [[3, 16, 32, 1, False]],
|
||||||
|
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||||
|
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||||
|
"blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
|
||||||
|
[5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||||
|
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||||
|
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def autopad(k, p=None):
|
||||||
|
if p is None:
|
||||||
|
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor=8, min_value=None):
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwish(nn.Module):
|
||||||
|
def __init__(self, inplace=True):
|
||||||
|
super(HardSwish, self).__init__()
|
||||||
|
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.relu6(x+3) / 6
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoid(nn.Module):
|
||||||
|
def __init__(self, inplace=True):
|
||||||
|
super(HardSigmoid, self).__init__()
|
||||||
|
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (self.relu6(x+3)) / 6
|
||||||
|
|
||||||
|
|
||||||
|
class SELayer(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(SELayer, self).__init__()
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(channel, channel // reduction, bias=False),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(channel // reduction, channel, bias=False),
|
||||||
|
HardSigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.size()
|
||||||
|
y = self.avgpool(x).view(b, c)
|
||||||
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseSeparable(nn.Module):
|
||||||
|
def __init__(self, inp, oup, dw_size, stride, use_se=False):
|
||||||
|
super(DepthwiseSeparable, self).__init__()
|
||||||
|
self.use_se = use_se
|
||||||
|
self.stride = stride
|
||||||
|
self.inp = inp
|
||||||
|
self.oup = oup
|
||||||
|
self.dw_size = dw_size
|
||||||
|
self.dw_sp = nn.Sequential(
|
||||||
|
nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
|
||||||
|
padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
|
||||||
|
nn.BatchNorm2d(self.inp),
|
||||||
|
HardSwish(),
|
||||||
|
|
||||||
|
nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(self.oup),
|
||||||
|
HardSwish(),
|
||||||
|
)
|
||||||
|
self.se = SELayer(self.oup)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dw_sp(x)
|
||||||
|
if self.use_se:
|
||||||
|
x = self.se(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PP_LCNet(nn.Module):
|
||||||
|
def __init__(self, scale=1.0, class_num=10, class_expand=1280, dropout_prob=0.2):
|
||||||
|
super(PP_LCNet, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
|
||||||
|
kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||||
|
self.blocks2 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks3 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks4 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
|
||||||
|
])
|
||||||
|
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||||
|
self.blocks5 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks6 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.GAP = nn.AdaptiveAvgPool2d(1)
|
||||||
|
|
||||||
|
self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
|
||||||
|
out_channels=class_expand,
|
||||||
|
kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
|
||||||
|
self.hardswish = HardSwish()
|
||||||
|
self.dropout = nn.Dropout(p=dropout_prob)
|
||||||
|
|
||||||
|
self.fc = nn.Linear(class_expand, class_num)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
print(x.shape)
|
||||||
|
x = self.blocks2(x)
|
||||||
|
print(x.shape)
|
||||||
|
x = self.blocks3(x)
|
||||||
|
print(x.shape)
|
||||||
|
x = self.blocks4(x)
|
||||||
|
print(x.shape)
|
||||||
|
x = self.blocks5(x)
|
||||||
|
print(x.shape)
|
||||||
|
x = self.blocks6(x)
|
||||||
|
print(x.shape)
|
||||||
|
|
||||||
|
x = self.GAP(x)
|
||||||
|
x = self.last_conv(x)
|
||||||
|
x = self.hardswish(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = torch.flatten(x, start_dim=1, end_dim=-1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_25(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.25, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_35(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.35, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_5(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.5, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_75(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.75, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x1_0(**kwargs):
|
||||||
|
model = PP_LCNet(scale=1.0, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x1_5(**kwargs):
|
||||||
|
model = PP_LCNet(scale=1.5, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x2_0(**kwargs):
|
||||||
|
model = PP_LCNet(scale=2.0, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def PPLCNET_x2_5(**kwargs):
|
||||||
|
model = PP_LCNet(scale=2.5, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# input = torch.randn(1, 3, 640, 640)
|
||||||
|
# model = PPLCNET_x2_5()
|
||||||
|
# flops, params = thop.profile(model, inputs=(input,))
|
||||||
|
# print('flops:', flops / 1000000000)
|
||||||
|
# print('params:', params / 1000000)
|
||||||
|
|
||||||
|
model = PPLCNET_x1_0()
|
||||||
|
# model_1 = PW_Conv(3, 16)
|
||||||
|
input = torch.randn(2, 3, 256, 256)
|
||||||
|
print(input.shape)
|
||||||
|
output = model(input)
|
||||||
|
print(output.shape) # [1, num_class]
|
||||||
|
|
18
contrast/feat_extract/model/loss.py
Normal file
18
contrast/feat_extract/model/loss.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, gamma=2):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
self.ce = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
|
||||||
|
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
||||||
|
logp = self.ce(input, target)
|
||||||
|
p = torch.exp(-logp)
|
||||||
|
loss = (1 - p) ** self.gamma * logp
|
||||||
|
return loss.mean()
|
83
contrast/feat_extract/model/metric.py
Normal file
83
contrast/feat_extract/model/metric.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# Definition of ArcFace loss and CosFace loss
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ArcFace(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
||||||
|
"""ArcFace formula:
|
||||||
|
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
||||||
|
Note that:
|
||||||
|
0 <= m + theta <= Pi
|
||||||
|
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
||||||
|
we have:
|
||||||
|
cos(theta) < cos(Pi - m)
|
||||||
|
So we can use cos(Pi - m) as threshold to check whether
|
||||||
|
(m + theta) go out of [0, Pi]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_size: usually 128, 256, 512 ...
|
||||||
|
class_num: num of people when training
|
||||||
|
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||||
|
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = embedding_size
|
||||||
|
self.out_features = class_num
|
||||||
|
self.s = s
|
||||||
|
self.m = m
|
||||||
|
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
||||||
|
nn.init.xavier_uniform_(self.weight)
|
||||||
|
|
||||||
|
self.cos_m = math.cos(m)
|
||||||
|
self.sin_m = math.sin(m)
|
||||||
|
self.th = math.cos(math.pi - m)
|
||||||
|
self.mm = math.sin(math.pi - m) * m
|
||||||
|
|
||||||
|
def forward(self, input, label):
|
||||||
|
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
||||||
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||||
|
# print('F.normalize(input)',input.shape)
|
||||||
|
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
||||||
|
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
||||||
|
phi = cosine * self.cos_m - sine * self.sin_m
|
||||||
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
||||||
|
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
||||||
|
# update y_i by phi in cosine
|
||||||
|
output = cosine * 1.0 # make backward works
|
||||||
|
batch_size = len(output)
|
||||||
|
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||||
|
# print(f'output {(output * self.s).shape}')
|
||||||
|
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
||||||
|
return output * self.s
|
||||||
|
|
||||||
|
|
||||||
|
class CosFace(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
embedding_size: usually 128, 256, 512 ...
|
||||||
|
class_num: num of people when training
|
||||||
|
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||||
|
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.s = s
|
||||||
|
self.m = m
|
||||||
|
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||||
|
nn.init.xavier_uniform_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, input, label):
|
||||||
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||||
|
phi = cosine - self.m
|
||||||
|
output = cosine * 1.0 # make backward works
|
||||||
|
batch_size = len(output)
|
||||||
|
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||||
|
return output * self.s
|
148
contrast/feat_extract/model/mobilenet_v1.py
Normal file
148
contrast/feat_extract/model/mobilenet_v1.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.ops.misc import Conv2dNormActivation
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MobileNetV1",
|
||||||
|
"DepthWiseSeparableConv2d",
|
||||||
|
"mobilenet_v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV1(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: int = conf.embedding_size,
|
||||||
|
) -> None:
|
||||||
|
super(MobileNetV1, self).__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
Conv2dNormActivation(3,
|
||||||
|
32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
activation_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
|
||||||
|
DepthWiseSeparableConv2d(32, 64, 1),
|
||||||
|
DepthWiseSeparableConv2d(64, 128, 2),
|
||||||
|
DepthWiseSeparableConv2d(128, 128, 1),
|
||||||
|
DepthWiseSeparableConv2d(128, 256, 2),
|
||||||
|
DepthWiseSeparableConv2d(256, 256, 1),
|
||||||
|
DepthWiseSeparableConv2d(256, 512, 2),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 1024, 2),
|
||||||
|
DepthWiseSeparableConv2d(1024, 1024, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d((7, 7))
|
||||||
|
|
||||||
|
self.classifier = nn.Linear(1024, num_classes)
|
||||||
|
|
||||||
|
# Initialize neural network weights
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
out = self._forward_impl(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Support torch.script function
|
||||||
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||||
|
out = self.features(x)
|
||||||
|
out = self.avgpool(out)
|
||||||
|
out = torch.flatten(out, 1)
|
||||||
|
out = self.classifier(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _initialize_weights(self) -> None:
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
nn.init.normal_(module.weight, 0, 0.01)
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseSeparableConv2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super(DepthWiseSeparableConv2d, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
if stride not in [1, 2]:
|
||||||
|
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
Conv2dNormActivation(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
groups=in_channels,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
activation_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
Conv2dNormActivation(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
activation_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
out = self.conv(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
|
||||||
|
model = MobileNetV1(**kwargs)
|
||||||
|
|
||||||
|
return model
|
200
contrast/feat_extract/model/mobilenet_v2.py
Normal file
200
contrast/feat_extract/model/mobilenet_v2.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from torch import nn
|
||||||
|
from .utils import load_state_dict_from_url
|
||||||
|
from ..config import config as conf
|
||||||
|
|
||||||
|
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||||
|
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_divisible(v, divisor, min_value=None):
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by 8
|
||||||
|
It can be seen here:
|
||||||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
:param v:
|
||||||
|
:param divisor:
|
||||||
|
:param min_value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNReLU(nn.Sequential):
|
||||||
|
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
super(ConvBNReLU, self).__init__(
|
||||||
|
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||||
|
norm_layer(out_planes),
|
||||||
|
nn.ReLU6(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
hidden_dim = int(round(inp * expand_ratio))
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
if expand_ratio != 1:
|
||||||
|
# pw
|
||||||
|
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
|
||||||
|
layers.extend([
|
||||||
|
# dw
|
||||||
|
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
])
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=conf.embedding_size,
|
||||||
|
width_mult=1.0,
|
||||||
|
inverted_residual_setting=None,
|
||||||
|
round_nearest=8,
|
||||||
|
block=None,
|
||||||
|
norm_layer=None):
|
||||||
|
"""
|
||||||
|
MobileNet V2 main class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of classes
|
||||||
|
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||||
|
inverted_residual_setting: Network structure
|
||||||
|
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||||
|
Set to 1 to turn off rounding
|
||||||
|
block: Module specifying inverted residual building block for mobilenet
|
||||||
|
norm_layer: Module specifying the normalization layer to use
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(MobileNetV2, self).__init__()
|
||||||
|
|
||||||
|
if block is None:
|
||||||
|
block = InvertedResidual
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
input_channel = 32
|
||||||
|
last_channel = 1280
|
||||||
|
|
||||||
|
if inverted_residual_setting is None:
|
||||||
|
inverted_residual_setting = [
|
||||||
|
# t, c, n, s
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 2],
|
||||||
|
[6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 2],
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 2],
|
||||||
|
[6, 320, 1, 1],
|
||||||
|
]
|
||||||
|
|
||||||
|
# only check the first element, assuming user knows t,c,n,s are required
|
||||||
|
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||||
|
raise ValueError("inverted_residual_setting should be non-empty "
|
||||||
|
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||||
|
|
||||||
|
# building first layer
|
||||||
|
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||||
|
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||||
|
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
||||||
|
# building inverted residual blocks
|
||||||
|
for t, c, n, s in inverted_residual_setting:
|
||||||
|
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||||
|
for i in range(n):
|
||||||
|
stride = s if i == 0 else 1
|
||||||
|
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||||
|
input_channel = output_channel
|
||||||
|
# building last several layers
|
||||||
|
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
|
||||||
|
# make it nn.Sequential
|
||||||
|
self.features = nn.Sequential(*features)
|
||||||
|
|
||||||
|
# building classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(self.last_channel, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight initialization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def _forward_impl(self, x):
|
||||||
|
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||||
|
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||||
|
x = self.features(x)
|
||||||
|
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
|
||||||
|
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV2 architecture from
|
||||||
|
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
model = MobileNetV2(**kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||||
|
progress=progress)
|
||||||
|
src_state_dict = state_dict
|
||||||
|
target_state_dict = model.state_dict()
|
||||||
|
skip_keys = []
|
||||||
|
# skip mismatch size tensors in case of pretraining
|
||||||
|
for k in src_state_dict.keys():
|
||||||
|
if k not in target_state_dict:
|
||||||
|
continue
|
||||||
|
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||||
|
skip_keys.append(k)
|
||||||
|
for k in skip_keys:
|
||||||
|
del src_state_dict[k]
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||||
|
#.load_state_dict(state_dict)
|
||||||
|
return model
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user