update
12
.gitignore
vendored
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
OneByOneSimilarity.py
|
||||||
|
Return_purchase_test_analysis.py
|
||||||
|
compairsonOneByOne/
|
||||||
|
comparative/
|
||||||
|
comparisonData/
|
||||||
|
comparisonResult/
|
||||||
|
data_test/
|
||||||
|
imageQualityData/
|
||||||
|
search_library/
|
||||||
|
Single_purchase_data/
|
||||||
|
*.pth
|
||||||
|
*.pt
|
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
15
.idea/deployment.xml
generated
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
|
<serverData>
|
||||||
|
<paths name="lc@192.168.1.142:22 password">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="imageIQA" local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
</serverData>
|
||||||
|
<option name="myAutoUpload" value="ALWAYS" />
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
8
.idea/kmeans.iml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.8 (my_env)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.8" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (my_env)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/kmeans.iml" filepath="$PROJECT_DIR$/.idea/kmeans.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
169
ComparisonAnalysis.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def showHist(err, correct, date):
|
||||||
|
err = np.array(err)
|
||||||
|
correct = np.array(correct)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].hist(err, bins=50, edgecolor='black')
|
||||||
|
axs[0].set_xlim([0, 1])
|
||||||
|
axs[0].set_title('err')
|
||||||
|
axs[1].hist(correct, bins=50, edgecolor='black')
|
||||||
|
axs[1].set_xlim([0, 1])
|
||||||
|
axs[1].set_title('correct')
|
||||||
|
plt.title(date)
|
||||||
|
plt.savefig('comparisonResult/Hist_' + date + '.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def showgrid(recall, prec, date):
|
||||||
|
x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(x, recall, color='red', label='recall')
|
||||||
|
plt.plot(x, prec, color='blue', label='PrecisePos')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
plt.title(date)
|
||||||
|
# plt.ylabel('Similarity')
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
# plt.savefig('accuracy_recall_grid.png')
|
||||||
|
plt.savefig('comparisonResult/grid_' + date + '.png')
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def read_txt_file(filePth):
|
||||||
|
with open(filePth, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
dict = {}
|
||||||
|
all_list = []
|
||||||
|
barcode_list = []
|
||||||
|
similarity_list = []
|
||||||
|
split_flag = False
|
||||||
|
clean_lines = [line.strip().replace("'", '').replace('"', '') for line in lines]
|
||||||
|
for line in clean_lines:
|
||||||
|
stripped_line = line.strip()
|
||||||
|
if not stripped_line:
|
||||||
|
split_flag = False
|
||||||
|
all_list.append(dict)
|
||||||
|
dict = {}
|
||||||
|
barcode_list, similarity_list = [], []
|
||||||
|
continue
|
||||||
|
|
||||||
|
label = line.split(':')[0]
|
||||||
|
value = line.split(':')[1]
|
||||||
|
if label == 'SeqDir':
|
||||||
|
dict['SeqDir'] = value
|
||||||
|
if label == 'Deleted':
|
||||||
|
dict['Deleted'] = value
|
||||||
|
if label == 'List':
|
||||||
|
split_flag = True
|
||||||
|
continue
|
||||||
|
if split_flag:
|
||||||
|
dict['barcode'] = barcode_list
|
||||||
|
dict['similarity'] = similarity_list
|
||||||
|
barcode_list.append(label)
|
||||||
|
similarity_list.append(value)
|
||||||
|
all_list.append(dict)
|
||||||
|
all_list = [d for d in all_list if d]
|
||||||
|
return all_list
|
||||||
|
|
||||||
|
|
||||||
|
def move_file(seqdirs, filePth):
|
||||||
|
comparisonData = os.path.basename(filePth).split('.')[0]
|
||||||
|
comparisonData = os.sep.join(['D:\Project\ieemoo\image_quality_assessment\comparisonData', comparisonData])
|
||||||
|
print('comparisonData', comparisonData)
|
||||||
|
for seqdir in seqdirs:
|
||||||
|
print(seqdir[0], seqdir[1])
|
||||||
|
err_pair = os.sep.join([comparisonData, 'err_pair', seqdir[1]])
|
||||||
|
if not os.path.exists(err_pair):
|
||||||
|
os.makedirs(err_pair)
|
||||||
|
for comparison in os.listdir(comparisonData):
|
||||||
|
print(os.sep.join([comparisonData, comparison]))
|
||||||
|
print(os.sep.join([err_pair, comparison]))
|
||||||
|
try:
|
||||||
|
if seqdir[0] in comparison:
|
||||||
|
shutil.copytree(os.sep.join([comparisonData, comparison]),
|
||||||
|
os.sep.join([err_pair, comparison]))
|
||||||
|
if seqdir[1] in comparison:
|
||||||
|
shutil.copytree(os.sep.join([comparisonData, comparison]),
|
||||||
|
os.sep.join([err_pair, comparison]))
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def compute_recall_precision(err_similarity, correct_similarity, date=None):
|
||||||
|
# ths = np.linspace(0, 1, 11)
|
||||||
|
ths = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||||
|
recall, prec = [], []
|
||||||
|
for th in ths:
|
||||||
|
TP = len([num for num in correct_similarity if num >= th])
|
||||||
|
FP = len([num for num in err_similarity if num >= th])
|
||||||
|
if (TP + FP) == 0:
|
||||||
|
prec.append(1)
|
||||||
|
recall.append(0)
|
||||||
|
else:
|
||||||
|
prec.append(TP / (TP + FP))
|
||||||
|
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||||||
|
# print(recall)
|
||||||
|
# print(prec)
|
||||||
|
showgrid(recall, prec, date)
|
||||||
|
return recall, prec
|
||||||
|
|
||||||
|
|
||||||
|
def deal_one_file(filePth):
|
||||||
|
date = filePth.split('\\')[-1].split('.')[0]
|
||||||
|
all_list = read_txt_file(filePth)
|
||||||
|
num = 0
|
||||||
|
err_barcode_list, correct_barcode_list, err_similarity, correct_similarity, seqdirs = [], [], [], [], []
|
||||||
|
for s_list in all_list:
|
||||||
|
seqdir = s_list['SeqDir'].strip()
|
||||||
|
delete = s_list['Deleted'].strip()
|
||||||
|
barcodes = [s.strip() for s in s_list['barcode']]
|
||||||
|
similarity = [float(s.strip()) for s in s_list['similarity']]
|
||||||
|
if delete in barcodes[:1]:
|
||||||
|
num += 1
|
||||||
|
correct_barcode_list.append(delete)
|
||||||
|
correct_similarity.append(similarity[0])
|
||||||
|
else:
|
||||||
|
seqdirs.append((seqdir, delete))
|
||||||
|
err_barcode_list.append(delete)
|
||||||
|
err_similarity.append(similarity[0])
|
||||||
|
|
||||||
|
compute_recall_precision(err_similarity, correct_similarity, date)
|
||||||
|
showHist(err_similarity, correct_similarity, date)
|
||||||
|
move_file(seqdirs, filePth) # 统计错误对
|
||||||
|
return err_similarity, correct_similarity
|
||||||
|
|
||||||
|
|
||||||
|
def main(filesPth):
|
||||||
|
err_similarities, correct_similarities = [], []
|
||||||
|
dir_lists = os.listdir(filesPth)
|
||||||
|
# dir_lists = ['deletedBarcode_10_0716_pm_3.txt']
|
||||||
|
for name in dir_lists:
|
||||||
|
filePth = os.sep.join([filesPth, name])
|
||||||
|
err_similarity, correct_similarity = deal_one_file(filePth)
|
||||||
|
err_similarities += err_similarity
|
||||||
|
correct_similarities += correct_similarity
|
||||||
|
compute_recall_precision(err_similarities, correct_similarities, date='zong')
|
||||||
|
showHist(err_similarities, correct_similarities, date='zong')
|
||||||
|
print('END')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# filePth = 'D:\Project\ieemoo\image_quality_assessment\comparisonResult\deletedBarcode_10_0722_pm_01.txt'
|
||||||
|
# read_txt_file(filePth)
|
||||||
|
|
||||||
|
filesPth = 'comparisonResult'
|
||||||
|
main(filesPth)
|
||||||
|
|
||||||
|
# for name in os.listdir('comparisonData\deletedBarcode_15_0628_pm\err_pair'):
|
||||||
|
# path = os.sep.join(['comparisonData\deletedBarcode_15_0628_pm\err_pair', name])
|
||||||
|
# print(len(os.listdir(path)))
|
||||||
|
# if len(os.listdir(path)) < 2:
|
||||||
|
# print(path)
|
BIN
__pycache__/ComparisonAnalysis.cpython-38.pyc
Normal file
BIN
__pycache__/Return_purchase_test_analysis.cpython-38.pyc
Normal file
BIN
__pycache__/dealdata.cpython-38.pyc
Normal file
BIN
__pycache__/imgcompare.cpython-38.pyc
Normal file
BIN
comprehensive_similarity.png
Normal file
After Width: | Height: | Size: 20 KiB |
12
contrast/.idea/contrast_nettest.iml
generated
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="GOOGLE" />
|
||||||
|
<option name="myDocStringFormat" value="Google" />
|
||||||
|
</component>
|
||||||
|
</module>
|
19
contrast/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="6">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="thop" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="regex" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="tensorboardX" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="torch" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="terminaltables" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
6
contrast/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
4
contrast/.idea/misc.xml
generated
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
contrast/.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/contrast_nettest.iml" filepath="$PROJECT_DIR$/.idea/contrast_nettest.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
contrast/.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
60
contrast/README.md
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Build Your Own Face Recognition Model
|
||||||
|
|
||||||
|
训练你自己的人脸识别模型!
|
||||||
|
|
||||||
|
人脸识别从原始的 Softmax Embbedding,经过2015年 Facenet 领衔的 triple loss metric learning,然后是 additional margin metric learning。这次的系列博客实现的是2018年提出的 ArcFace 。
|
||||||
|
|
||||||
|
|
||||||
|
### 依赖
|
||||||
|
```py
|
||||||
|
Python >= 3.6
|
||||||
|
pytorch >= 1.0
|
||||||
|
torchvision
|
||||||
|
imutils
|
||||||
|
pillow == 6.2.0
|
||||||
|
tqdm
|
||||||
|
```
|
||||||
|
|
||||||
|
### 数据准备
|
||||||
|
|
||||||
|
+ 下载WebFace(百度一下)以及干净的图片列表([BaiduYun](http://pan.baidu.com/s/1hrKpbm8))用于训练
|
||||||
|
+ 下载LFW([BaiduYun](https://pan.baidu.com/s/12IKEpvM8-tYgSaUiz_adGA) 提取码 u7z4)以及[测试列表](https://github.com/ronghuaiyang/arcface-pytorch/blob/master/lfw_test_pair.txt)用于测试
|
||||||
|
+ 删除WebFace中的脏数据,使用`utils.py`
|
||||||
|
|
||||||
|
### 配置参数
|
||||||
|
|
||||||
|
见`config.py`
|
||||||
|
|
||||||
|
### 训练
|
||||||
|
|
||||||
|
天然支持单机多GPU训练
|
||||||
|
|
||||||
|
```py
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1
|
||||||
|
python train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试
|
||||||
|
|
||||||
|
```py
|
||||||
|
python test.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 博客
|
||||||
|
|
||||||
|
虽然有关人脸识别的介绍已经很多了,但受到许多 [Build-Your-Own-x](https://github.com/danistefanovic/build-your-own-x) 文章的启发,就想写一个 Build Your Own Face Model 的博客,愿于他人有益。
|
||||||
|
|
||||||
|
+ 001 [数据准备](./blog/data.md)
|
||||||
|
+ 002 [模型架构](./blog/model.md)
|
||||||
|
+ 003 [损失函数](./blog/loss.md)
|
||||||
|
+ 004 [度量函数](./blog/metric.md)
|
||||||
|
+ 005 [训练](./blog/train.md)
|
||||||
|
+ 006 [测试](./blog/test.md)
|
||||||
|
|
||||||
|
### 致谢
|
||||||
|
|
||||||
|
虽然并未注明,但本项目中有一些代码直接复制或者修改自以下仓库,许可证与之相同:
|
||||||
|
|
||||||
|
+ [insightFace](https://github.com/deepinsight/insightface/tree/master/recognition)
|
||||||
|
+ [insightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||||
|
+ [arcface-pytorch](https://github.com/ronghuaiyang/arcface-pytorch)
|
BIN
contrast/__pycache__/img_data.cpython-310.pyc
Normal file
BIN
contrast/__pycache__/img_data.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/logic.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/search.cpython-310.pyc
Normal file
BIN
contrast/__pycache__/search.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/test_logic.cpython-310.pyc
Normal file
BIN
contrast/__pycache__/test_logic.cpython-38.pyc
Normal file
21
contrast/config.py.bak
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
host = "192.168.1.28"
|
||||||
|
port = "19530"
|
||||||
|
|
||||||
|
embedding_size = 256
|
||||||
|
img_size = 224
|
||||||
|
test_transform = T.Compose([
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((img_size, img_size)),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[0.5], std=[0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# test_model = "checkpoints/resnet18_our388.pth"
|
||||||
|
test_model = "checkpoints/mobilenetv3Large_our388_noPara.pth"
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
config = Config()
|
21
contrast/dataset.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(conf, training=True):
|
||||||
|
if training:
|
||||||
|
dataroot = conf.train_root
|
||||||
|
transform = conf.train_transform
|
||||||
|
batch_size = conf.train_batch_size
|
||||||
|
else:
|
||||||
|
dataroot = conf.test_root
|
||||||
|
transform = conf.test_transform
|
||||||
|
batch_size = conf.test_batch_size
|
||||||
|
|
||||||
|
data = ImageFolder(dataroot, transform=transform)
|
||||||
|
class_num = len(data.classes)
|
||||||
|
loader = DataLoader(data, batch_size=batch_size, shuffle=True,
|
||||||
|
pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
|
return loader, class_num
|
129
contrast/img_data.py
Normal file
66
contrast/logic.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tools.config import cfg as conf
|
||||||
|
sys.path.append('contrast')
|
||||||
|
# from config import config as conf
|
||||||
|
# from model import resnet18, MobileNetV3_Large
|
||||||
|
from test_logic import similarity_interface
|
||||||
|
from img_data import queueImgs_add
|
||||||
|
|
||||||
|
# import pymilvus
|
||||||
|
|
||||||
|
|
||||||
|
class datacollection:
|
||||||
|
barcode_flag = None
|
||||||
|
add_flag = None
|
||||||
|
queImgsDict = None
|
||||||
|
mainMilvus = None
|
||||||
|
tempLibList = None
|
||||||
|
model = None
|
||||||
|
barcode_list = None
|
||||||
|
actionModel = True # 是否是运行模式, False是测试模式 True是运行模式
|
||||||
|
|
||||||
|
|
||||||
|
class similarityResult:
|
||||||
|
top10 = None
|
||||||
|
top1 = None
|
||||||
|
tempLibList = None
|
||||||
|
topn = None
|
||||||
|
|
||||||
|
|
||||||
|
class similarity:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getSimilarity(self, model, dataCollection, similarityRes):
|
||||||
|
dataCollection.mainMilvus = model.milvusModel
|
||||||
|
dataCollection.model = model.similarityModel
|
||||||
|
# try:
|
||||||
|
if dataCollection.add_flag:
|
||||||
|
if dataCollection.barcode_flag: # 加购 有barcode -> 输出top10和top1
|
||||||
|
similarityRes.top10, similarityRes.top1, similarityRes.tempLibList = similarity_interface(
|
||||||
|
dataCollection)
|
||||||
|
print(f"top10: {similarityRes.top10}\ntop1: {similarityRes.top1}")
|
||||||
|
else: # 加购 无barcode -> 输出top10
|
||||||
|
similarityRes.top10, similarityRes.tempLibList = similarity_interface(dataCollection)
|
||||||
|
else: # 退购 -> 输出top10和topn
|
||||||
|
if dataCollection.barcode_flag:
|
||||||
|
similarityRes.top10, similarityRes.top1, similarityRes.topn = similarity_interface(dataCollection)
|
||||||
|
else:
|
||||||
|
similarityRes.top10, similarityRes.topn = similarity_interface(dataCollection)
|
||||||
|
return similarityRes
|
||||||
|
# except pymilvus.exceptions.SchemaNotReadyException as SchemaNotReadyException: ###当前特征库不存在
|
||||||
|
# print('pymilvus.exceptions.SchemaNotReadyException', SchemaNotReadyException)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
data_collection = datacollection()
|
||||||
|
similarityRes = similarityResult()
|
||||||
|
data_collection.barcode_flag = queueImgs_add['barcode_flag']
|
||||||
|
data_collection.add_flag = queueImgs_add['add_flag']
|
||||||
|
data_collection.queImgsDict = queueImgs_add
|
||||||
|
similarity().getSimilarity(data_collection, similarityRes)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
1894
contrast/main_barcodes.json
Normal file
17
contrast/main_library.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""搭建主特征库"""
|
||||||
|
from test_logic import create_milvus, img2feature
|
||||||
|
from config import config as conf
|
||||||
|
from img_data import library_imgs, temp_imgs
|
||||||
|
|
||||||
|
def createMainMilvus(imgs_dict): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||||
|
barcode_list, imgs_list = img2feature(imgs_dict)
|
||||||
|
mainMilvus = create_milvus('main_features', conf.host, conf.port, barcode_list, imgs_list)
|
||||||
|
return mainMilvus
|
||||||
|
def createTempMilvus(imgs_dict): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||||
|
barcode_list, imgs_list = img2feature(imgs_dict)
|
||||||
|
tempMilvus = create_milvus('temp_features', conf.host, conf.port, barcode_list, imgs_list)
|
||||||
|
return tempMilvus
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
createMainMilvus(library_imgs)
|
||||||
|
# createTempMilvus(temp_imgs)
|
2
contrast/model/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .resnet_pre import resnet18
|
||||||
|
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
BIN
contrast/model/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/fmobilenet.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/mobilenet_v3.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/mobilevit.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/resnet_pre.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/utils.cpython-310.pyc
Normal file
200
contrast/model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
'''MobileNetV3 in PyTorch.
|
||||||
|
|
||||||
|
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||||
|
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import init
|
||||||
|
from tools.config import config as conf
|
||||||
|
|
||||||
|
|
||||||
|
class hswish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class hsigmoid(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.relu6(x + 3, inplace=True) / 6
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SeModule(nn.Module):
|
||||||
|
def __init__(self, in_size, reduction=4):
|
||||||
|
super(SeModule, self).__init__()
|
||||||
|
self.se = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d(1),
|
||||||
|
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(in_size // reduction),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(in_size),
|
||||||
|
hsigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.se(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
'''expand + depthwise + pointwise'''
|
||||||
|
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||||
|
super(Block, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.se = semodule
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||||
|
self.nolinear1 = nolinear
|
||||||
|
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||||
|
self.nolinear2 = nolinear
|
||||||
|
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_size)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride == 1 and in_size != out_size:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(out_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
if self.se != None:
|
||||||
|
out = self.se(out)
|
||||||
|
out = out + self.shortcut(x) if self.stride==1 else out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3_Large(nn.Module):
|
||||||
|
def __init__(self, num_classes=conf.embedding_size):
|
||||||
|
super(MobileNetV3_Large, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(16)
|
||||||
|
self.hs1 = hswish()
|
||||||
|
|
||||||
|
self.bneck = nn.Sequential(
|
||||||
|
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||||
|
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||||
|
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||||
|
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||||
|
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||||
|
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||||
|
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||||
|
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||||
|
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||||
|
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||||
|
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||||
|
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||||
|
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||||
|
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||||
|
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(960)
|
||||||
|
self.hs2 = hswish()
|
||||||
|
self.linear3 = nn.Linear(960, 1280)
|
||||||
|
self.bn3 = nn.BatchNorm1d(1280)
|
||||||
|
self.hs3 = hswish()
|
||||||
|
self.linear4 = nn.Linear(1280, num_classes)
|
||||||
|
self.init_params()
|
||||||
|
|
||||||
|
def init_params(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.hs1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bneck(out)
|
||||||
|
out = self.hs2(self.bn2(self.conv2(out)))
|
||||||
|
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.hs3(self.bn3(self.linear3(out)))
|
||||||
|
out = self.linear4(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3_Small(nn.Module):
|
||||||
|
def __init__(self, num_classes=conf.embedding_size):
|
||||||
|
super(MobileNetV3_Small, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(16)
|
||||||
|
self.hs1 = hswish()
|
||||||
|
|
||||||
|
self.bneck = nn.Sequential(
|
||||||
|
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||||
|
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||||
|
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||||
|
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||||
|
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||||
|
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||||
|
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||||
|
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||||
|
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||||
|
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||||
|
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(576)
|
||||||
|
self.hs2 = hswish()
|
||||||
|
self.linear3 = nn.Linear(576, 1280)
|
||||||
|
self.bn3 = nn.BatchNorm1d(1280)
|
||||||
|
self.hs3 = hswish()
|
||||||
|
self.linear4 = nn.Linear(1280, num_classes)
|
||||||
|
self.init_params()
|
||||||
|
|
||||||
|
def init_params(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.hs1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bneck(out)
|
||||||
|
out = self.hs2(self.bn2(self.conv2(out)))
|
||||||
|
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
|
||||||
|
out = self.hs3(self.bn3(self.linear3(out)))
|
||||||
|
out = self.linear4(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
net = MobileNetV3_Small()
|
||||||
|
x = torch.randn(2,3,224,224)
|
||||||
|
y = net(x)
|
||||||
|
print(y.size())
|
||||||
|
|
||||||
|
# test()
|
462
contrast/model/resnet_pre.py
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tools.config import config as conf
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
except ImportError:
|
||||||
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||||
|
# from .utils import load_state_dict_from_url
|
||||||
|
|
||||||
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||||
|
'wide_resnet50_2', 'wide_resnet101_2']
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||||
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||||
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||||
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||||
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||||
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||||
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||||
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(SpatialAttention, self).__init__()
|
||||||
|
|
||||||
|
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||||
|
padding = 3 if kernel_size == 7 else 1
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||||
|
x = torch.cat([avg_out, max_out], dim=1)
|
||||||
|
x = self.conv1(x)
|
||||||
|
return self.sigmoid(x)
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
if dilation > 1:
|
||||||
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||||
|
self.cam = cam
|
||||||
|
self.bam = bam
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bn1 = norm_layer(planes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
if self.cam:
|
||||||
|
if planes == 64:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||||
|
elif planes == 128:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||||
|
elif planes == 256:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||||
|
elif planes == 512:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||||
|
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
if self.bam:
|
||||||
|
self.bam = SpatialAttention()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
if self.cam:
|
||||||
|
ori_out = self.globalAvgPool(out)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.fc1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.fc2(out)
|
||||||
|
out = self.sigmod(out)
|
||||||
|
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||||
|
out = out * ori_out
|
||||||
|
|
||||||
|
if self.bam:
|
||||||
|
out = out*self.bam(out)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||||
|
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||||
|
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||||
|
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||||
|
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||||
|
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
|
self.cam = cam
|
||||||
|
self.bam = bam
|
||||||
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
|
self.bn1 = norm_layer(width)
|
||||||
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
if self.cam:
|
||||||
|
if planes == 64:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||||
|
elif planes == 128:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||||
|
elif planes == 256:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||||
|
elif planes == 512:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||||
|
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
if self.cam:
|
||||||
|
ori_out = self.globalAvgPool(out)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.fc1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.fc2(out)
|
||||||
|
out = self.sigmod(out)
|
||||||
|
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||||
|
out = out * ori_out
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||||
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None, scale=0.75):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self._norm_layer = norm_layer
|
||||||
|
|
||||||
|
self.inplanes = 64
|
||||||
|
self.dilation = 1
|
||||||
|
if replace_stride_with_dilation is None:
|
||||||
|
# each element in the tuple indicates if we should replace
|
||||||
|
# the 2x2 stride with a dilated convolution instead
|
||||||
|
replace_stride_with_dilation = [False, False, False]
|
||||||
|
if len(replace_stride_with_dilation) != 3:
|
||||||
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, int(64*scale), layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, int(128*scale), layers[1], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
self.layer3 = self._make_layer(block, int(256*scale), layers[2], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[1])
|
||||||
|
self.layer4 = self._make_layer(block, int(512*scale), layers[3], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[2])
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(int(512 * block.expansion*scale), num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# Zero-initialize the last BN in each residual branch,
|
||||||
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||||
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
norm_layer = self._norm_layer
|
||||||
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
norm_layer(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||||
|
base_width=self.base_width, dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _forward_impl(self, x):
|
||||||
|
# See note [TorchScript super()]
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
# print('poolBefore', x.shape)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
# print('poolAfter', x.shape)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
# print('fcBefore',x.shape)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
# print('fcAfter',x.shape)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
# model = ResNet(block, layers, **kwargs)
|
||||||
|
# if pretrained:
|
||||||
|
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
# progress=progress)
|
||||||
|
# model.load_state_dict(state_dict, strict=False)
|
||||||
|
# return model
|
||||||
|
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
progress=progress)
|
||||||
|
|
||||||
|
src_state_dict = state_dict
|
||||||
|
target_state_dict = model.state_dict()
|
||||||
|
skip_keys = []
|
||||||
|
# skip mismatch size tensors in case of pretraining
|
||||||
|
for k in src_state_dict.keys():
|
||||||
|
if k not in target_state_dict:
|
||||||
|
continue
|
||||||
|
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||||
|
skip_keys.append(k)
|
||||||
|
for k in skip_keys:
|
||||||
|
del src_state_dict[k]
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||||
|
r"""ResNet-14 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||||
|
r"""ResNet-18 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-34 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-50 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-101 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-152 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-50 32x4d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-101 32x8d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-50-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-101-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
101
contrast/search.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
class ImgSearch():
|
||||||
|
def __init__(self):
|
||||||
|
self.search_params = {
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_max(self, a, b):
|
||||||
|
if a > b:
|
||||||
|
return a
|
||||||
|
else:
|
||||||
|
return b
|
||||||
|
|
||||||
|
def check_keys(self, dict1_last, dict2):
|
||||||
|
for key2 in list(dict2.keys()):
|
||||||
|
if key2 in list(dict1_last.keys()):
|
||||||
|
value = self.get_max(dict1_last[key2], dict2[key2])
|
||||||
|
dict1_last[key2] = value
|
||||||
|
else:
|
||||||
|
dict1_last[key2] = dict2[key2]
|
||||||
|
return dict1_last
|
||||||
|
def result_analysis(self, result, top1_flag=False):
|
||||||
|
result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典
|
||||||
|
for hits in result:
|
||||||
|
for hit in hits:
|
||||||
|
if not hit.id in result_dict: ## barcode(hit.id)不在结果字典中
|
||||||
|
result_dict.update({hit.id: round(hit.distance, 2)})
|
||||||
|
else: ## 将同一barcode相似度保存较高的
|
||||||
|
distance = result_dict.get(hit.id)
|
||||||
|
distance_new = self.get_max(distance, hit.distance)
|
||||||
|
result_dict.update({hit.id: round(distance_new, 2)})
|
||||||
|
if top1_flag:
|
||||||
|
return result_dict
|
||||||
|
else:
|
||||||
|
## 将所有barcode相似度结果排序存储
|
||||||
|
if len(result_dict) > 10:
|
||||||
|
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||||
|
else:
|
||||||
|
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
return result_sort_dict
|
||||||
|
def result_update(self, temp_result, last_result):
|
||||||
|
temp_keys = list(temp_result.keys())
|
||||||
|
last_keys = list(last_result.keys())
|
||||||
|
for ke in temp_keys:
|
||||||
|
temp_value = temp_result[ke]
|
||||||
|
if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcode,track_id2中barcode相似度高才更新
|
||||||
|
last_value = last_result[ke]
|
||||||
|
if temp_value > last_value:
|
||||||
|
last_result.update({ke: temp_value})
|
||||||
|
else: ## track_id1的结果和track_id2的结果无公共barcode
|
||||||
|
last_result.update({ke: temp_value})
|
||||||
|
return last_result
|
||||||
|
def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id
|
||||||
|
result_last = dict()
|
||||||
|
for i in range(len(queBarIdList)):
|
||||||
|
vectorsSearch = queueFeatures[i]
|
||||||
|
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10)
|
||||||
|
result_sort_dic = self.result_analysis(result)
|
||||||
|
result_last.update({queBarIdList[i]: result_sort_dic})
|
||||||
|
return result_last
|
||||||
|
|
||||||
|
def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId):
|
||||||
|
newBarList = []
|
||||||
|
### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn]
|
||||||
|
for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode
|
||||||
|
if len(bar.split('_')) == 3:
|
||||||
|
mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1]
|
||||||
|
if mac_barcode in barIdList:
|
||||||
|
newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm]
|
||||||
|
if len(newBarList) == 0:
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
expr = f"pk in {newBarList}"
|
||||||
|
result_last = dict()
|
||||||
|
for i in range(len(queueList)):
|
||||||
|
vectorsSearch = queueFeatures[i]
|
||||||
|
result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params,
|
||||||
|
limit=len(newBarList))
|
||||||
|
result_sort_dic = self.result_analysis(result)
|
||||||
|
result_last.update({queueList[i]: result_sort_dic})
|
||||||
|
return result_last
|
||||||
|
|
||||||
|
def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId
|
||||||
|
result_last = dict()
|
||||||
|
for i in range(len(queBarIdList)):
|
||||||
|
pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId
|
||||||
|
vectorsSearch = queFeatures[i]
|
||||||
|
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'",
|
||||||
|
param=self.search_params, limit=1)
|
||||||
|
result_dic = self.result_analysis(result, top1_flag=True)
|
||||||
|
if (len(result_dic) != 0) and (len(result_last) != 0):
|
||||||
|
result_last = self.result_update(result_dic, result_last)
|
||||||
|
else:
|
||||||
|
result_last.update({key: value for key, value in result_dic.items()})
|
||||||
|
if len(result_last) == 0:
|
||||||
|
pk_barcode = queBarIdList[0].split('_')[1]
|
||||||
|
result_last.update({pk_barcode: 0})
|
||||||
|
return result_last
|
||||||
|
|
317
contrast/test_logic.py
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import pdb
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from contrast.model import resnet18, MobileNetV3_Large
|
||||||
|
# import pymilvus
|
||||||
|
# from pymilvus import (
|
||||||
|
# connections,
|
||||||
|
# utility,
|
||||||
|
# FieldSchema, CollectionSchema, DataType,
|
||||||
|
# Collection,
|
||||||
|
# Milvus
|
||||||
|
# )
|
||||||
|
# from config import config as conf
|
||||||
|
from contrast.search import ImgSearch
|
||||||
|
from contrast.img_data import queueImgs_add
|
||||||
|
import sys
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append('../tools')
|
||||||
|
from tools.config import cfg as conf
|
||||||
|
from tools.config import gvalue
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess(images: list, actionModel) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
# print(img)
|
||||||
|
try:
|
||||||
|
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||||
|
res.append(im)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def inference(images, model, actionModel):
|
||||||
|
data = test_preprocess(images, actionModel)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
data = data.to(conf.device)
|
||||||
|
features = model(data)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def group_image(images, batch=64) -> list:
|
||||||
|
"""Group image paths by batch size"""
|
||||||
|
size = len(images)
|
||||||
|
res = []
|
||||||
|
for i in range(0, size, batch):
|
||||||
|
end = min(batch + i, size)
|
||||||
|
res.append(images[i:end])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def barcode_state(barcodeIDList):
|
||||||
|
with open('contrast/main_barcodes.json', 'r') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
main_barcode = list(data.values())[0]
|
||||||
|
barIdList_true = []
|
||||||
|
barIdList_false = []
|
||||||
|
for barId in barcodeIDList:
|
||||||
|
bar = barId.split('_')[1]
|
||||||
|
if bar in main_barcode:
|
||||||
|
barIdList_true.append(barId)
|
||||||
|
else:
|
||||||
|
barIdList_false.append(barId)
|
||||||
|
return barIdList_true, barIdList_false
|
||||||
|
|
||||||
|
|
||||||
|
def getFeatureList(barList, imgList, model, actionModel):
|
||||||
|
featList = [[] for i in range(len(barList))]
|
||||||
|
for index, feat in enumerate(imgList):
|
||||||
|
groups = group_image(feat)
|
||||||
|
for group in groups:
|
||||||
|
feat_tensor = inference(group, model, actionModel)
|
||||||
|
for fe in feat_tensor:
|
||||||
|
if fe.device == 'cpu':
|
||||||
|
fe_np = fe.squeeze().detach().numpy()
|
||||||
|
else:
|
||||||
|
fe_np = fe.squeeze().detach().cpu().numpy()
|
||||||
|
featList[index].append(fe_np)
|
||||||
|
return featList
|
||||||
|
|
||||||
|
|
||||||
|
def img2feature(imgs_dict, model, actionModel, barcode_flag):
|
||||||
|
if not len(imgs_dict) > 0:
|
||||||
|
raise ValueError("Tracking fail no images files provided")
|
||||||
|
queBarIdList = list(imgs_dict.keys())
|
||||||
|
if barcode_flag:
|
||||||
|
# # ========判断barcode是否在特征库============
|
||||||
|
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
|
||||||
|
queFeatList_t = []
|
||||||
|
if len(queBarIdList_t) == 0:
|
||||||
|
print(f"All barcodes are not in the main_library: {barIdList_f}")
|
||||||
|
return queBarIdList_t, queFeatList_t
|
||||||
|
else:
|
||||||
|
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
|
||||||
|
print(f"These barcodes are not in the main_library: {barIdList_f}")
|
||||||
|
for bar_f in barIdList_f:
|
||||||
|
del imgs_dict[bar_f]
|
||||||
|
queImgList_t = list(imgs_dict.values())
|
||||||
|
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
|
||||||
|
return queBarIdList_t, queFeatList_t
|
||||||
|
else:
|
||||||
|
queImgsList = list(imgs_dict.values())
|
||||||
|
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
|
||||||
|
return queBarIdList, queFeatList
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def create_milvus(collection_name, host, port, barcode_list, features):
|
||||||
|
# # 1. connect to Milvus
|
||||||
|
# fmt = "\n=== {:30} ===\n"
|
||||||
|
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
|
||||||
|
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
|
||||||
|
# print(f"Does collection {collection_name} exist in Milvus: {has}")
|
||||||
|
# # if has: ## 删除collection_name的库
|
||||||
|
# # utility.drop_collection(collection_name)
|
||||||
|
#
|
||||||
|
# # 2. create colllection
|
||||||
|
# fields = [
|
||||||
|
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
|
||||||
|
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
|
||||||
|
# ]
|
||||||
|
# schema = CollectionSchema(fields)
|
||||||
|
# print(fmt.format(f"Create collection {collection_name}"))
|
||||||
|
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
|
||||||
|
# # 3. insert data
|
||||||
|
# for i in range(len(features)):
|
||||||
|
# entities = [
|
||||||
|
# # provide the pk field because `auto_id` is set to False
|
||||||
|
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
|
||||||
|
# features[i],
|
||||||
|
# ]
|
||||||
|
# print(fmt.format("Start inserting entities"))
|
||||||
|
# insert_result = hello_milvus.insert(entities)
|
||||||
|
# hello_milvus.flush()
|
||||||
|
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
|
||||||
|
# return hello_milvus
|
||||||
|
|
||||||
|
|
||||||
|
# def load_collection(collection_name):
|
||||||
|
# collection = Collection(collection_name)
|
||||||
|
# # collection.release() ### 将collection从加载状态变成未加载
|
||||||
|
# # collection.drop_index() ### 删除索引
|
||||||
|
#
|
||||||
|
# index_params = {
|
||||||
|
# "index_type": "IVF_FLAT",
|
||||||
|
# # "index_type": "IVF_SQ8",
|
||||||
|
# # "index_type": "GPU_IVF_FLAT",
|
||||||
|
# "metric_type": "COSINE",
|
||||||
|
# "params": {
|
||||||
|
# "nlist": 10000
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# #### 准确率低
|
||||||
|
# # index_params = {
|
||||||
|
# # "index_type": "IVF_PQ",
|
||||||
|
# # "metric_type": "COSINE",
|
||||||
|
# # "params": {
|
||||||
|
# # "nlist": 99,
|
||||||
|
# # "m": 2,
|
||||||
|
# # "nbits": 8
|
||||||
|
# # }
|
||||||
|
# # }
|
||||||
|
# collection.create_index(
|
||||||
|
# field_name="embeddings",
|
||||||
|
# index_params=index_params,
|
||||||
|
# index_name="SQ8"
|
||||||
|
# )
|
||||||
|
# collection.load()
|
||||||
|
# return collection
|
||||||
|
|
||||||
|
|
||||||
|
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
|
||||||
|
# searchImg = ImgSearch() ## 相似度比较
|
||||||
|
# # 将输入图片加入临时库
|
||||||
|
# if add_flag:
|
||||||
|
# if actionModel:
|
||||||
|
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
|
||||||
|
# else:
|
||||||
|
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
|
||||||
|
#
|
||||||
|
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
|
||||||
|
# if len(queBarIdList) == 0:
|
||||||
|
# top10, top1 = {}, {}
|
||||||
|
# else:
|
||||||
|
# for bar in queBarIdList:
|
||||||
|
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
|
||||||
|
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||||
|
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
|
||||||
|
# else:
|
||||||
|
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
|
||||||
|
# # 存入临时特征库
|
||||||
|
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
|
||||||
|
#
|
||||||
|
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||||||
|
# 'host': conf.host,
|
||||||
|
# 'port': conf.port,
|
||||||
|
# 'barcode_list': queBarIdList,
|
||||||
|
# 'features': queBarIdFeatures})
|
||||||
|
# thread.start()
|
||||||
|
# start1 = time.time()
|
||||||
|
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||||
|
# start2 = time.time()
|
||||||
|
# print('search top10 time>>>> {}'.format(start2-start1))
|
||||||
|
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||||||
|
# start3 = time.time()
|
||||||
|
# print('search top1 time>>>>> {}'.format(start3-start2))
|
||||||
|
# return top10, top1, gvalue.tempLibLists
|
||||||
|
# else: # 加购 无barcode -> 输出top10
|
||||||
|
# # 无barcode时,生成随机数作为字典key值
|
||||||
|
# queBarIdList_rand = []
|
||||||
|
# for i in range(len(queBarIdList)):
|
||||||
|
# random_number = ''.join(random.choices('0123456789', k=10))
|
||||||
|
# queBarIdList_rand.append(str(random_number))
|
||||||
|
# # gvalue.tempLibList.append(str(random_number))
|
||||||
|
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||||
|
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
|
||||||
|
# else:
|
||||||
|
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
|
||||||
|
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
|
||||||
|
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||||||
|
# 'host': conf.host,
|
||||||
|
# 'port': conf.port,
|
||||||
|
# 'barcode_list': queBarIdList_rand,
|
||||||
|
# 'features': queBarIdFeatures})
|
||||||
|
# thread.start()
|
||||||
|
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||||
|
# # print(f'top10: {top10}')
|
||||||
|
# return top10, gvalue.tempLibLists
|
||||||
|
# else: # 退购 -> 输出top10和topn
|
||||||
|
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
|
||||||
|
# gvalue.tempLibList = []
|
||||||
|
# else:
|
||||||
|
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||||||
|
# ## 加载临时特征库
|
||||||
|
# tempMilvusName = "temp_features"
|
||||||
|
# has = utility.has_collection(tempMilvusName)
|
||||||
|
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
|
||||||
|
# tempMilvus = load_collection(tempMilvusName)
|
||||||
|
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
|
||||||
|
# if actionModel:
|
||||||
|
# barcode_list = barcode_list
|
||||||
|
# else:
|
||||||
|
# barcode_list = queueImgs_add['barcode_list']
|
||||||
|
# if actionModel:
|
||||||
|
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
|
||||||
|
# else:
|
||||||
|
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
|
||||||
|
# if barcode_flag:
|
||||||
|
# if len(queBarIdList) == 0:
|
||||||
|
# top10, top1, top_n = {}, {}, {}
|
||||||
|
# else:
|
||||||
|
# start1 = time.time()
|
||||||
|
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||||||
|
# start2 = time.time()
|
||||||
|
# print('search top1 time>>>> {}'.format(start2 - start1))
|
||||||
|
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||||
|
# start3 = time.time()
|
||||||
|
# print('search top10 time>>>> {}'.format(start3 - start2))
|
||||||
|
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||||||
|
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
|
||||||
|
# return top10, top1, top_n
|
||||||
|
# else:
|
||||||
|
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||||
|
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||||||
|
# # print(f'top10: {top10}, topn: {top_n}')
|
||||||
|
# return top10, top_n
|
||||||
|
|
||||||
|
|
||||||
|
def similarity_interface(dataCollection):
|
||||||
|
queImgsDict = dataCollection.queImgsDict
|
||||||
|
add_flag = dataCollection.add_flag
|
||||||
|
barcode_flag = dataCollection.barcode_flag
|
||||||
|
main_milvus = dataCollection.mainMilvus
|
||||||
|
#tempLibList = dataCollection.tempLibList
|
||||||
|
model = dataCollection.model
|
||||||
|
actionModel = dataCollection.actionModel
|
||||||
|
barcode_list = dataCollection.barcode_list
|
||||||
|
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pass
|
||||||
|
|
||||||
|
# connections.connect('default', host=conf.host, port=conf.port)
|
||||||
|
# # 加载主特征库
|
||||||
|
# mainMilvusName = "main_features"
|
||||||
|
# has = utility.has_collection(mainMilvusName)
|
||||||
|
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
|
||||||
|
# mainMilvus = Collection(mainMilvusName)
|
||||||
|
# mainMilvus.load()
|
||||||
|
# model = initModel()
|
||||||
|
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
|
||||||
|
# add_flag = queueImgs_add['add_flag']
|
||||||
|
# barcode_flag = queueImgs_add['barcode_flag']
|
||||||
|
# tempLibList = [] # 临时特征库的barcodeId_list
|
||||||
|
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
|
||||||
|
# if add_flag:
|
||||||
|
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
|
||||||
|
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||||||
|
# print(f"top10: {top10}\ntop1: {top1}")
|
||||||
|
# else: # 加购 无barcode -> 输出top10
|
||||||
|
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||||||
|
# else: # 退购 -> 输出top10和topn
|
||||||
|
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
64
contrast/utils.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""Train List 训练列表
|
||||||
|
格式:
|
||||||
|
ImagePath Label
|
||||||
|
|
||||||
|
示例:
|
||||||
|
/data/WebFace/0124920/003.jpg 10572
|
||||||
|
/data/WebFace/0124920/012.jpg 10572
|
||||||
|
/data/WebFace/0124920/020.jpg 10572
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from imutils import paths
|
||||||
|
|
||||||
|
def generate_list(images_directory, saved_name=None):
|
||||||
|
"""生成数据列表
|
||||||
|
Args:
|
||||||
|
images_directory: 人脸数据目录,通常包含多个子文件夹。如
|
||||||
|
WebFace和LFW的格式
|
||||||
|
Returns:
|
||||||
|
data_list: [<路径> <标签>]
|
||||||
|
"""
|
||||||
|
subdirs = os.listdir(images_directory)
|
||||||
|
num_ids = len(subdirs)
|
||||||
|
data_list = []
|
||||||
|
for i in range(num_ids):
|
||||||
|
subdir = osp.join(images_directory, subdirs[i])
|
||||||
|
files = os.listdir(subdir)
|
||||||
|
paths = [osp.join(subdir, file) for file in files]
|
||||||
|
# 添加ID作为其人脸标签
|
||||||
|
paths_with_Id = [f"{p} {i}\n" for p in paths]
|
||||||
|
data_list.extend(paths_with_Id)
|
||||||
|
|
||||||
|
if saved_name:
|
||||||
|
with open(saved_name, 'w', encoding='utf-8') as f:
|
||||||
|
f.writelines(data_list)
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
def transform_clean_list(webface_directory, cleaned_list_path):
|
||||||
|
"""转换webface的干净列表格式
|
||||||
|
Args:
|
||||||
|
webface_directory: WebFace数据目录
|
||||||
|
cleaned_list_path: cleaned_list.txt路径
|
||||||
|
Returns:
|
||||||
|
cleaned_list: 转换后的数据列表
|
||||||
|
"""
|
||||||
|
with open(cleaned_list_path, encoding='utf-8') as f:
|
||||||
|
cleaned_list = f.readlines()
|
||||||
|
cleaned_list = [p.replace('\\', '/') for p in cleaned_list]
|
||||||
|
cleaned_list = [osp.join(webface_directory, p) for p in cleaned_list]
|
||||||
|
return cleaned_list
|
||||||
|
|
||||||
|
def remove_dirty_image(webface_directory, cleaned_list):
|
||||||
|
cleaned_list = set([c.split()[0] for c in cleaned_list])
|
||||||
|
for p in paths.list_images(webface_directory):
|
||||||
|
if p not in cleaned_list:
|
||||||
|
print(f"remove {p}")
|
||||||
|
os.remove(p)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
data = '/data/CASIA-WebFace/'
|
||||||
|
lst = '/data/cleaned_list.txt'
|
||||||
|
cleaned_list = transform_clean_list(data, lst)
|
||||||
|
remove_dirty_image(data, cleaned_list)
|
160
dealdata.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["OMP_NUM_THREADS"] = '1'
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
|
||||||
|
import shutil
|
||||||
|
import math
|
||||||
|
|
||||||
|
distance_lists = []
|
||||||
|
all_distance_lists = []
|
||||||
|
|
||||||
|
|
||||||
|
def showImg(newx):
|
||||||
|
plt.scatter(newx[:, 0], newx[:, 1])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def showHistogram(data):
|
||||||
|
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
|
||||||
|
bins = np.arange(0, 1.5, 0.1)
|
||||||
|
axs[0, 0].hist(data[0], bins=bins, edgecolor='black')
|
||||||
|
axs[0, 0].set_title('silhouette')
|
||||||
|
axs[0, 0].set_xlabel('Similarity')
|
||||||
|
axs[0, 0].set_ylabel('Frequency')
|
||||||
|
axs[0, 0].legend(labels=['频次'])
|
||||||
|
axs[0, 1].hist(data[1], bins=bins, edgecolor='black')
|
||||||
|
axs[0, 1].set_title('calinski')
|
||||||
|
axs[0, 1].set_xlabel('Similarity')
|
||||||
|
axs[0, 1].set_ylabel('Frequency')
|
||||||
|
axs[0, 1].legend(labels=['频次'])
|
||||||
|
axs[1, 0].hist(data[2], bins=bins, edgecolor='black')
|
||||||
|
axs[1, 0].set_title('davies')
|
||||||
|
axs[1, 0].set_xlabel('Similarity')
|
||||||
|
axs[1, 0].set_ylabel('Frequency')
|
||||||
|
axs[1, 0].legend(labels=['频次'])
|
||||||
|
axs[1, 1].hist(data[3], bins=bins, edgecolor='black')
|
||||||
|
axs[1, 1].set_title('inertia')
|
||||||
|
axs[1, 1].set_xlabel('Similarity')
|
||||||
|
axs[1, 1].set_ylabel('Frequency')
|
||||||
|
axs[1, 1].legend(labels=['频次'])
|
||||||
|
# 显示图形
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('multiple_histograms_in_ranges_non_pca.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def get_pca_data(data):
|
||||||
|
pca = PCA(n_components=16)
|
||||||
|
newx = pca.fit_transform(data)
|
||||||
|
# print(pca.explained_variance_ratio_)
|
||||||
|
return newx
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(vec1, vec2):
|
||||||
|
# dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||||
|
# norm_vec1 = math.sqrt(sum(x ** 2 for x in vec1))
|
||||||
|
# norm_vec2 = math.sqrt(sum(x ** 2 for x in vec2))
|
||||||
|
# return dot_product / (norm_vec1 * norm_vec2)
|
||||||
|
vec1 = np.array(vec1)
|
||||||
|
vec2 = np.array(vec2)
|
||||||
|
cos_sim = vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
|
return cos_sim
|
||||||
|
|
||||||
|
def get_kclusters(data, k, dicts, standard='silhouette'):
|
||||||
|
Kmeans = KMeans(n_clusters=k, n_init='auto')
|
||||||
|
y_pred = Kmeans.fit_predict(data)
|
||||||
|
if standard == 'silhouette':
|
||||||
|
s_score = silhouette_score(data, y_pred, metric='euclidean')
|
||||||
|
dicts[s_score] = [y_pred, Kmeans.cluster_centers_]
|
||||||
|
elif standard == 'calinski':
|
||||||
|
c_score = calinski_harabasz_score(data, y_pred)
|
||||||
|
dicts[c_score] = [y_pred, Kmeans.cluster_centers_]
|
||||||
|
elif standard == 'davies':
|
||||||
|
d_score = davies_bouldin_score(data, y_pred)
|
||||||
|
dicts[d_score] = [y_pred, Kmeans.cluster_centers_]
|
||||||
|
elif standard == 'inertia':
|
||||||
|
i_score = Kmeans.inertia_
|
||||||
|
dicts[i_score] = [y_pred, Kmeans.cluster_centers_]
|
||||||
|
return dicts
|
||||||
|
|
||||||
|
|
||||||
|
def get_keams(data, standard='silhouette'):
|
||||||
|
global distance_lists
|
||||||
|
distance_list = []
|
||||||
|
dicts = {}
|
||||||
|
for k in range(2, 10):
|
||||||
|
dicts = get_kclusters(data, k, dicts, standard)
|
||||||
|
if standard == 'silhouette' or standard == 'calinski':
|
||||||
|
max_key = max(dicts.keys())
|
||||||
|
value = dicts[max_key]
|
||||||
|
elif standard == 'davies' or standard == 'inertia':
|
||||||
|
max_key = min(dicts.keys())
|
||||||
|
value = dicts[max_key]
|
||||||
|
|
||||||
|
for num, i in enumerate(value[0]):
|
||||||
|
distance_list.append(abs(cosine_similarity(value[1][i], data[num])))
|
||||||
|
distance_lists += distance_list
|
||||||
|
return distance_lists
|
||||||
|
|
||||||
|
def move_file(labels, y_pred):
|
||||||
|
for label, y_label in zip(labels, y_pred):
|
||||||
|
print('label >>> {}'.format(label))
|
||||||
|
if not os.path.isdir(os.sep.join(['data', label.split('_')[0], str(y_label)])):
|
||||||
|
os.mkdir(os.sep.join(['data', label.split('_')[0], str(y_label)]))
|
||||||
|
try:
|
||||||
|
shutil.move(os.sep.join(['data', label.split('_')[0], label]),
|
||||||
|
os.sep.join(['data', label.split('_')[0], str(y_label), label]))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
def move_video(root_pth='data_test'): # 数据准备
|
||||||
|
for Dirs in os.listdir(root_pth):
|
||||||
|
for root, dirs, files in os.walk(os.sep.join([root_pth, Dirs])):
|
||||||
|
if len(dirs) > 0:
|
||||||
|
for dir in dirs:
|
||||||
|
filespth = os.sep.join([root_pth, Dirs, dir])
|
||||||
|
for file in os.listdir(filespth):
|
||||||
|
if file.endswith('.jpg'):
|
||||||
|
oldpth = os.sep.join([filespth, file])
|
||||||
|
newpth = os.sep.join([root_pth, Dirs, file])
|
||||||
|
shutil.move(oldpth, newpth)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global distance_lists
|
||||||
|
for standard in ['silhouette', 'calinski', 'davies', 'inertia']:
|
||||||
|
with open('data_test/data.pkl', 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
for keys, values in data.items():
|
||||||
|
all_data, labels = [], []
|
||||||
|
temp = None
|
||||||
|
for value in values:
|
||||||
|
for key1, value1 in value.items():
|
||||||
|
# print('key1 >>>> {}'.format(key1))
|
||||||
|
if temp is None:
|
||||||
|
labels.append(key1)
|
||||||
|
temp = value1
|
||||||
|
else:
|
||||||
|
labels.append(key1)
|
||||||
|
temp = np.vstack((temp, value1))
|
||||||
|
try:
|
||||||
|
# temp = get_pca_data(temp)
|
||||||
|
get_keams(temp, standard)
|
||||||
|
# move_file(labels, y_pred)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
f.close()
|
||||||
|
all_distance_lists.append(distance_lists)
|
||||||
|
distance_lists = []
|
||||||
|
showHistogram(all_distance_lists)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
# move_video()
|
104
image_quality.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ytracking.track_ import *
|
||||||
|
from tools.Interface import AiInterface, AiClass
|
||||||
|
from tools.operate_usearch import create_base_index, search_in_index
|
||||||
|
from tools.initModel import models
|
||||||
|
from imgcompare import get_feature_list, compute_similarity_matrix
|
||||||
|
import pickle
|
||||||
|
models.initModel()
|
||||||
|
ai_obj = AiClass()
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_lists(pth):
|
||||||
|
imglist, imglists = [], []
|
||||||
|
for root, dirs, files in os.walk(pth):
|
||||||
|
if not any(dirs):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.jpg'):
|
||||||
|
imglist.append(os.sep.join([root, file]))
|
||||||
|
imglists.append(imglist)
|
||||||
|
imglist = []
|
||||||
|
return imglists
|
||||||
|
|
||||||
|
def get_standard_image(cosine_similarities, similarity_threshold=0.6):
|
||||||
|
"""
|
||||||
|
:param cosine_similarities:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
target_indexs = []
|
||||||
|
max_similarity = {}
|
||||||
|
mask = (cosine_similarities > similarity_threshold)
|
||||||
|
counts = mask.sum(axis=1)
|
||||||
|
for key in range(counts.shape[0]):
|
||||||
|
max_similarity[key] = counts[key]
|
||||||
|
sorted_dict_desc = dict(sorted(max_similarity.items(), key=lambda item: item[1], reverse=True))
|
||||||
|
keys = list(sorted_dict_desc.keys())
|
||||||
|
while len(keys) > 10:
|
||||||
|
target_indexs.append(keys[0])
|
||||||
|
single_line = cosine_similarities[keys[0], :]
|
||||||
|
rows = np.where((single_line > similarity_threshold))
|
||||||
|
if len(rows[0]) < 2:
|
||||||
|
break
|
||||||
|
for row in rows[0]:
|
||||||
|
try:
|
||||||
|
keys.remove(row)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
# print(target_indexs)
|
||||||
|
return target_indexs
|
||||||
|
|
||||||
|
def create_feature_library(pth, save_index_name, index_file_pth=None):
|
||||||
|
target_feature_lists, target_barcode_lists = [], []
|
||||||
|
imglists = get_img_lists(pth)
|
||||||
|
for imglist in imglists:
|
||||||
|
feature_list = get_feature_list(imglist, False)
|
||||||
|
cosine_similarities = compute_similarity_matrix(feature_list)
|
||||||
|
target_indexs = get_standard_image(cosine_similarities)
|
||||||
|
target_feature_lists.append([feature_list[i] for i in target_indexs])
|
||||||
|
target_barcode_lists.append([os.path.basename(imglist[i]).split('_')[0] for i in target_indexs])
|
||||||
|
create_base_index(save_index_name=save_index_name,
|
||||||
|
barcodes=target_barcode_lists,
|
||||||
|
features=target_feature_lists,
|
||||||
|
index_file_pth=index_file_pth)
|
||||||
|
with open('search_library/target_barcode_lists.pkl', 'wb') as f:
|
||||||
|
pickle.dump(target_barcode_lists, f)
|
||||||
|
|
||||||
|
def search_top_in_index(test_image_pth, index_name): #1:N
|
||||||
|
s_barcode, s_similarity = [], []
|
||||||
|
img_lists = [os.sep.join([test_image_pth, name]) for name in os.listdir(test_image_pth)]
|
||||||
|
feature_lists = get_feature_list(img_lists, False)
|
||||||
|
for feature in feature_lists:
|
||||||
|
result = search_in_index(query=np.array(feature), index_name=index_name)
|
||||||
|
s_barcode.append(result.keys)
|
||||||
|
s_similarity.append(1-result.distances)
|
||||||
|
s_barcode = np.array(s_barcode)
|
||||||
|
s_similarity = np.array(s_similarity)
|
||||||
|
return s_barcode, s_similarity
|
||||||
|
|
||||||
|
def search_one_in_index(test_image_pth, index_name): # 1:1
|
||||||
|
barcodes = [int(os.path.basename(name).split('_')[0]) for name in os.listdir(test_image_pth)]
|
||||||
|
barcodes = list(set(barcodes))
|
||||||
|
# barcodes = ['6934364805640']
|
||||||
|
img_lists = [os.sep.join([test_image_pth, name]) for name in os.listdir(test_image_pth)]
|
||||||
|
feature_lists = get_feature_list(img_lists, False)
|
||||||
|
result = search_in_index(barcode=barcodes,
|
||||||
|
query=feature_lists,
|
||||||
|
index_name=index_name,
|
||||||
|
temp_index=False)
|
||||||
|
print(feature_lists)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pth = 'imageQualityData/test_data'
|
||||||
|
save_index_name = 'search_library/test_index_10_simple_0717.usearch'
|
||||||
|
create_feature_library(pth,
|
||||||
|
save_index_name=save_index_name)
|
||||||
|
|
||||||
|
# test_images_pth = 'D:/Project/ieemoo/image_quality_assessment/imageQualityData/test_images'
|
||||||
|
# # index_name = 'D:/Project/ieemoo/image_quality_assessment/search_library/test_index_10_normal_0717.usearch'
|
||||||
|
# index_name = 'D:/Project/ieemoo/image_quality_assessment/search_library/test_index_10_simple_0717.usearch'
|
||||||
|
# # search_top_in_index(test_images_pth, index_name)
|
||||||
|
# search_one_in_index(test_images_pth, index_name)
|
313
imgcompare.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ytracking.track_ import *
|
||||||
|
from contrast.test_logic import group_image, inference
|
||||||
|
from tools.Interface import AiInterface, AiClass
|
||||||
|
from tools.config import cfg, gvalue
|
||||||
|
|
||||||
|
from tools.initModel import models
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
|
||||||
|
from dealdata import get_keams
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from contrast.model.resnet_pre import resnet18
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
models.initModel()
|
||||||
|
ai_obj = AiClass()
|
||||||
|
|
||||||
|
|
||||||
|
def showComprehensiveHistogram(data, title):
|
||||||
|
bins = np.arange(0, 1.01, 0.1)
|
||||||
|
plt.hist(data, bins, edgecolor='black')
|
||||||
|
plt.title(title)
|
||||||
|
plt.xlabel('Similarity')
|
||||||
|
plt.ylabel('Frequency')
|
||||||
|
# plt.show()
|
||||||
|
plt.savefig(title + '.png')
|
||||||
|
|
||||||
|
|
||||||
|
def showHistogram():
|
||||||
|
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
|
||||||
|
bins = np.arange(0, 1.01, 0.1)
|
||||||
|
axs[0, 0].hist(gvalue.back_return_similarity, bins=bins, edgecolor='black')
|
||||||
|
axs[0, 0].set_title('back_return_similarity')
|
||||||
|
axs[0, 0].set_xlabel('Similarity')
|
||||||
|
axs[0, 0].set_ylabel('Frequency')
|
||||||
|
axs[0, 0].legend(labels=['back_return_similarity'])
|
||||||
|
|
||||||
|
axs[0, 1].hist(gvalue.back_add_similarity, bins=bins, edgecolor='black')
|
||||||
|
axs[0, 1].set_title('back_add_similarity')
|
||||||
|
axs[0, 1].set_xlabel('Similarity')
|
||||||
|
axs[0, 1].set_ylabel('Frequency')
|
||||||
|
axs[0, 1].legend(labels=['back_add_similarity'])
|
||||||
|
|
||||||
|
axs[1, 0].hist(gvalue.front_return_similarity, bins=bins, edgecolor='black')
|
||||||
|
axs[1, 0].set_title('front_return_similarity')
|
||||||
|
axs[1, 0].set_xlabel('Similarity')
|
||||||
|
axs[1, 0].set_ylabel('Frequency')
|
||||||
|
axs[1, 0].legend(labels=['front_return_similarity'])
|
||||||
|
|
||||||
|
axs[1, 1].hist(gvalue.front_add_similarity, bins=bins, edgecolor='black')
|
||||||
|
axs[1, 1].set_title('front_add_similarity')
|
||||||
|
axs[1, 1].set_xlabel('Similarity')
|
||||||
|
axs[1, 1].set_ylabel('Frequency')
|
||||||
|
axs[1, 1].legend(labels=['front_add_similarity'])
|
||||||
|
# 显示图形
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('multiple_histograms.png')
|
||||||
|
plt.close(fig)
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def showgrid():
|
||||||
|
y_back_return = get_count_number(gvalue.back_return_similarity)
|
||||||
|
y_back_add = get_count_number(gvalue.back_add_similarity)
|
||||||
|
y_front_return = get_count_number(gvalue.front_return_similarity)
|
||||||
|
y_front_add = get_count_number(gvalue.front_add_similarity)
|
||||||
|
y_comprehensive = get_count_number(gvalue.comprehensive_similarity)
|
||||||
|
x = np.linspace(start=0.1, stop=1.0, num=10, endpoint=True).tolist()
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(x, y_back_return, color='red', label='back_return')
|
||||||
|
plt.plot(x, y_back_add, color='blue', label='back_add')
|
||||||
|
plt.plot(x, y_front_return, color='green', label='front_return')
|
||||||
|
plt.plot(x, y_front_add, color='purple', label='front_add')
|
||||||
|
plt.plot(x, y_comprehensive, color='orange', label='comprehensive')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('Similarity')
|
||||||
|
plt.ylabel('Frequency')
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
plt.savefig('multiple_grid.png')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def showtable(): # 在指定相似度下离群点统计
|
||||||
|
temp_lists = [get_count_number(gvalue.back_return_similarity),
|
||||||
|
get_count_number(gvalue.back_add_similarity),
|
||||||
|
get_count_number(gvalue.front_return_similarity),
|
||||||
|
get_count_number(gvalue.front_add_similarity),
|
||||||
|
get_count_number(gvalue.comprehensive_similarity)]
|
||||||
|
rows = []
|
||||||
|
table = PrettyTable()
|
||||||
|
tablename = ['back_return', 'back_add', 'front_return', 'front_add', 'comprehensive']
|
||||||
|
table.field_names = ['name', '0.1', '0.2', '0.3', '0.4', '0.5', '0.6', '0.7', '0.8', '0.9', '1.0']
|
||||||
|
for List, name in zip(temp_lists, tablename):
|
||||||
|
o_data = [round(data / List[-1], 3) for data in List]
|
||||||
|
o_data.insert(0, name)
|
||||||
|
rows.append(o_data)
|
||||||
|
# print(rows)
|
||||||
|
table.add_rows(rows)
|
||||||
|
print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_similarity_matrix(featurelists):
|
||||||
|
"""计算图片之间的余弦相似度矩阵"""
|
||||||
|
# 计算所有向量对之间的余弦相似度
|
||||||
|
cosine_similarities = 1 - cdist(featurelists, featurelists, metric='cosine')
|
||||||
|
cosine_similarities = np.around(cosine_similarities, decimals=3)
|
||||||
|
return cosine_similarities
|
||||||
|
|
||||||
|
|
||||||
|
def remove_empty_folders(root_dir):
|
||||||
|
for foldername, subfolders, files in os.walk(root_dir):
|
||||||
|
if not subfolders and not files: # 如果当前文件夹无子文件夹且无文件
|
||||||
|
print(f"Removing empty folder: {foldername}")
|
||||||
|
try:
|
||||||
|
shutil.rmtree(foldername) # 删除空文件夹
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error removing folder {foldername}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(vec_mean, vecs, k=False, y_pred=None): # 余弦角相似度
|
||||||
|
all_similarity = []
|
||||||
|
if not k:
|
||||||
|
vec_mean = np.array(vec_mean)
|
||||||
|
for ovec in vecs:
|
||||||
|
ovec = np.array(ovec)
|
||||||
|
cos_sim = ovec.dot(vec_mean) / (np.linalg.norm(vec_mean) * np.linalg.norm(ovec))
|
||||||
|
all_similarity.append(cos_sim)
|
||||||
|
else:
|
||||||
|
for nu, ks in enumerate(y_pred):
|
||||||
|
ovec = np.array(vecs[nu])
|
||||||
|
vecmean = np.array(vec_mean[ks])
|
||||||
|
cos_sim = ovec.dot(vecmean) / (np.linalg.norm(vecmean) * np.linalg.norm(ovec))
|
||||||
|
all_similarity.append(cos_sim)
|
||||||
|
# print(all_similarity)
|
||||||
|
return all_similarity
|
||||||
|
|
||||||
|
|
||||||
|
def get_count_number(numbers):
|
||||||
|
count_less = []
|
||||||
|
thresholds = np.linspace(start=0.1, stop=1.0, num=10, endpoint=True).tolist()
|
||||||
|
for threshold in thresholds:
|
||||||
|
count_less.append(sum(map(lambda x: x < threshold, numbers)))
|
||||||
|
print(count_less)
|
||||||
|
return count_less
|
||||||
|
|
||||||
|
|
||||||
|
def shuntVideo_imgs(obj: AiInterface, rootpth, vpth, ): # 制作单trackid下的相似度矩阵
|
||||||
|
videospth = os.sep.join([rootpth, vpth])
|
||||||
|
for videoname in os.listdir(videospth):
|
||||||
|
if videoname.endswith('mp4'):
|
||||||
|
cameraId = '0' if videoname.split('_')[2] == 'back' else '1'
|
||||||
|
videopth = os.sep.join([videospth, videoname])
|
||||||
|
save_imgs_dir = os.sep.join([rootpth, 'images', videoname.split('.')[0]])
|
||||||
|
if not os.path.exists(save_imgs_dir):
|
||||||
|
os.makedirs(save_imgs_dir)
|
||||||
|
track_boxes, features_dict, frame_id_img = run(models, source=videopth)
|
||||||
|
allimages, trackIdList = obj.getTrackingBox(track_boxes,
|
||||||
|
features_dict,
|
||||||
|
cameraId,
|
||||||
|
frame_id_img,
|
||||||
|
save_imgs_dir)
|
||||||
|
featList = get_feature_list(allimages)
|
||||||
|
cosine_similarities = compute_similarity_matrix(featList)
|
||||||
|
print(len(cosine_similarities))
|
||||||
|
print(cosine_similarities)
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_list(allimages, actionModel=True):
|
||||||
|
featList = []
|
||||||
|
groups = group_image(allimages, batch=64)
|
||||||
|
if not actionModel:
|
||||||
|
groups = [groups]
|
||||||
|
for group in groups:
|
||||||
|
for img in group:
|
||||||
|
feat_tensor = inference(img, models.similarityModel, actionModel)
|
||||||
|
for fe in feat_tensor:
|
||||||
|
if fe.device == 'cpu':
|
||||||
|
fe_np = fe.squeeze().detach().numpy()
|
||||||
|
else:
|
||||||
|
fe_np = fe.squeeze().detach().cpu().numpy()
|
||||||
|
featList.append(fe_np)
|
||||||
|
return featList
|
||||||
|
|
||||||
|
|
||||||
|
def k_similarity(imgs_pth, k, actionModel=False): # k个聚类中心向量与每个图片的相似度
|
||||||
|
remove_empty_folders(imgs_pth)
|
||||||
|
for imgdirs in os.listdir(imgs_pth):
|
||||||
|
imgpth = []
|
||||||
|
for img in os.listdir(os.sep.join([imgs_pth, imgdirs])):
|
||||||
|
imgpth.append(os.sep.join([imgs_pth, imgdirs, img]))
|
||||||
|
featList = get_feature_list(imgpth, actionModel)
|
||||||
|
# assert all(len(lst) == len(featList[0]) for lst in featList)
|
||||||
|
if len(featList) < k:
|
||||||
|
continue
|
||||||
|
featList = np.array(featList)
|
||||||
|
Kmeans = KMeans(n_clusters=k)
|
||||||
|
y_pred = Kmeans.fit_predict(featList)
|
||||||
|
ores = cosine_similarity(Kmeans.cluster_centers_, featList, k=True, y_pred=y_pred)
|
||||||
|
if 'back_return' in imgdirs:
|
||||||
|
gvalue.back_return_similarity += ores
|
||||||
|
elif 'back_add' in imgdirs:
|
||||||
|
gvalue.back_add_similarity += ores
|
||||||
|
elif 'front_return' in imgdirs:
|
||||||
|
gvalue.front_return_similarity += ores
|
||||||
|
elif 'front_add' in imgdirs:
|
||||||
|
gvalue.front_add_similarity += ores
|
||||||
|
gvalue.comprehensive_similarity += ores
|
||||||
|
showtable() # 离群点表格
|
||||||
|
|
||||||
|
|
||||||
|
def average_similarity(imgs_pth, actionModel=False): # 平均向量与每个图片的相似度
|
||||||
|
remove_empty_folders(imgs_pth)
|
||||||
|
for imgdirs in os.listdir(imgs_pth):
|
||||||
|
imgpth = []
|
||||||
|
if len(os.listdir(os.sep.join([imgs_pth, imgdirs]))) < 10:
|
||||||
|
continue
|
||||||
|
for img in os.listdir(os.sep.join([imgs_pth, imgdirs])):
|
||||||
|
imgpth.append(os.sep.join([imgs_pth, imgdirs, img]))
|
||||||
|
featList = get_feature_list(imgpth, actionModel)
|
||||||
|
assert all(len(lst) == len(featList[0]) for lst in featList)
|
||||||
|
vec_mean = [sum(column) / len(featList) for column in zip(*featList)]
|
||||||
|
ores = cosine_similarity(vec_mean, featList)
|
||||||
|
if 'back_return' in imgdirs:
|
||||||
|
gvalue.back_return_similarity += ores
|
||||||
|
elif 'back_add' in imgdirs:
|
||||||
|
gvalue.back_add_similarity += ores
|
||||||
|
elif 'front_return' in imgdirs:
|
||||||
|
gvalue.front_return_similarity += ores
|
||||||
|
elif 'front_add' in imgdirs:
|
||||||
|
gvalue.front_add_similarity += ores
|
||||||
|
gvalue.comprehensive_similarity += ores
|
||||||
|
showHistogram() # 绘制直方图
|
||||||
|
showgrid() # 绘制折线图
|
||||||
|
showtable() # 离群点表格
|
||||||
|
showComprehensiveHistogram(gvalue.comprehensive_similarity, 'comprehensive_similarity')
|
||||||
|
|
||||||
|
|
||||||
|
def barcode_similarity(rootpths):
|
||||||
|
for dir in os.listdir(rootpths):
|
||||||
|
if dir == 'barcode_similarity':
|
||||||
|
continue
|
||||||
|
new_dir = os.sep.join([rootpths, 'barcode_similarity', dir])
|
||||||
|
if not os.path.exists(new_dir):
|
||||||
|
os.makedirs(new_dir)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
rootpth = os.sep.join([rootpths, dir]) # 6934660520292
|
||||||
|
imgs_pth = [os.sep.join([rootpth, name]) for name in os.listdir(rootpth)]
|
||||||
|
featList = get_feature_list(imgs_pth, False)
|
||||||
|
cosine_similarities = compute_similarity_matrix(featList)
|
||||||
|
num = 0
|
||||||
|
for i in range(cosine_similarities.shape[0]):
|
||||||
|
cols = np.where(cosine_similarities[i, :] > 0.5)[0]
|
||||||
|
if len(cols) > num:
|
||||||
|
num = len(cols)
|
||||||
|
max_cols = cols
|
||||||
|
imgPth = [os.sep.join([rootpth, imgName]) for imgName in [os.listdir(rootpth)[i] for i in max_cols]]
|
||||||
|
for img in imgPth:
|
||||||
|
try:
|
||||||
|
shutil.copy(img, new_dir)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
|
# shutil.copy(img, new_dir)
|
||||||
|
# print(imgPth)
|
||||||
|
# print(featList)
|
||||||
|
# print(imgs_pth)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_two_img(img1, img2):
|
||||||
|
img1_feature = get_feature_list(img1, False)[0]
|
||||||
|
img2_feature = get_feature_list(img2, False)[0]
|
||||||
|
cos_sim = img1_feature.dot(img2_feature) / (np.linalg.norm(img1_feature) * np.linalg.norm(img2_feature))
|
||||||
|
print(cos_sim)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
rootpth = 'Single_purchase_data'
|
||||||
|
'''
|
||||||
|
制作单trackid下的相似度矩阵
|
||||||
|
'''
|
||||||
|
# vpth = 'test_video'
|
||||||
|
# shuntVideo_imgs(ai_obj, rootpth, vpth)
|
||||||
|
|
||||||
|
'''
|
||||||
|
平均向量与每个图片的相似度
|
||||||
|
'''
|
||||||
|
imgs_pth = os.sep.join([rootpth, 'images'])
|
||||||
|
average_similarity(imgs_pth)
|
||||||
|
|
||||||
|
'''
|
||||||
|
k值聚类中心向量与每个图片的相似度
|
||||||
|
'''
|
||||||
|
# imgs_pth = os.sep.join([rootpth, 'images'])
|
||||||
|
# k_similarity(imgs_pth, k=3)
|
||||||
|
|
||||||
|
'''
|
||||||
|
计算筛选单个barcode相似度集中最多的图片
|
||||||
|
'''
|
||||||
|
# rootpths = 'data_test'
|
||||||
|
# barcode_similarity(rootpths)
|
||||||
|
|
||||||
|
'''
|
||||||
|
对比两张图的相似度
|
||||||
|
'''
|
||||||
|
# img1 = ['C:/Users/HP/Desktop/maskBackImg.jpg']
|
||||||
|
# img2 = ['C:/Users/HP/Desktop/frontImgMask.jpg']
|
||||||
|
# compare_two_img(img1, img2)
|
77
manualSortingTools.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
from ytracking.track_ import *
|
||||||
|
from contrast.test_logic import group_image, inference
|
||||||
|
from tools.Interface import AiInterface, AiClass
|
||||||
|
|
||||||
|
from tools.initModel import models
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
models.initModel()
|
||||||
|
ai_obj = AiClass()
|
||||||
|
distance_lists = []
|
||||||
|
all_distance_lists = []
|
||||||
|
|
||||||
|
|
||||||
|
def showImg(newx):
|
||||||
|
plt.scatter(newx[:, 0], newx[:, 1])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(vec1, vec2):
|
||||||
|
vec1 = np.array(vec1)
|
||||||
|
vec2 = np.array(vec2)
|
||||||
|
cos_sim = vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
|
return cos_sim
|
||||||
|
|
||||||
|
|
||||||
|
def get_kclusters(data, k):
|
||||||
|
Kmeans = KMeans(n_clusters=k, n_init='auto')
|
||||||
|
y_pred = Kmeans.fit_predict(data)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
|
def move_file(pth, dirs, labels, y_pred):
|
||||||
|
for label, y_label in zip(labels, y_pred):
|
||||||
|
if not os.path.isdir(os.sep.join([pth, dirs, str(y_label)])):
|
||||||
|
os.mkdir(os.sep.join([pth, dirs, str(y_label)]))
|
||||||
|
try:
|
||||||
|
shutil.move(os.sep.join([pth, dirs, label]),
|
||||||
|
os.sep.join([pth, dirs, str(y_label), label]))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
def createDic(pth, dirs):
|
||||||
|
all_dics = {}
|
||||||
|
imgs = os.sep.join([pth, dirs])
|
||||||
|
for img in os.listdir(imgs):
|
||||||
|
imgPth = os.sep.join([imgs, img])
|
||||||
|
feat_tensor = inference([imgPth], models.similarityModel, actionModel=False)
|
||||||
|
all_dics[img] = feat_tensor
|
||||||
|
return all_dics
|
||||||
|
|
||||||
|
|
||||||
|
def main(pth, dirs):
|
||||||
|
global distance_lists
|
||||||
|
allDics = createDic(pth, dirs)
|
||||||
|
labels = []
|
||||||
|
temp = None
|
||||||
|
for key, value in allDics.items():
|
||||||
|
value = value.cpu().detach().numpy()
|
||||||
|
labels.append(key)
|
||||||
|
if temp is None:
|
||||||
|
temp = value
|
||||||
|
else:
|
||||||
|
temp = np.vstack((temp, value))
|
||||||
|
y_pred = get_kclusters(temp, 3)
|
||||||
|
move_file(pth, dirs, labels, y_pred)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pth = 'C:/Users/HP/Desktop/40084'
|
||||||
|
dirs = '1'
|
||||||
|
main(pth, dirs)
|
BIN
multiple_grid.png
Normal file
After Width: | Height: | Size: 60 KiB |
BIN
multiple_histograms.png
Normal file
After Width: | Height: | Size: 61 KiB |
BIN
tools/.usearch
Normal file
175
tools/Interface.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import abc
|
||||||
|
# import os
|
||||||
|
# import pdb
|
||||||
|
# import pickle
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tools.config import gvalue
|
||||||
|
|
||||||
|
sys.path.append('./ytracking')
|
||||||
|
sys.path.append('./contrast')
|
||||||
|
# from ytracking.tracking.dotrack import init_tracker, VideoTracks, boxes_add_fid
|
||||||
|
from ytracking.tracking.have_tracking import have_tracked
|
||||||
|
from ytracking.track_ import *
|
||||||
|
from contrast.logic import datacollection, similarityResult, similarity
|
||||||
|
from PIL import Image
|
||||||
|
from tools.config import gvalue
|
||||||
|
|
||||||
|
class AiInterface(metaclass=abc.ABCMeta):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def getTrackingBox(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def getSimilarity(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AiClass(AiInterface):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_xyxy_coordinates(self, box, frame_id_img):
|
||||||
|
"""
|
||||||
|
计算并返回边界框的坐标。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
x1 = max(0, int(box[0]))
|
||||||
|
x2 = min(frame_id_img.shape[1], int(box[2]))
|
||||||
|
y1 = max(0, int(box[1]))
|
||||||
|
y2 = min(frame_id_img.shape[0], int(box[3]))
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
except IndexError as e:
|
||||||
|
raise ValueError("边界框坐标超出图像尺寸") from e
|
||||||
|
|
||||||
|
def getTrackingBox(self, bboxes, features_dict, camera_id, frame_id_img, save_imgs_dir):
|
||||||
|
"""
|
||||||
|
根据提供的边界框和帧图像,返回图像列表和轨迹ID列表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_lists = {}
|
||||||
|
track_id_list = []
|
||||||
|
|
||||||
|
gt = Profile()
|
||||||
|
with gt:
|
||||||
|
vts = have_tracked(bboxes, features_dict, camera_id)
|
||||||
|
|
||||||
|
nn = 0
|
||||||
|
for res in vts.Residual:
|
||||||
|
for box in res.boxes:
|
||||||
|
try:
|
||||||
|
box = [int(i) for i in box.tolist()]
|
||||||
|
print('box[7] >>>> {}'.format(box[7]))
|
||||||
|
x1, y1, x2, y2 = self.get_xyxy_coordinates(box, frame_id_img[box[7]])
|
||||||
|
gvalue.track_y_lists.append(y1)
|
||||||
|
c_img = frame_id_img[box[7]][y1:y2, x1:x2][:, :, ::-1]
|
||||||
|
|
||||||
|
# c_img = frame_id_img[box[7]][box[1]:box[3], box[0]:box[2]][:, :, ::-1]
|
||||||
|
img_pil = Image.fromarray(c_img.astype('uint8'), 'RGB')
|
||||||
|
|
||||||
|
img_pil.save(os.sep.join([save_imgs_dir, str(nn) + '.jpg']))
|
||||||
|
nn += 1
|
||||||
|
|
||||||
|
track_id = str(box[4])
|
||||||
|
track_id_list.append(track_id)
|
||||||
|
if track_id not in image_lists:
|
||||||
|
image_lists[track_id] = []
|
||||||
|
image_lists[track_id].append(img_pil)
|
||||||
|
except Exception as e:
|
||||||
|
print("y1: {}, y2: {}, x1:{} x2:{}".format(box[2], box[3], box[0], box[1]))
|
||||||
|
print("x:{}, y:{}".format(frame_id_img[box[7]].shape[1], frame_id_img[box[7]].shape[0]))
|
||||||
|
print(f"处理边界框时发生错误: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_image_list = list(image_lists.values())
|
||||||
|
trackIdList = list(set(track_id_list))
|
||||||
|
|
||||||
|
return all_image_list, trackIdList
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_topn_data(source_data):
|
||||||
|
if source_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(source_data, dict):
|
||||||
|
raise ValueError("输入数据必须是字典类型")
|
||||||
|
|
||||||
|
if not source_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
total = {}
|
||||||
|
carId_barcode_trackId_list = []
|
||||||
|
data_category = []
|
||||||
|
|
||||||
|
for category, category_data in source_data.items():
|
||||||
|
carId_barcode_trackId_list.append(category)
|
||||||
|
for car_id, similarity in category_data.items():
|
||||||
|
data_category.append({'carId_barcode_trackId_n': car_id, 'similarity': similarity})
|
||||||
|
|
||||||
|
total['carId_barcode_trackId'] = carId_barcode_trackId_list
|
||||||
|
total['data'] = data_category
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_top10_data(source_data):
|
||||||
|
if source_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(source_data, dict):
|
||||||
|
raise ValueError("输入数据必须是字典类型")
|
||||||
|
|
||||||
|
if not source_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
total = {}
|
||||||
|
data_category = []
|
||||||
|
|
||||||
|
for category, category_data in source_data.items():
|
||||||
|
trackid = category.split('_')[-1]
|
||||||
|
barcode = category.split('_')[-2]
|
||||||
|
for car_id, similarity in category_data.items():
|
||||||
|
data_category.append({'barcode': car_id, 'similarity': similarity, 'trackid': trackid})
|
||||||
|
|
||||||
|
total['barcode'] = barcode
|
||||||
|
total['data'] = data_category
|
||||||
|
return total
|
||||||
|
|
||||||
|
def getSimilarity(self, model, queueImgs):
|
||||||
|
data_collection = datacollection()
|
||||||
|
similarityRes = similarityResult()
|
||||||
|
data_collection.barcode_flag = queueImgs['barcode_flag']
|
||||||
|
data_collection.add_flag = queueImgs['add_flag']
|
||||||
|
data_collection.barcode_list = queueImgs['barcode_list'].strip("'").split(',')
|
||||||
|
data_collection.queImgsDict = queueImgs
|
||||||
|
|
||||||
|
similarityRes = similarity().getSimilarity(model, data_collection, similarityRes)
|
||||||
|
# print('similarityRes.top10: ------------------ {}'.format(similarityRes.top10))
|
||||||
|
if similarityRes.top1:
|
||||||
|
similarityRes.top1 = {"barcode": list(similarityRes.top1.keys())[0],
|
||||||
|
"similarity": list(similarityRes.top1.values())[0]}
|
||||||
|
# similarityRes.tempLibList = gvalue.tempLibList
|
||||||
|
# print('-------------------------', gvalue.tempLibLists)
|
||||||
|
if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||||
|
similarityRes.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||||||
|
else:
|
||||||
|
similarityRes.tempLibList = []
|
||||||
|
similarityresult = {
|
||||||
|
'top10': AiClass.process_top10_data(similarityRes.top10),
|
||||||
|
'top1': similarityRes.top1,
|
||||||
|
'topn': AiClass.process_topn_data(similarityRes.topn),
|
||||||
|
'tempLibList': similarityRes.tempLibList,
|
||||||
|
'sequenceId': queueImgs['sequenceId'],
|
||||||
|
}
|
||||||
|
return similarityresult
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
AI = AiClass()
|
||||||
|
|
||||||
|
# track_boxes, frame_id_img = run()
|
||||||
|
# AI.getTrackingBox(track_boxes, frame_id_img)
|
||||||
|
# print('=== test ===')
|
||||||
|
# AI.getSimilarity(cfg.queueImgs)
|
BIN
tools/Template_images/cartboarder.png
Normal file
After Width: | Height: | Size: 13 KiB |
BIN
tools/Template_images/cartedge.png
Normal file
After Width: | Height: | Size: 11 KiB |
BIN
tools/Template_images/edgeline.png
Normal file
After Width: | Height: | Size: 7.2 KiB |
BIN
tools/Template_images/incart.png
Normal file
After Width: | Height: | Size: 9.6 KiB |
BIN
tools/Template_images/incart_ftmp.png
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
tools/Template_images/outcart.png
Normal file
After Width: | Height: | Size: 9.6 KiB |
0
tools/__init__.py
Normal file
BIN
tools/__pycache__/Interface.cpython-38.pyc
Normal file
BIN
tools/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
tools/__pycache__/config.cpython-38.pyc
Normal file
BIN
tools/__pycache__/getResult.cpython-38.pyc
Normal file
BIN
tools/__pycache__/getbox.cpython-38.pyc
Normal file
BIN
tools/__pycache__/initModel.cpython-38.pyc
Normal file
BIN
tools/__pycache__/operate_usearch.cpython-38.pyc
Normal file
BIN
tools/__pycache__/uploadvideos.cpython-38.pyc
Normal file
93
tools/config.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
|
||||||
|
# from yacs.config import CfgNode as CfgNode
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
class globalVal:
|
||||||
|
tempLibList = []
|
||||||
|
tempLibLists = {}
|
||||||
|
track_y_lists = []
|
||||||
|
mac_id = None
|
||||||
|
back_return_similarity = []
|
||||||
|
back_add_similarity = []
|
||||||
|
front_return_similarity = []
|
||||||
|
front_add_similarity = []
|
||||||
|
comprehensive_similarity = []
|
||||||
|
|
||||||
|
class config:
|
||||||
|
save_videos_dir = 'videos'
|
||||||
|
|
||||||
|
#url
|
||||||
|
# push_url = 'http://api.test2.ieemoo.cn/emoo-api/intelligence/addVideoPathBySequenceId.do'
|
||||||
|
push_url = 'https://api.test2.ieemoo.cn/emoo-api/intelligence/addVideoPathBySequenceId.do' # 闲时上传
|
||||||
|
get_config_url = 'https://api.test2.ieemoo.cn/emoo-api/intelligence/addVideoPathByStoreId.do' # 闲时上传相应配置
|
||||||
|
storidPth = 'tools/storeId.txt'
|
||||||
|
|
||||||
|
#obs update
|
||||||
|
obs_access_key_id = 'LHXJC7GIC2NNUUHHTNVI'
|
||||||
|
obs_secret_access_key = 'sVWvEItrFKWPp5DxeMvX8jLFU69iXPpzkjuMX3iM'
|
||||||
|
obs_server = 'https://obs.cn-east-3.myhuaweicloud.com'
|
||||||
|
obs_bucketName = 'ieemoo-ai'
|
||||||
|
|
||||||
|
keys = ['x', 'y', 'w', 'h', 'track_id', 'score', 'cls', 'frame_index']
|
||||||
|
|
||||||
|
obs_root_dir = 'ieemoo_ai_data'
|
||||||
|
|
||||||
|
#contrast config
|
||||||
|
host = "192.168.1.28"
|
||||||
|
port = "19530"
|
||||||
|
embedding_size = 256
|
||||||
|
img_size = 224
|
||||||
|
test_transform = T.Compose([
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((224, 224)),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[0.5], std=[0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# test_model = "./tools/ckpts/MobilenetV3Large_noParallel_2624.pth"
|
||||||
|
test_model = "./tools/ckpts/resnet18_0721_best.pth"
|
||||||
|
tracking_model = "./tools/ckpts/best_158734_cls11_noaug10.pt"
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
httpHost = '0.0.0.0'
|
||||||
|
httpPort = 8088
|
||||||
|
|
||||||
|
#tracking config
|
||||||
|
botsort = './ytracking/tracking/trackers/cfg/botsort.yaml'
|
||||||
|
incart = './tools/Template_images/incart.png'
|
||||||
|
outcart = './tools/Template_images/outcart.png'
|
||||||
|
cartboarder = './tools/Template_images/cartboarder.png'
|
||||||
|
edgeline = './tools/Template_images/edgeline.png'
|
||||||
|
cartedge = './tools/Template_images/cartedge.png'
|
||||||
|
incart_ftmp = './tools/Template_images/incart_ftmp.png'
|
||||||
|
|
||||||
|
|
||||||
|
action_type = {
|
||||||
|
"1": 'purchase',
|
||||||
|
'2': 'jettison',
|
||||||
|
'3': 'unswept_purchase',
|
||||||
|
'4': 'unswept_jettison'
|
||||||
|
}
|
||||||
|
camera_id = {
|
||||||
|
'0': 'back',
|
||||||
|
'1': 'front',
|
||||||
|
}
|
||||||
|
recognize_result = {
|
||||||
|
'01': 'uncatalogued',
|
||||||
|
'02': 'fail',
|
||||||
|
'03': 'exception',
|
||||||
|
'04': 'pass',
|
||||||
|
}
|
||||||
|
|
||||||
|
# reid config
|
||||||
|
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
model_path = './tools/ckpts/best_resnet18_0515.pth'
|
||||||
|
|
||||||
|
temp_video_name = None
|
||||||
|
|
||||||
|
cfg = config()
|
||||||
|
gvalue = globalVal()
|
78
tools/getResult.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from tools.Interface import AiInterface, AiClass
|
||||||
|
# from Interface import AiInterface, AiClass
|
||||||
|
from config import cfg
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import pdb
|
||||||
|
from track_ import *
|
||||||
|
import time
|
||||||
|
|
||||||
|
'''
|
||||||
|
跟踪与对比输入
|
||||||
|
'''
|
||||||
|
ai_obj = AiClass()
|
||||||
|
|
||||||
|
|
||||||
|
def deal_similarity_data(message):
|
||||||
|
# print('message --- > {}'.format(message))
|
||||||
|
car_mac = message['videoIds'].split('_')[-3]
|
||||||
|
similar = {'add_flag': True if message['action'] in {'1', '3'} else False,
|
||||||
|
'barcode_flag': True if not message['barcode'] == 'null' else False}
|
||||||
|
if similar['add_flag'] and similar['barcode_flag']: # 加购有barcode
|
||||||
|
for Id, image_list in zip(message['trackIdList'],
|
||||||
|
message['images']):
|
||||||
|
similar[car_mac + '_' + message['barcode'] + '_' + Id] = image_list
|
||||||
|
similar['barcode_list'] = message['barcodeList']
|
||||||
|
elif similar['add_flag'] and (not similar['barcode_flag']): # 加购无barcode
|
||||||
|
for Id, image_list in zip(message['trackIdList'],
|
||||||
|
message['images']):
|
||||||
|
similar[car_mac + '_' + Id] = image_list
|
||||||
|
similar['barcode_list'] = message['barcodeList']
|
||||||
|
elif (not similar['add_flag']) and similar['barcode_flag']: # 退购有barcode
|
||||||
|
for Id, image_list in zip(message['trackIdList'],
|
||||||
|
message['images']):
|
||||||
|
similar[car_mac + '_' + message['barcode'] + '_' + Id] = image_list
|
||||||
|
similar['barcode_list'] = message['barcodeList']
|
||||||
|
else: # 退购无barcode
|
||||||
|
for Id, image_list in zip(message['trackIdList'],
|
||||||
|
message['images']):
|
||||||
|
similar[car_mac + '_' + Id] = image_list
|
||||||
|
similar['barcode_list'] = message['barcodeList']
|
||||||
|
similar['sequenceId'] = message['sequenceId']
|
||||||
|
return similar
|
||||||
|
|
||||||
|
|
||||||
|
def get_similarity_result(obj: AiInterface, videopth, model, camera_id, message):
|
||||||
|
"""
|
||||||
|
获取相似度结果。
|
||||||
|
|
||||||
|
:param videopth:
|
||||||
|
:param obj: AiInterface 对象,用于获取跟踪框和相似度数据。
|
||||||
|
:param message: dict, 可选参数,包含跟踪数据和相似度处理结果。
|
||||||
|
:return: 相似度结果。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
track_boxes, features_dict, frame_id_img = run(model, source=videopth)
|
||||||
|
allimages, trackIdList = obj.getTrackingBox(track_boxes, features_dict, camera_id, frame_id_img)
|
||||||
|
message['trackIdList'] = trackIdList
|
||||||
|
message['images'] = allimages
|
||||||
|
message = deal_similarity_data(message)
|
||||||
|
similarityRes = obj.getSimilarity(model, message)
|
||||||
|
except ValueError as ve:
|
||||||
|
print('ve >>>> ', ve)
|
||||||
|
similarityRes = {'top10': {},
|
||||||
|
'top1': {},
|
||||||
|
'topn': {},
|
||||||
|
'tempLibList': [],
|
||||||
|
'sequenceId': message['sequenceId']
|
||||||
|
}
|
||||||
|
return similarityRes
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
message = {
|
||||||
|
'action': '1',
|
||||||
|
'barcode': '084501446314',
|
||||||
|
'sequenceId': 'test'
|
||||||
|
}
|
||||||
|
get_similarity_result(ai_obj, message)
|
153
tools/getbox.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Tue May 21 15:25:23 2024
|
||||||
|
|
||||||
|
@author: ieemoo-zl003
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 替换为你的目录路径
|
||||||
|
files_path = 'D:/Project/ieemoo/kmeans/comparisonData/deletedBarcode_10_0709_am/err_pair/6902088146356/20240709-105804_6902088146356'
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_float_arr(s):
|
||||||
|
# 移除字符串末尾的逗号(如果存在)
|
||||||
|
if s.endswith(','):
|
||||||
|
s = s[:-1]
|
||||||
|
|
||||||
|
# 使用split()方法分割字符串,然后将每个元素转化为float
|
||||||
|
float_array = [float(x) for x in s.split(",")]
|
||||||
|
return float_array
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tracker_input_boxes_feats(file_name):
|
||||||
|
framesId = []
|
||||||
|
boxes = []
|
||||||
|
feats = []
|
||||||
|
|
||||||
|
|
||||||
|
frame_id = 0
|
||||||
|
with open(file_name, 'r', encoding='utf-8') as file:
|
||||||
|
for line in file:
|
||||||
|
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||||
|
|
||||||
|
# 跳过空行
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.find("frameId") >= 0:
|
||||||
|
frame_id = line[line.find("frameId:") + 8:].strip()
|
||||||
|
|
||||||
|
# 检查是否以'box:'或'feat:'开始
|
||||||
|
if line.find("box:") >= 0 and line.find("output_box:") < 0:
|
||||||
|
boxes.append(line[line.find("box:") + 4:].strip()) # 去掉'box:'并去除可能的空白字符
|
||||||
|
framesId.append(frame_id)
|
||||||
|
|
||||||
|
if line.find("feat:") >= 0:
|
||||||
|
feats.append(line[line.find("feat:") + 5:].strip()) # 去掉'box:'并去除可能的空白字符
|
||||||
|
|
||||||
|
return boxes, feats, framesId
|
||||||
|
|
||||||
|
|
||||||
|
def find_string_in_array(arr, target):
|
||||||
|
"""
|
||||||
|
在字符串数组中找到目标字符串对应的行(索引)。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
arr -- 字符串数组
|
||||||
|
target -- 要查找的目标字符串
|
||||||
|
|
||||||
|
返回:
|
||||||
|
目标字符串在数组中的索引。如果未找到,则返回-1。
|
||||||
|
"""
|
||||||
|
parts = target.split(',')
|
||||||
|
box_substrings = ','.join(parts[:4])
|
||||||
|
conf_substring = ','.join(parts[5:6])
|
||||||
|
for i, s in enumerate(arr):
|
||||||
|
if s.find(box_substrings) >= 0 and s.find(conf_substring[:7]) >= 0:
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tracker_output_boxes_feats(read_file_name):
|
||||||
|
input_boxes, input_feats, framesId = extract_tracker_input_boxes_feats(read_file_name)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
feats = []
|
||||||
|
with open(read_file_name, 'r', encoding='utf-8') as file:
|
||||||
|
for line in file:
|
||||||
|
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||||
|
|
||||||
|
# 跳过空行
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否以'output_box:'开始
|
||||||
|
if line.find("output_box:") >= 0:
|
||||||
|
boxes_str = line[line.find("output_box:") + 11:].strip()
|
||||||
|
boxes.append(boxes_str) # 去掉'output_box:'并去除可能的空白字符
|
||||||
|
index = find_string_in_array(input_boxes, boxes_str)
|
||||||
|
feat_f = str_to_float_arr(input_feats[index])
|
||||||
|
norm_f = np.linalg.norm(feat_f)
|
||||||
|
feat_f = feat_f / norm_f
|
||||||
|
feats.append(feat_f)
|
||||||
|
return input_boxes, input_feats, boxes, feats, framesId
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tracking_output_boxes_feats(read_file_name):
|
||||||
|
tracker_boxes, tracker_feats, input_boxes, input_feats, framesId = extract_tracker_output_boxes_feats(
|
||||||
|
read_file_name)
|
||||||
|
boxes = []
|
||||||
|
feats = []
|
||||||
|
boxes_frames_id = []
|
||||||
|
tracking_flag = False
|
||||||
|
tracking_num_cnt = 0
|
||||||
|
with open(read_file_name, 'r', encoding='utf-8') as file:
|
||||||
|
for line in file:
|
||||||
|
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
||||||
|
|
||||||
|
# 跳过空行
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tracking_flag:
|
||||||
|
if line.find("tracking_") >= 0:
|
||||||
|
tracking_flag = False
|
||||||
|
tracking_num_cnt = tracking_num_cnt + 1
|
||||||
|
else:
|
||||||
|
boxes.append(line)
|
||||||
|
index = find_string_in_array(input_boxes, line)
|
||||||
|
feats.append(input_feats[index])
|
||||||
|
|
||||||
|
if tracking_num_cnt == 0:
|
||||||
|
index = find_string_in_array(tracker_boxes, line)
|
||||||
|
boxes_frames_id.append(framesId[index])
|
||||||
|
# 检查是否以tracking_'开始
|
||||||
|
if line.find("tracking_") >= 0:
|
||||||
|
tracking_flag = True
|
||||||
|
|
||||||
|
return tracker_boxes, tracker_feats, input_boxes, input_feats, boxes, feats, boxes_frames_id
|
||||||
|
|
||||||
|
def find_index_feats(files_path):
|
||||||
|
# 遍历目录下的所有文件和目录
|
||||||
|
all_boxes, boboxes_frames_ids, Boxes, framesIds =[],[],[],[]
|
||||||
|
for filename in os.listdir(files_path):
|
||||||
|
# 构造完整的文件路径
|
||||||
|
file_path = os.path.join(files_path, filename)
|
||||||
|
# 判断是否是文件
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
# 打开文件
|
||||||
|
if filename.endswith('data') and (not 'tracking' in filename):
|
||||||
|
tracker_boxes, tracker_feats, input_boxes, input_feats, boxes, feats, boxes_frames_id = extract_tracking_output_boxes_feats(file_path)
|
||||||
|
box, feats, framesId = extract_tracker_input_boxes_feats(file_path)
|
||||||
|
Boxes += box
|
||||||
|
framesIds += framesId
|
||||||
|
all_boxes += boxes[:len(boxes_frames_id)]
|
||||||
|
boboxes_frames_ids += boxes_frames_id
|
||||||
|
# print(all_boxes)
|
||||||
|
# print(boboxes_frames_ids)
|
||||||
|
return all_boxes, boboxes_frames_ids, tracker_boxes, Boxes, framesIds
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
find_index_feats(files_path)
|
45
tools/initModel.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
from ytracking.models.common import DetectMultiBackend
|
||||||
|
from ytracking.utils.torch_utils import select_device
|
||||||
|
from tools.config import cfg
|
||||||
|
from contrast.model.resnet_pre import resnet18
|
||||||
|
from ytracking.tracking.utils import Boxes, IterableSimpleNamespace, yaml_load
|
||||||
|
from ytracking.tracking.trackers import BOTSORT, BYTETracker
|
||||||
|
# import mediapipe as mp
|
||||||
|
# from pymilvus import (
|
||||||
|
# connections,
|
||||||
|
# utility,
|
||||||
|
# FieldSchema, CollectionSchema, DataType,
|
||||||
|
# Collection,
|
||||||
|
# Milvus
|
||||||
|
# )
|
||||||
|
|
||||||
|
class Models:
|
||||||
|
def __init__(self):
|
||||||
|
self.yoloModel = None
|
||||||
|
self.reidModel = None
|
||||||
|
self.similarityModel = None
|
||||||
|
self.Milvus = None
|
||||||
|
self.device = 'cpu'
|
||||||
|
|
||||||
|
def initSimilarityModel(self):
|
||||||
|
# model = MobileNetV3_Large().to(cfg.device)
|
||||||
|
model = resnet18().to(cfg.device)
|
||||||
|
# model.load_state_dict(torch.load(cfg.test_model, map_location=cfg.device))
|
||||||
|
model.load_state_dict(torch.load(cfg.model_path, map_location=cfg.device))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def initYoloModel(self):
|
||||||
|
device = select_device(self.device)
|
||||||
|
model = DetectMultiBackend(cfg.tracking_model, device=device, dnn=False, fp16=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def initModel(self):
|
||||||
|
self.yoloModel = self.initYoloModel()
|
||||||
|
self.similarityModel = self.initSimilarityModel()
|
||||||
|
|
||||||
|
models = Models()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Models().initModel()
|
153
tools/operate_usearch.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from usearch.index import Index
|
||||||
|
import json
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
|
||||||
|
def create_index():
|
||||||
|
index = Index(
|
||||||
|
ndim=256,
|
||||||
|
metric='cos',
|
||||||
|
# dtype='f32',
|
||||||
|
dtype='f16',
|
||||||
|
connectivity=32,
|
||||||
|
expansion_add=40,#128,
|
||||||
|
expansion_search=10,#64,
|
||||||
|
multi=True
|
||||||
|
)
|
||||||
|
return index
|
||||||
|
|
||||||
|
def compare_feature(features1, features2, model = '1'):
|
||||||
|
"""
|
||||||
|
:param model 比对策略
|
||||||
|
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||||||
|
'1':带对比的所有相似度的均值
|
||||||
|
'2':比对1:1的最大值
|
||||||
|
:param feature1:
|
||||||
|
:param feature2:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
similarity_group, similarity_groups = [], []
|
||||||
|
if model == '0':
|
||||||
|
for feature1 in features1:
|
||||||
|
for feature2 in features2[0]:
|
||||||
|
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||||
|
similarity_group.append(similarity)
|
||||||
|
similarity_groups.append(max(similarity_group))
|
||||||
|
similarity_group = []
|
||||||
|
return sum(similarity_groups)/len(similarity_groups)
|
||||||
|
|
||||||
|
elif model == '1':
|
||||||
|
feature2 = features2[0]
|
||||||
|
for feature1 in features1:
|
||||||
|
for num in range(len(feature2)):
|
||||||
|
similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||||
|
similarity_group.append(similarity)
|
||||||
|
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||||||
|
similarity_group = []
|
||||||
|
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||||||
|
if len(similarity_groups) == 0:
|
||||||
|
return -1
|
||||||
|
return sum(similarity_groups)/len(similarity_groups)
|
||||||
|
elif model == '2':
|
||||||
|
feature2 = features2[0]
|
||||||
|
for feature1 in features1:
|
||||||
|
for num in range(len(feature2)):
|
||||||
|
similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||||
|
similarity_group.append(similarity)
|
||||||
|
return max(similarity_group)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_barcode_feature(data):
|
||||||
|
barcode = data['key']
|
||||||
|
features = data['value']
|
||||||
|
return [barcode] * len(features), features
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_file(file_path):
|
||||||
|
"""
|
||||||
|
:param file_path:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
barcodes, features = [], []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for dic in data['total']:
|
||||||
|
barcode, feature = get_barcode_feature(dic)
|
||||||
|
barcodes.append(barcode)
|
||||||
|
features.append(feature)
|
||||||
|
return barcodes, features
|
||||||
|
|
||||||
|
|
||||||
|
def create_base_index(index_file_pth=None,
|
||||||
|
barcodes=None,
|
||||||
|
features=None,
|
||||||
|
save_index_name=None):
|
||||||
|
index = create_index()
|
||||||
|
if index_file_pth is not None:
|
||||||
|
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||||||
|
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||||||
|
barcodes, features = analysis_file(index_file_pth)
|
||||||
|
else:
|
||||||
|
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||||||
|
for barcode, feature in zip(barcodes, features):
|
||||||
|
index.add(np.array(barcode), np.array(feature))
|
||||||
|
index.save(save_index_name)
|
||||||
|
|
||||||
|
def get_feature_index(index_file_pth=None,
|
||||||
|
barcodes=None):
|
||||||
|
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||||||
|
index = Index.restore(index_file_pth, view=True)
|
||||||
|
feature_lists = index.get(np.array(barcodes))
|
||||||
|
print("memory {} size {}".format(index.memory_usage, index.size))
|
||||||
|
return feature_lists
|
||||||
|
|
||||||
|
def search_in_index(query=None,
|
||||||
|
barcode=None, # barcode -> int or np.ndarray
|
||||||
|
index_name=None,
|
||||||
|
temp_index=False, # 是否为临时库
|
||||||
|
model='0',
|
||||||
|
):
|
||||||
|
if temp_index:
|
||||||
|
assert index_name is not None, 'index_name must be not None'
|
||||||
|
index = Index.restore(index_name, view=True)
|
||||||
|
if barcode is not None: # 1:1对比测试
|
||||||
|
feature_lists = index.get(np.array(barcode))
|
||||||
|
results = compare_feature(query, feature_lists)
|
||||||
|
else:
|
||||||
|
results = index.search(query, count=5)
|
||||||
|
return results
|
||||||
|
else: # 标准库
|
||||||
|
assert index_name is not None, 'index_name must be not None'
|
||||||
|
index = Index.restore(index_name, view=True)
|
||||||
|
if barcode is not None: # 1:1对比测试
|
||||||
|
feature_lists = index.get(np.array(barcode))
|
||||||
|
results = compare_feature(query, feature_lists, model)
|
||||||
|
else:
|
||||||
|
results = index.search(query, count=10)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def delete_index(index_name=None, key=None, index=None):
|
||||||
|
assert key is not None, 'key must be not None'
|
||||||
|
if index is None:
|
||||||
|
assert index_name is not None, 'index_name must be not None'
|
||||||
|
index = Index.restore(index_name, view=True)
|
||||||
|
index.remove(index_name)
|
||||||
|
else:
|
||||||
|
index.remove(key)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# index_file_pth = '../search_library/data_0923.json'
|
||||||
|
# create_base_index(index_file_pth)
|
||||||
|
|
||||||
|
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||||||
|
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||||||
|
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||||||
|
|
||||||
|
# check index data file
|
||||||
|
index_file_pth = '../search_library/data_0923.data'
|
||||||
|
# # get_feature_index(index_file_pth, ['6901070602818'])
|
||||||
|
get_feature_index(index_file_pth, ['6934230050105'])
|
||||||
|
|
1
tools/storeId.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
32011001
|
159
tools/uploadvideos.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
from obs import ObsClient
|
||||||
|
import obs
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from tools.config import cfg
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
class uploadVideos:
|
||||||
|
def __init__(self):
|
||||||
|
self.obsClient, self.headers = self.InitObsClient()
|
||||||
|
|
||||||
|
def InitObsClient(self):
|
||||||
|
# 创建ObsClient实例
|
||||||
|
obsClient = ObsClient(
|
||||||
|
access_key_id=cfg.obs_access_key_id, # 配置访问密钥ID
|
||||||
|
secret_access_key=cfg.obs_secret_access_key, # 配置访问密钥
|
||||||
|
server=cfg.obs_server) # 配置服务器地址
|
||||||
|
headers = obs.SetObjectMetadataHeader() # 设置对象元数据头
|
||||||
|
headers.cacheControl = "no-cache" # 设置缓存控制为不缓存
|
||||||
|
return obsClient, headers # 返回ObsClient实例和元数据头
|
||||||
|
|
||||||
|
def upload(self, video_name, video_dir=None):
|
||||||
|
# 检查是否已初始化ObsClient实例
|
||||||
|
if not self.obsClient:
|
||||||
|
raise ValueError("请先初始化ObsClient实例")
|
||||||
|
|
||||||
|
class uploadResult:
|
||||||
|
def __init__(self, video_path=None, squenceId=None):
|
||||||
|
self.video_path = video_path
|
||||||
|
self.squenceId = squenceId
|
||||||
|
|
||||||
|
uploadRes = uploadResult()
|
||||||
|
|
||||||
|
# 解析视频名称获取相关信息
|
||||||
|
information = video_name.split('.')[0].split('_')
|
||||||
|
action_category = cfg.action_type[information[-1]] # 动作类别
|
||||||
|
camera_id = cfg.camera_id[information[-3]] # 摄像头ID
|
||||||
|
recognize_result = cfg.recognize_result[information[0]] # 识别结果
|
||||||
|
time = information[1].split('-')[0] # 时间
|
||||||
|
squenceId = information[1]
|
||||||
|
|
||||||
|
|
||||||
|
# 构建OSS对象键
|
||||||
|
objectkey = os.path.join(cfg.obs_root_dir, time, action_category, camera_id, recognize_result, video_name)
|
||||||
|
if video_dir is None:
|
||||||
|
file_path = os.sep.join([cfg.save_videos_dir, squenceId, video_name]) # 本地文件路径
|
||||||
|
else:
|
||||||
|
file_path = os.sep.join([video_dir, squenceId, video_name])
|
||||||
|
|
||||||
|
# 上传文件到OSS
|
||||||
|
resp = self.obsClient.putFile(cfg.obs_bucketName, objectkey, file_path)
|
||||||
|
|
||||||
|
uploadRes.video_path = resp['body']['objectUrl']
|
||||||
|
uploadRes.squenceId = squenceId
|
||||||
|
os.remove(file_path)
|
||||||
|
return uploadRes
|
||||||
|
|
||||||
|
def get_information(self, video_squence, camera_type):
|
||||||
|
"""获取视频信息"""
|
||||||
|
videos_path = []
|
||||||
|
videos_dir = os.sep.join([cfg.save_videos_dir, video_squence])
|
||||||
|
for video_name in os.listdir(videos_dir):
|
||||||
|
if video_squence in video_name:
|
||||||
|
if camera_type == video_name.split('_')[-3]: # 摄像头位置ID
|
||||||
|
videos_path.append(self.upload(video_name).video_path)
|
||||||
|
elif camera_type == '2':
|
||||||
|
videos_path.append(self.upload(video_name).video_path)
|
||||||
|
return {"videos_path": videos_path}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exception_action(queueImgs):
|
||||||
|
"""删除视频"""
|
||||||
|
# 解析视频名称获取相关信息
|
||||||
|
video_squence = queueImgs['videoIds'].split(',')[0]
|
||||||
|
for video_name in os.listdir(cfg.save_videos_dir):
|
||||||
|
if video_squence in video_name:
|
||||||
|
os.rename(os.sep.join([cfg.save_videos_dir, video_name]),
|
||||||
|
os.sep.join([cfg.save_videos_dir, '03_' + video_name]))
|
||||||
|
|
||||||
|
class VideoUploader:
|
||||||
|
@staticmethod
|
||||||
|
def read_config_file(file_path):
|
||||||
|
"""安全地读取配置文件"""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
except IOError as e:
|
||||||
|
print(f"读取配置文件错误: {e}")
|
||||||
|
return []
|
||||||
|
return [line.strip() for line in lines if line.strip() != '']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_videos_for_ids(video_ids, video_dir):
|
||||||
|
"""根据ID上传视频"""
|
||||||
|
tempdata = []
|
||||||
|
print('----------->', video_dir)
|
||||||
|
for root, dirs, files in os.walk(video_dir):
|
||||||
|
for name in files:
|
||||||
|
print(name)
|
||||||
|
name_s = name.split('.')[0] # 避免重复分割
|
||||||
|
parts = name_s.split('_')
|
||||||
|
if len(parts) < 7:
|
||||||
|
continue
|
||||||
|
for data in video_ids:
|
||||||
|
if parts[-1] == data['action'] and (parts[-3] == data['type'] or data['type'] == '2'):
|
||||||
|
try:
|
||||||
|
upload_rs = uploadVideos().upload(name, video_dir)
|
||||||
|
if upload_rs:
|
||||||
|
video_path = upload_rs.video_path
|
||||||
|
sequence_id = upload_rs.squenceId
|
||||||
|
tempdata.append({
|
||||||
|
"squenceId": sequence_id,
|
||||||
|
"video_path": [video_path]
|
||||||
|
})
|
||||||
|
break # 找到匹配项即跳出循环
|
||||||
|
except Exception as e:
|
||||||
|
print(f"上传视频 {name} 时出错: {e}")
|
||||||
|
return tempdata
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timedUpload(rootPth = '/home/lc/project/ieemoo'):
|
||||||
|
"""定时上传视频"""
|
||||||
|
storidPth = os.sep.join([rootPth, 'tools', 'storeId.txt'])
|
||||||
|
save_videos_dir = os.sep.join([rootPth, 'videos'])
|
||||||
|
config_lines = VideoUploader.read_config_file(storidPth)
|
||||||
|
if not config_lines:
|
||||||
|
print("未找到有效配置。")
|
||||||
|
|
||||||
|
print('配置行 --- >', config_lines)
|
||||||
|
soreid_list = [{"storeId": line} for line in config_lines]
|
||||||
|
try:
|
||||||
|
rep = requests.post(url=cfg.get_config_url, data=soreid_list[0])
|
||||||
|
rep.raise_for_status() # 检查响应状态
|
||||||
|
video_ids = rep.json().get('data', [])
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"获取配置信息失败: {e}")
|
||||||
|
|
||||||
|
if video_ids:
|
||||||
|
tempdata = VideoUploader.upload_videos_for_ids(video_ids, save_videos_dir)
|
||||||
|
if tempdata:
|
||||||
|
tmpdata = {'videosPth': str(tempdata)}
|
||||||
|
try:
|
||||||
|
requests.post(url=cfg.push_url, data=tmpdata)
|
||||||
|
print("推送数据成功")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"推送数据失败: {e}")
|
||||||
|
else:
|
||||||
|
tmpdata = {'videosPth': str([])}
|
||||||
|
try:
|
||||||
|
requests.post(url=cfg.push_url, data=tmpdata)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"空数据推送失败: {e}")
|
||||||
|
# print(tmpdata)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
VideoUploader.timedUpload()
|
0
ytracking/__init__.py
Normal file
BIN
ytracking/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
ytracking/__pycache__/export.cpython-39.pyc
Normal file
BIN
ytracking/__pycache__/track_.cpython-38.pyc
Normal file
0
ytracking/models/__init__.py
Normal file
BIN
ytracking/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
ytracking/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
ytracking/models/__pycache__/common.cpython-38.pyc
Normal file
BIN
ytracking/models/__pycache__/common.cpython-39.pyc
Normal file
BIN
ytracking/models/__pycache__/experimental.cpython-38.pyc
Normal file
BIN
ytracking/models/__pycache__/experimental.cpython-39.pyc
Normal file
BIN
ytracking/models/__pycache__/yolo.cpython-38.pyc
Normal file
BIN
ytracking/models/__pycache__/yolo.cpython-39.pyc
Normal file
883
ytracking/models/common.py
Normal file
@ -0,0 +1,883 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Common modules
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import platform
|
||||||
|
import warnings
|
||||||
|
import zipfile
|
||||||
|
from collections import OrderedDict, namedtuple
|
||||||
|
from copy import copy
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from PIL import Image
|
||||||
|
from torch.cuda import amp
|
||||||
|
|
||||||
|
# Import 'ultralytics' package or install if if missing
|
||||||
|
try:
|
||||||
|
import ultralytics
|
||||||
|
|
||||||
|
assert hasattr(ultralytics, '__version__') # verify package is not directory
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.system('pip install -U ultralytics')
|
||||||
|
import ultralytics
|
||||||
|
|
||||||
|
from ytracking.ultralytics.utils.plotting import Annotator, colors, save_one_box
|
||||||
|
|
||||||
|
from ytracking.utils import TryExcept
|
||||||
|
from ytracking.utils.dataloaders import exif_transpose, letterbox
|
||||||
|
from ytracking.utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
||||||
|
increment_path, is_jupyter, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy,
|
||||||
|
xyxy2xywh, yaml_load)
|
||||||
|
from ytracking.utils.torch_utils import copy_attr, smart_inference_mode
|
||||||
|
|
||||||
|
|
||||||
|
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||||
|
# Pad to 'same' shape outputs
|
||||||
|
if d > 1:
|
||||||
|
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
||||||
|
if p is None:
|
||||||
|
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
class Conv(nn.Module):
|
||||||
|
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
||||||
|
default_act = nn.SiLU() # default activation
|
||||||
|
|
||||||
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
||||||
|
self.bn = nn.BatchNorm2d(c2)
|
||||||
|
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
def forward_fuse(self, x):
|
||||||
|
return self.act(self.conv(x))
|
||||||
|
|
||||||
|
|
||||||
|
class DWConv(Conv):
|
||||||
|
# Depth-wise convolution
|
||||||
|
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
||||||
|
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
||||||
|
|
||||||
|
|
||||||
|
class DWConvTranspose2d(nn.ConvTranspose2d):
|
||||||
|
# Depth-wise transpose convolution
|
||||||
|
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
||||||
|
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Module):
|
||||||
|
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
|
||||||
|
def __init__(self, c, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.q = nn.Linear(c, c, bias=False)
|
||||||
|
self.k = nn.Linear(c, c, bias=False)
|
||||||
|
self.v = nn.Linear(c, c, bias=False)
|
||||||
|
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
||||||
|
self.fc1 = nn.Linear(c, c, bias=False)
|
||||||
|
self.fc2 = nn.Linear(c, c, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
||||||
|
x = self.fc2(self.fc1(x)) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
# Vision Transformer https://arxiv.org/abs/2010.11929
|
||||||
|
def __init__(self, c1, c2, num_heads, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = None
|
||||||
|
if c1 != c2:
|
||||||
|
self.conv = Conv(c1, c2)
|
||||||
|
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
||||||
|
self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
|
||||||
|
self.c2 = c2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.conv is not None:
|
||||||
|
x = self.conv(x)
|
||||||
|
b, _, w, h = x.shape
|
||||||
|
p = x.flatten(2).permute(2, 0, 1)
|
||||||
|
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
# Standard bottleneck
|
||||||
|
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
||||||
|
super().__init__()
|
||||||
|
c_ = int(c2 * e) # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv2 = Conv(c_, c2, 3, 1, g=g)
|
||||||
|
self.add = shortcut and c1 == c2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckCSP(nn.Module):
|
||||||
|
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
||||||
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||||
|
super().__init__()
|
||||||
|
c_ = int(c2 * e) # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
||||||
|
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
||||||
|
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
||||||
|
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y1 = self.cv3(self.m(self.cv1(x)))
|
||||||
|
y2 = self.cv2(x)
|
||||||
|
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
|
||||||
|
|
||||||
|
|
||||||
|
class CrossConv(nn.Module):
|
||||||
|
# Cross Convolution Downsample
|
||||||
|
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
|
||||||
|
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
|
||||||
|
super().__init__()
|
||||||
|
c_ = int(c2 * e) # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, (1, k), (1, s))
|
||||||
|
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
|
||||||
|
self.add = shortcut and c1 == c2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||||
|
|
||||||
|
|
||||||
|
class C3(nn.Module):
|
||||||
|
# CSP Bottleneck with 3 convolutions
|
||||||
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||||
|
super().__init__()
|
||||||
|
c_ = int(c2 * e) # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv2 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
||||||
|
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
||||||
|
|
||||||
|
|
||||||
|
class C3x(C3):
|
||||||
|
# C3 module with cross-convolutions
|
||||||
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||||
|
super().__init__(c1, c2, n, shortcut, g, e)
|
||||||
|
c_ = int(c2 * e)
|
||||||
|
self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
|
||||||
|
|
||||||
|
|
||||||
|
class C3TR(C3):
|
||||||
|
# C3 module with TransformerBlock()
|
||||||
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||||
|
super().__init__(c1, c2, n, shortcut, g, e)
|
||||||
|
c_ = int(c2 * e)
|
||||||
|
self.m = TransformerBlock(c_, c_, 4, n)
|
||||||
|
|
||||||
|
|
||||||
|
class C3SPP(C3):
|
||||||
|
# C3 module with SPP()
|
||||||
|
def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
|
||||||
|
super().__init__(c1, c2, n, shortcut, g, e)
|
||||||
|
c_ = int(c2 * e)
|
||||||
|
self.m = SPP(c_, c_, k)
|
||||||
|
|
||||||
|
|
||||||
|
class C3Ghost(C3):
|
||||||
|
# C3 module with GhostBottleneck()
|
||||||
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||||
|
super().__init__(c1, c2, n, shortcut, g, e)
|
||||||
|
c_ = int(c2 * e) # hidden channels
|
||||||
|
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
|
||||||
|
|
||||||
|
|
||||||
|
class SPP(nn.Module):
|
||||||
|
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
|
||||||
|
def __init__(self, c1, c2, k=(5, 9, 13)):
|
||||||
|
super().__init__()
|
||||||
|
c_ = c1 // 2 # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
||||||
|
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.cv1(x)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
||||||
|
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
||||||
|
|
||||||
|
|
||||||
|
class SPPF(nn.Module):
|
||||||
|
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
|
||||||
|
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
||||||
|
super().__init__()
|
||||||
|
c_ = c1 // 2 # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, 1, 1)
|
||||||
|
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
||||||
|
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.cv1(x)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
||||||
|
y1 = self.m(x)
|
||||||
|
y2 = self.m(y1)
|
||||||
|
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
||||||
|
|
||||||
|
|
||||||
|
class Focus(nn.Module):
|
||||||
|
# Focus wh information into c-space
|
||||||
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
||||||
|
super().__init__()
|
||||||
|
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
|
||||||
|
# self.contract = Contract(gain=2)
|
||||||
|
|
||||||
|
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
||||||
|
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
|
||||||
|
# return self.conv(self.contract(x))
|
||||||
|
|
||||||
|
|
||||||
|
class GhostConv(nn.Module):
|
||||||
|
# Ghost Convolution https://github.com/huawei-noah/ghostnet
|
||||||
|
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
||||||
|
super().__init__()
|
||||||
|
c_ = c2 // 2 # hidden channels
|
||||||
|
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
|
||||||
|
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.cv1(x)
|
||||||
|
return torch.cat((y, self.cv2(y)), 1)
|
||||||
|
|
||||||
|
|
||||||
|
class GhostBottleneck(nn.Module):
|
||||||
|
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
|
||||||
|
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
||||||
|
super().__init__()
|
||||||
|
c_ = c2 // 2
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
GhostConv(c1, c_, 1, 1), # pw
|
||||||
|
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
||||||
|
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
||||||
|
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
|
||||||
|
act=False)) if s == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x) + self.shortcut(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Contract(nn.Module):
|
||||||
|
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
|
||||||
|
def __init__(self, gain=2):
|
||||||
|
super().__init__()
|
||||||
|
self.gain = gain
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
|
||||||
|
s = self.gain
|
||||||
|
x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
|
||||||
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
|
||||||
|
return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
|
||||||
|
|
||||||
|
|
||||||
|
class Expand(nn.Module):
|
||||||
|
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
|
||||||
|
def __init__(self, gain=2):
|
||||||
|
super().__init__()
|
||||||
|
self.gain = gain
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
|
||||||
|
s = self.gain
|
||||||
|
x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
|
||||||
|
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
|
||||||
|
return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
|
||||||
|
|
||||||
|
|
||||||
|
class Concat(nn.Module):
|
||||||
|
# Concatenate a list of tensors along dimension
|
||||||
|
def __init__(self, dimension=1):
|
||||||
|
super().__init__()
|
||||||
|
self.d = dimension
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.cat(x, self.d)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectMultiBackend(nn.Module):
|
||||||
|
# YOLOv5 MultiBackend class for python inference on various backends
|
||||||
|
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
|
||||||
|
# Usage:
|
||||||
|
# PyTorch: weights = *.pt
|
||||||
|
# TorchScript: *.torchscript
|
||||||
|
# ONNX Runtime: *.onnx
|
||||||
|
# ONNX OpenCV DNN: *.onnx --dnn
|
||||||
|
# OpenVINO: *_openvino_model
|
||||||
|
# CoreML: *.mlmodel
|
||||||
|
# TensorRT: *.engine
|
||||||
|
# TensorFlow SavedModel: *_saved_model
|
||||||
|
# TensorFlow GraphDef: *.pb
|
||||||
|
# TensorFlow Lite: *.tflite
|
||||||
|
# TensorFlow Edge TPU: *_edgetpu.tflite
|
||||||
|
# PaddlePaddle: *_paddle_model
|
||||||
|
from ytracking.models.experimental import attempt_download, attempt_load # scoped to avoid circular import
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||||
|
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
||||||
|
fp16 &= pt or jit or onnx or engine or triton # FP16
|
||||||
|
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
||||||
|
stride = 32 # default stride
|
||||||
|
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
||||||
|
if not (pt or triton):
|
||||||
|
w = attempt_download(w) # download if not local
|
||||||
|
|
||||||
|
if pt: # PyTorch
|
||||||
|
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
||||||
|
stride = max(int(model.stride.max()), 32) # model stride
|
||||||
|
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
||||||
|
model.half() if fp16 else model.float()
|
||||||
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||||
|
elif jit: # TorchScript
|
||||||
|
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
||||||
|
extra_files = {'config.txt': ''} # model metadata
|
||||||
|
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
||||||
|
model.half() if fp16 else model.float()
|
||||||
|
if extra_files['config.txt']: # load metadata dict
|
||||||
|
d = json.loads(extra_files['config.txt'],
|
||||||
|
object_hook=lambda d: {
|
||||||
|
int(k) if k.isdigit() else k: v
|
||||||
|
for k, v in d.items()})
|
||||||
|
stride, names = int(d['stride']), d['names']
|
||||||
|
elif dnn: # ONNX OpenCV DNN
|
||||||
|
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
||||||
|
check_requirements('opencv-python>=4.5.4')
|
||||||
|
net = cv2.dnn.readNetFromONNX(w)
|
||||||
|
elif onnx: # ONNX Runtime
|
||||||
|
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
||||||
|
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
||||||
|
import onnxruntime
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
||||||
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||||
|
output_names = [x.name for x in session.get_outputs()]
|
||||||
|
meta = session.get_modelmeta().custom_metadata_map # metadata
|
||||||
|
if 'stride' in meta:
|
||||||
|
stride, names = int(meta['stride']), eval(meta['names'])
|
||||||
|
elif xml: # OpenVINO
|
||||||
|
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
||||||
|
check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||||
|
from openvino.runtime import Core, Layout, get_batch
|
||||||
|
core = Core()
|
||||||
|
if not Path(w).is_file(): # if not *.xml
|
||||||
|
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||||
|
ov_model = core.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
||||||
|
if ov_model.get_parameters()[0].get_layout().empty:
|
||||||
|
ov_model.get_parameters()[0].set_layout(Layout('NCHW'))
|
||||||
|
batch_dim = get_batch(ov_model)
|
||||||
|
if batch_dim.is_static:
|
||||||
|
batch_size = batch_dim.get_length()
|
||||||
|
ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device
|
||||||
|
stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
|
||||||
|
elif engine: # TensorRT
|
||||||
|
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||||
|
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
|
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
||||||
|
if device.type == 'cpu':
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||||
|
model = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
context = model.create_execution_context()
|
||||||
|
bindings = OrderedDict()
|
||||||
|
output_names = []
|
||||||
|
fp16 = False # default updated below
|
||||||
|
dynamic = False
|
||||||
|
for i in range(model.num_bindings):
|
||||||
|
name = model.get_binding_name(i)
|
||||||
|
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||||
|
if model.binding_is_input(i):
|
||||||
|
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
||||||
|
dynamic = True
|
||||||
|
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
||||||
|
if dtype == np.float16:
|
||||||
|
fp16 = True
|
||||||
|
else: # output
|
||||||
|
output_names.append(name)
|
||||||
|
shape = tuple(context.get_binding_shape(i))
|
||||||
|
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
||||||
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||||
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||||
|
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
||||||
|
elif coreml: # CoreML
|
||||||
|
LOGGER.info(f'Loading {w} for CoreML inference...')
|
||||||
|
import coremltools as ct
|
||||||
|
model = ct.models.MLModel(w)
|
||||||
|
elif saved_model: # TF SavedModel
|
||||||
|
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
||||||
|
import tensorflow as tf
|
||||||
|
keras = False # assume TF1 saved_model
|
||||||
|
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||||
|
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||||
|
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
def wrap_frozen_graph(gd, inputs, outputs):
|
||||||
|
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
||||||
|
ge = x.graph.as_graph_element
|
||||||
|
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||||
|
|
||||||
|
def gd_outputs(gd):
|
||||||
|
name_list, input_list = [], []
|
||||||
|
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||||
|
name_list.append(node.name)
|
||||||
|
input_list.extend(node.input)
|
||||||
|
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
||||||
|
|
||||||
|
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||||
|
with open(w, 'rb') as f:
|
||||||
|
gd.ParseFromString(f.read())
|
||||||
|
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
||||||
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||||
|
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||||
|
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||||
|
except ImportError:
|
||||||
|
import tensorflow as tf
|
||||||
|
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
||||||
|
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||||
|
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||||
|
delegate = {
|
||||||
|
'Linux': 'libedgetpu.so.1',
|
||||||
|
'Darwin': 'libedgetpu.1.dylib',
|
||||||
|
'Windows': 'edgetpu.dll'}[platform.system()]
|
||||||
|
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
||||||
|
else: # TFLite
|
||||||
|
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
||||||
|
interpreter = Interpreter(model_path=w) # load TFLite model
|
||||||
|
interpreter.allocate_tensors() # allocate
|
||||||
|
input_details = interpreter.get_input_details() # inputs
|
||||||
|
output_details = interpreter.get_output_details() # outputs
|
||||||
|
# load metadata
|
||||||
|
with contextlib.suppress(zipfile.BadZipFile):
|
||||||
|
with zipfile.ZipFile(w, 'r') as model:
|
||||||
|
meta_file = model.namelist()[0]
|
||||||
|
meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
||||||
|
stride, names = int(meta['stride']), meta['names']
|
||||||
|
elif tfjs: # TF.js
|
||||||
|
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
||||||
|
elif paddle: # PaddlePaddle
|
||||||
|
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
||||||
|
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
||||||
|
import paddle.inference as pdi
|
||||||
|
if not Path(w).is_file(): # if not *.pdmodel
|
||||||
|
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
||||||
|
weights = Path(w).with_suffix('.pdiparams')
|
||||||
|
config = pdi.Config(str(w), str(weights))
|
||||||
|
if cuda:
|
||||||
|
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
||||||
|
predictor = pdi.create_predictor(config)
|
||||||
|
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
||||||
|
output_names = predictor.get_output_names()
|
||||||
|
elif triton: # NVIDIA Triton Inference Server
|
||||||
|
LOGGER.info(f'Using {w} as Triton Inference Server...')
|
||||||
|
check_requirements('tritonclient[all]')
|
||||||
|
from utils.triton import TritonRemoteModel
|
||||||
|
model = TritonRemoteModel(url=w)
|
||||||
|
nhwc = model.runtime.startswith('tensorflow')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
||||||
|
|
||||||
|
# class names
|
||||||
|
if 'names' not in locals():
|
||||||
|
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
|
||||||
|
if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
|
||||||
|
names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
|
||||||
|
|
||||||
|
self.__dict__.update(locals()) # assign all variables to self
|
||||||
|
|
||||||
|
def forward(self, im, augment=False, visualize=False):
|
||||||
|
# YOLOv5 MultiBackend inference
|
||||||
|
b, ch, h, w = im.shape # batch, channel, height, width
|
||||||
|
if self.fp16 and im.dtype != torch.float16:
|
||||||
|
im = im.half() # to FP16
|
||||||
|
if self.nhwc:
|
||||||
|
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||||
|
|
||||||
|
if self.pt: # PyTorch
|
||||||
|
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
||||||
|
elif self.jit: # TorchScript
|
||||||
|
y = self.model(im)
|
||||||
|
elif self.dnn: # ONNX OpenCV DNN
|
||||||
|
im = im.cpu().numpy() # torch to numpy
|
||||||
|
self.net.setInput(im)
|
||||||
|
y = self.net.forward()
|
||||||
|
elif self.onnx: # ONNX Runtime
|
||||||
|
im = im.cpu().numpy() # torch to numpy
|
||||||
|
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
||||||
|
elif self.xml: # OpenVINO
|
||||||
|
im = im.cpu().numpy() # FP32
|
||||||
|
y = list(self.ov_compiled_model(im).values())
|
||||||
|
elif self.engine: # TensorRT
|
||||||
|
if self.dynamic and im.shape != self.bindings['images'].shape:
|
||||||
|
i = self.model.get_binding_index('images')
|
||||||
|
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
||||||
|
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
||||||
|
for name in self.output_names:
|
||||||
|
i = self.model.get_binding_index(name)
|
||||||
|
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
||||||
|
s = self.bindings['images'].shape
|
||||||
|
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
||||||
|
self.binding_addrs['images'] = int(im.data_ptr())
|
||||||
|
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||||
|
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
||||||
|
elif self.coreml: # CoreML
|
||||||
|
im = im.cpu().numpy()
|
||||||
|
im = Image.fromarray((im[0] * 255).astype('uint8'))
|
||||||
|
# im = im.resize((192, 320), Image.BILINEAR)
|
||||||
|
y = self.model.predict({'image': im}) # coordinates are xywh normalized
|
||||||
|
if 'confidence' in y:
|
||||||
|
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
||||||
|
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
||||||
|
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
||||||
|
else:
|
||||||
|
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
||||||
|
elif self.paddle: # PaddlePaddle
|
||||||
|
im = im.cpu().numpy().astype(np.float32)
|
||||||
|
self.input_handle.copy_from_cpu(im)
|
||||||
|
self.predictor.run()
|
||||||
|
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
||||||
|
elif self.triton: # NVIDIA Triton Inference Server
|
||||||
|
y = self.model(im)
|
||||||
|
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||||
|
im = im.cpu().numpy()
|
||||||
|
if self.saved_model: # SavedModel
|
||||||
|
y = self.model(im, training=False) if self.keras else self.model(im)
|
||||||
|
elif self.pb: # GraphDef
|
||||||
|
y = self.frozen_func(x=self.tf.constant(im))
|
||||||
|
else: # Lite or Edge TPU
|
||||||
|
input = self.input_details[0]
|
||||||
|
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
|
||||||
|
if int8:
|
||||||
|
scale, zero_point = input['quantization']
|
||||||
|
im = (im / scale + zero_point).astype(np.uint8) # de-scale
|
||||||
|
self.interpreter.set_tensor(input['index'], im)
|
||||||
|
self.interpreter.invoke()
|
||||||
|
y = []
|
||||||
|
for output in self.output_details:
|
||||||
|
x = self.interpreter.get_tensor(output['index'])
|
||||||
|
if int8:
|
||||||
|
scale, zero_point = output['quantization']
|
||||||
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||||
|
y.append(x)
|
||||||
|
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
||||||
|
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
||||||
|
|
||||||
|
if isinstance(y, (list, tuple)):
|
||||||
|
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
||||||
|
else:
|
||||||
|
return self.from_numpy(y)
|
||||||
|
|
||||||
|
def from_numpy(self, x):
|
||||||
|
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
||||||
|
|
||||||
|
def warmup(self, imgsz=(1, 3, 640, 640)):
|
||||||
|
# Warmup model by running inference once
|
||||||
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
|
||||||
|
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
||||||
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||||
|
for _ in range(2 if self.jit else 1): #
|
||||||
|
self.forward(im) # warmup
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _model_type(p='path/to/model.pt'):
|
||||||
|
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||||
|
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
||||||
|
from utils.export import export_formats
|
||||||
|
from utils.downloads import is_url
|
||||||
|
sf = list(export_formats().Suffix) # export suffixes
|
||||||
|
if not is_url(p, check=False):
|
||||||
|
check_suffix(p, sf) # checks
|
||||||
|
url = urlparse(p) # if url may be Triton inference server
|
||||||
|
types = [s in Path(p).name for s in sf]
|
||||||
|
types[8] &= not types[9] # tflite &= not edgetpu
|
||||||
|
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
||||||
|
return types + [triton]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_metadata(f=Path('path/to/meta.yaml')):
|
||||||
|
# Load metadata from meta.yaml if it exists
|
||||||
|
if f.exists():
|
||||||
|
d = yaml_load(f)
|
||||||
|
return d['stride'], d['names'] # assign stride, names
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
class AutoShape(nn.Module):
|
||||||
|
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
||||||
|
conf = 0.25 # NMS confidence threshold
|
||||||
|
iou = 0.45 # NMS IoU threshold
|
||||||
|
agnostic = False # NMS class-agnostic
|
||||||
|
multi_label = False # NMS multiple labels per box
|
||||||
|
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
||||||
|
max_det = 1000 # maximum number of detections per image
|
||||||
|
amp = False # Automatic Mixed Precision (AMP) inference
|
||||||
|
|
||||||
|
def __init__(self, model, verbose=True):
|
||||||
|
super().__init__()
|
||||||
|
if verbose:
|
||||||
|
LOGGER.info('Adding AutoShape... ')
|
||||||
|
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
||||||
|
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
||||||
|
self.pt = not self.dmb or model.pt # PyTorch model
|
||||||
|
self.model = model.eval()
|
||||||
|
if self.pt:
|
||||||
|
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
||||||
|
m.inplace = False # Detect.inplace=False for safe multithread inference
|
||||||
|
m.export = True # do not output loss values
|
||||||
|
|
||||||
|
def _apply(self, fn):
|
||||||
|
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||||
|
self = super()._apply(fn)
|
||||||
|
if self.pt:
|
||||||
|
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
||||||
|
m.stride = fn(m.stride)
|
||||||
|
m.grid = list(map(fn, m.grid))
|
||||||
|
if isinstance(m.anchor_grid, list):
|
||||||
|
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||||
|
return self
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
|
def forward(self, ims, size=640, augment=False, profile=False):
|
||||||
|
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
||||||
|
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
||||||
|
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
||||||
|
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
||||||
|
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
||||||
|
# numpy: = np.zeros((640,1280,3)) # HWC
|
||||||
|
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
||||||
|
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
||||||
|
|
||||||
|
dt = (Profile(), Profile(), Profile())
|
||||||
|
with dt[0]:
|
||||||
|
if isinstance(size, int): # expand
|
||||||
|
size = (size, size)
|
||||||
|
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
||||||
|
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
||||||
|
if isinstance(ims, torch.Tensor): # torch
|
||||||
|
with amp.autocast(autocast):
|
||||||
|
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
||||||
|
|
||||||
|
# Pre-process
|
||||||
|
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
||||||
|
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
||||||
|
for i, im in enumerate(ims):
|
||||||
|
f = f'image{i}' # filename
|
||||||
|
if isinstance(im, (str, Path)): # filename or uri
|
||||||
|
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
||||||
|
im = np.asarray(exif_transpose(im))
|
||||||
|
elif isinstance(im, Image.Image): # PIL Image
|
||||||
|
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
||||||
|
files.append(Path(f).with_suffix('.jpg').name)
|
||||||
|
if im.shape[0] < 5: # image in CHW
|
||||||
|
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
||||||
|
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
||||||
|
s = im.shape[:2] # HWC
|
||||||
|
shape0.append(s) # image shape
|
||||||
|
g = max(size) / max(s) # gain
|
||||||
|
shape1.append([int(y * g) for y in s])
|
||||||
|
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||||
|
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
|
||||||
|
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
|
||||||
|
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||||
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||||
|
|
||||||
|
with amp.autocast(autocast):
|
||||||
|
# Inference
|
||||||
|
with dt[1]:
|
||||||
|
y = self.model(x, augment=augment) # forward
|
||||||
|
|
||||||
|
# Post-process
|
||||||
|
with dt[2]:
|
||||||
|
y = non_max_suppression(y if self.dmb else y[0],
|
||||||
|
self.conf,
|
||||||
|
self.iou,
|
||||||
|
self.classes,
|
||||||
|
self.agnostic,
|
||||||
|
self.multi_label,
|
||||||
|
max_det=self.max_det) # NMS
|
||||||
|
for i in range(n):
|
||||||
|
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
||||||
|
|
||||||
|
return Detections(ims, y, files, dt, self.names, x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Detections:
|
||||||
|
# YOLOv5 detections class for inference results
|
||||||
|
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
||||||
|
super().__init__()
|
||||||
|
d = pred[0].device # device
|
||||||
|
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
||||||
|
self.ims = ims # list of images as numpy arrays
|
||||||
|
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
||||||
|
self.names = names # class names
|
||||||
|
self.files = files # image filenames
|
||||||
|
self.times = times # profiling times
|
||||||
|
self.xyxy = pred # xyxy pixels
|
||||||
|
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
||||||
|
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
||||||
|
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||||
|
self.n = len(self.pred) # number of images (batch size)
|
||||||
|
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
||||||
|
self.s = tuple(shape) # inference BCHW shape
|
||||||
|
|
||||||
|
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
||||||
|
s, crops = '', []
|
||||||
|
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
||||||
|
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
||||||
|
if pred.shape[0]:
|
||||||
|
for c in pred[:, -1].unique():
|
||||||
|
n = (pred[:, -1] == c).sum() # detections per class
|
||||||
|
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
||||||
|
s = s.rstrip(', ')
|
||||||
|
if show or save or render or crop:
|
||||||
|
annotator = Annotator(im, example=str(self.names))
|
||||||
|
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
|
||||||
|
label = f'{self.names[int(cls)]} {conf:.2f}'
|
||||||
|
if crop:
|
||||||
|
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
|
||||||
|
crops.append({
|
||||||
|
'box': box,
|
||||||
|
'conf': conf,
|
||||||
|
'cls': cls,
|
||||||
|
'label': label,
|
||||||
|
'im': save_one_box(box, im, file=file, save=save)})
|
||||||
|
else: # all others
|
||||||
|
annotator.box_label(box, label if labels else '', color=colors(cls))
|
||||||
|
im = annotator.im
|
||||||
|
else:
|
||||||
|
s += '(no detections)'
|
||||||
|
|
||||||
|
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
||||||
|
if show:
|
||||||
|
if is_jupyter():
|
||||||
|
from IPython.display import display
|
||||||
|
display(im)
|
||||||
|
else:
|
||||||
|
im.show(self.files[i])
|
||||||
|
if save:
|
||||||
|
f = self.files[i]
|
||||||
|
im.save(save_dir / f) # save
|
||||||
|
if i == self.n - 1:
|
||||||
|
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
||||||
|
if render:
|
||||||
|
self.ims[i] = np.asarray(im)
|
||||||
|
if pprint:
|
||||||
|
s = s.lstrip('\n')
|
||||||
|
return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
|
||||||
|
if crop:
|
||||||
|
if save:
|
||||||
|
LOGGER.info(f'Saved results to {save_dir}\n')
|
||||||
|
return crops
|
||||||
|
|
||||||
|
@TryExcept('Showing images is not supported in this environment')
|
||||||
|
def show(self, labels=True):
|
||||||
|
self._run(show=True, labels=labels) # show results
|
||||||
|
|
||||||
|
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
|
||||||
|
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
|
||||||
|
self._run(save=True, labels=labels, save_dir=save_dir) # save results
|
||||||
|
|
||||||
|
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
|
||||||
|
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
|
||||||
|
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
|
||||||
|
|
||||||
|
def render(self, labels=True):
|
||||||
|
self._run(render=True, labels=labels) # render results
|
||||||
|
return self.ims
|
||||||
|
|
||||||
|
def pandas(self):
|
||||||
|
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
||||||
|
new = copy(self) # return copy
|
||||||
|
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
||||||
|
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
||||||
|
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
||||||
|
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
||||||
|
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
||||||
|
return new
|
||||||
|
|
||||||
|
def tolist(self):
|
||||||
|
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
||||||
|
r = range(self.n) # iterable
|
||||||
|
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
||||||
|
# for d in x:
|
||||||
|
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
||||||
|
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
||||||
|
return x
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
LOGGER.info(self.__str__())
|
||||||
|
|
||||||
|
def __len__(self): # override len(results)
|
||||||
|
return self.n
|
||||||
|
|
||||||
|
def __str__(self): # override print(results)
|
||||||
|
return self._run(pprint=True) # print results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'YOLOv5 {self.__class__} instance\n' + self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class Proto(nn.Module):
|
||||||
|
# YOLOv5 mask Proto module for segmentation models
|
||||||
|
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
||||||
|
super().__init__()
|
||||||
|
self.cv1 = Conv(c1, c_, k=3)
|
||||||
|
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
||||||
|
self.cv2 = Conv(c_, c_, k=3)
|
||||||
|
self.cv3 = Conv(c_, c2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class Classify(nn.Module):
|
||||||
|
# YOLOv5 classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
||||||
|
def __init__(self,
|
||||||
|
c1,
|
||||||
|
c2,
|
||||||
|
k=1,
|
||||||
|
s=1,
|
||||||
|
p=None,
|
||||||
|
g=1,
|
||||||
|
dropout_p=0.0): # ch_in, ch_out, kernel, stride, padding, groups, dropout probability
|
||||||
|
super().__init__()
|
||||||
|
c_ = 1280 # efficientnet_b0 size
|
||||||
|
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
||||||
|
self.drop = nn.Dropout(p=dropout_p, inplace=True)
|
||||||
|
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if isinstance(x, list):
|
||||||
|
x = torch.cat(x, 1)
|
||||||
|
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
111
ytracking/models/experimental.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Experimental modules
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from utils.downloads import attempt_download
|
||||||
|
|
||||||
|
|
||||||
|
class Sum(nn.Module):
|
||||||
|
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||||
|
def __init__(self, n, weight=False): # n: number of inputs
|
||||||
|
super().__init__()
|
||||||
|
self.weight = weight # apply weights boolean
|
||||||
|
self.iter = range(n - 1) # iter object
|
||||||
|
if weight:
|
||||||
|
self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x[0] # no weight
|
||||||
|
if self.weight:
|
||||||
|
w = torch.sigmoid(self.w) * 2
|
||||||
|
for i in self.iter:
|
||||||
|
y = y + x[i + 1] * w[i]
|
||||||
|
else:
|
||||||
|
for i in self.iter:
|
||||||
|
y = y + x[i + 1]
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class MixConv2d(nn.Module):
|
||||||
|
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
|
||||||
|
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
|
||||||
|
super().__init__()
|
||||||
|
n = len(k) # number of convolutions
|
||||||
|
if equal_ch: # equal c_ per group
|
||||||
|
i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
|
||||||
|
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
|
||||||
|
else: # equal weight.numel() per group
|
||||||
|
b = [c2] + [0] * n
|
||||||
|
a = np.eye(n + 1, n, k=-1)
|
||||||
|
a -= np.roll(a, 1, axis=1)
|
||||||
|
a *= np.array(k) ** 2
|
||||||
|
a[0] = 1
|
||||||
|
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
|
||||||
|
|
||||||
|
self.m = nn.ModuleList([
|
||||||
|
nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
|
||||||
|
self.bn = nn.BatchNorm2d(c2)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
||||||
|
|
||||||
|
|
||||||
|
class Ensemble(nn.ModuleList):
|
||||||
|
# Ensemble of models
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||||
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
||||||
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
||||||
|
# y = torch.stack(y).mean(0) # mean ensemble
|
||||||
|
y = torch.cat(y, 1) # nms ensemble
|
||||||
|
return y, None # inference, train output
|
||||||
|
|
||||||
|
|
||||||
|
def attempt_load(weights, device=None, inplace=True, fuse=True):
|
||||||
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
|
from models.yolo import Detect, Model
|
||||||
|
|
||||||
|
model = Ensemble()
|
||||||
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||||
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
|
# Model compatibility updates
|
||||||
|
if not hasattr(ckpt, 'stride'):
|
||||||
|
ckpt.stride = torch.tensor([32.])
|
||||||
|
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
||||||
|
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||||
|
|
||||||
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||||
|
|
||||||
|
# Module updates
|
||||||
|
for m in model.modules():
|
||||||
|
t = type(m)
|
||||||
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
||||||
|
m.inplace = inplace
|
||||||
|
if t is Detect and not isinstance(m.anchor_grid, list):
|
||||||
|
delattr(m, 'anchor_grid')
|
||||||
|
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
||||||
|
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||||
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
|
# Return model
|
||||||
|
if len(model) == 1:
|
||||||
|
return model[-1]
|
||||||
|
|
||||||
|
# Return detection ensemble
|
||||||
|
print(f'Ensemble created with {weights}\n')
|
||||||
|
for k in 'names', 'nc', 'yaml':
|
||||||
|
setattr(model, k, getattr(model[0], k))
|
||||||
|
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
||||||
|
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
||||||
|
return model
|
59
ytracking/models/hub/anchors.yaml
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||||
|
# Default anchors for COCO data
|
||||||
|
|
||||||
|
|
||||||
|
# P5 -------------------------------------------------------------------------------------------------------------------
|
||||||
|
# P5-640:
|
||||||
|
anchors_p5_640:
|
||||||
|
- [10,13, 16,30, 33,23] # P3/8
|
||||||
|
- [30,61, 62,45, 59,119] # P4/16
|
||||||
|
- [116,90, 156,198, 373,326] # P5/32
|
||||||
|
|
||||||
|
|
||||||
|
# P6 -------------------------------------------------------------------------------------------------------------------
|
||||||
|
# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387
|
||||||
|
anchors_p6_640:
|
||||||
|
- [9,11, 21,19, 17,41] # P3/8
|
||||||
|
- [43,32, 39,70, 86,64] # P4/16
|
||||||
|
- [65,131, 134,130, 120,265] # P5/32
|
||||||
|
- [282,180, 247,354, 512,387] # P6/64
|
||||||
|
|
||||||
|
# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792
|
||||||
|
anchors_p6_1280:
|
||||||
|
- [19,27, 44,40, 38,94] # P3/8
|
||||||
|
- [96,68, 86,152, 180,137] # P4/16
|
||||||
|
- [140,301, 303,264, 238,542] # P5/32
|
||||||
|
- [436,615, 739,380, 925,792] # P6/64
|
||||||
|
|
||||||
|
# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187
|
||||||
|
anchors_p6_1920:
|
||||||
|
- [28,41, 67,59, 57,141] # P3/8
|
||||||
|
- [144,103, 129,227, 270,205] # P4/16
|
||||||
|
- [209,452, 455,396, 358,812] # P5/32
|
||||||
|
- [653,922, 1109,570, 1387,1187] # P6/64
|
||||||
|
|
||||||
|
|
||||||
|
# P7 -------------------------------------------------------------------------------------------------------------------
|
||||||
|
# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372
|
||||||
|
anchors_p7_640:
|
||||||
|
- [11,11, 13,30, 29,20] # P3/8
|
||||||
|
- [30,46, 61,38, 39,92] # P4/16
|
||||||
|
- [78,80, 146,66, 79,163] # P5/32
|
||||||
|
- [149,150, 321,143, 157,303] # P6/64
|
||||||
|
- [257,402, 359,290, 524,372] # P7/128
|
||||||
|
|
||||||
|
# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818
|
||||||
|
anchors_p7_1280:
|
||||||
|
- [19,22, 54,36, 32,77] # P3/8
|
||||||
|
- [70,83, 138,71, 75,173] # P4/16
|
||||||
|
- [165,159, 148,334, 375,151] # P5/32
|
||||||
|
- [334,317, 251,626, 499,474] # P6/64
|
||||||
|
- [750,326, 534,814, 1079,818] # P7/128
|
||||||
|
|
||||||
|
# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227
|
||||||
|
anchors_p7_1920:
|
||||||
|
- [29,34, 81,55, 47,115] # P3/8
|
||||||
|
- [105,124, 207,107, 113,259] # P4/16
|
||||||
|
- [247,238, 222,500, 563,227] # P5/32
|
||||||
|
- [501,476, 376,939, 749,711] # P6/64
|
||||||
|
- [1126,489, 801,1222, 1618,1227] # P7/128
|
51
ytracking/models/hub/yolov3-spp.yaml
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 80 # number of classes
|
||||||
|
depth_multiple: 1.0 # model depth multiple
|
||||||
|
width_multiple: 1.0 # layer channel multiple
|
||||||
|
anchors:
|
||||||
|
- [10,13, 16,30, 33,23] # P3/8
|
||||||
|
- [30,61, 62,45, 59,119] # P4/16
|
||||||
|
- [116,90, 156,198, 373,326] # P5/32
|
||||||
|
|
||||||
|
# darknet53 backbone
|
||||||
|
backbone:
|
||||||
|
# [from, number, module, args]
|
||||||
|
[[-1, 1, Conv, [32, 3, 1]], # 0
|
||||||
|
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
||||||
|
[-1, 1, Bottleneck, [64]],
|
||||||
|
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
||||||
|
[-1, 2, Bottleneck, [128]],
|
||||||
|
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
||||||
|
[-1, 8, Bottleneck, [256]],
|
||||||
|
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
||||||
|
[-1, 8, Bottleneck, [512]],
|
||||||
|
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
||||||
|
[-1, 4, Bottleneck, [1024]], # 10
|
||||||
|
]
|
||||||
|
|
||||||
|
# YOLOv3-SPP head
|
||||||
|
head:
|
||||||
|
[[-1, 1, Bottleneck, [1024, False]],
|
||||||
|
[-1, 1, SPP, [512, [5, 9, 13]]],
|
||||||
|
[-1, 1, Conv, [1024, 3, 1]],
|
||||||
|
[-1, 1, Conv, [512, 1, 1]],
|
||||||
|
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
||||||
|
|
||||||
|
[-2, 1, Conv, [256, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
||||||
|
[-1, 1, Bottleneck, [512, False]],
|
||||||
|
[-1, 1, Bottleneck, [512, False]],
|
||||||
|
[-1, 1, Conv, [256, 1, 1]],
|
||||||
|
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
||||||
|
|
||||||
|
[-2, 1, Conv, [128, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
||||||
|
[-1, 1, Bottleneck, [256, False]],
|
||||||
|
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
||||||
|
|
||||||
|
[[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||||
|
]
|