This commit is contained in:
lee
2024-11-27 15:37:10 +08:00
commit 3a5214c796
696 changed files with 56947 additions and 0 deletions

12
contrast/.idea/contrast_nettest.iml generated Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="GOOGLE" />
<option name="myDocStringFormat" value="Google" />
</component>
</module>

View File

@ -0,0 +1,19 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="6">
<item index="0" class="java.lang.String" itemvalue="thop" />
<item index="1" class="java.lang.String" itemvalue="regex" />
<item index="2" class="java.lang.String" itemvalue="tensorboardX" />
<item index="3" class="java.lang.String" itemvalue="torch" />
<item index="4" class="java.lang.String" itemvalue="numpy" />
<item index="5" class="java.lang.String" itemvalue="terminaltables" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

4
contrast/.idea/misc.xml generated Normal file
View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project>

8
contrast/.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/contrast_nettest.iml" filepath="$PROJECT_DIR$/.idea/contrast_nettest.iml" />
</modules>
</component>
</project>

6
contrast/.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

60
contrast/README.md Normal file
View File

@ -0,0 +1,60 @@
# Build Your Own Face Recognition Model
训练你自己的人脸识别模型!
人脸识别从原始的 Softmax Embbedding经过2015年 Facenet 领衔的 triple loss metric learning然后是 additional margin metric learning。这次的系列博客实现的是2018年提出的 ArcFace 。
### 依赖
```py
Python >= 3.6
pytorch >= 1.0
torchvision
imutils
pillow == 6.2.0
tqdm
```
### 数据准备
+ 下载WebFace百度一下以及干净的图片列表[BaiduYun](http://pan.baidu.com/s/1hrKpbm8))用于训练
+ 下载LFW[BaiduYun](https://pan.baidu.com/s/12IKEpvM8-tYgSaUiz_adGA) 提取码 u7z4以及[测试列表](https://github.com/ronghuaiyang/arcface-pytorch/blob/master/lfw_test_pair.txt)用于测试
+ 删除WebFace中的脏数据使用`utils.py`
### 配置参数
`config.py`
### 训练
天然支持单机多GPU训练
```py
export CUDA_VISIBLE_DEVICES=0,1
python train.py
```
### 测试
```py
python test.py
```
### 博客
虽然有关人脸识别的介绍已经很多了,但受到许多 [Build-Your-Own-x](https://github.com/danistefanovic/build-your-own-x) 文章的启发,就想写一个 Build Your Own Face Model 的博客,愿于他人有益。
+ 001 [数据准备](./blog/data.md)
+ 002 [模型架构](./blog/model.md)
+ 003 [损失函数](./blog/loss.md)
+ 004 [度量函数](./blog/metric.md)
+ 005 [训练](./blog/train.md)
+ 006 [测试](./blog/test.md)
### 致谢
虽然并未注明,但本项目中有一些代码直接复制或者修改自以下仓库,许可证与之相同:
+ [insightFace](https://github.com/deepinsight/insightface/tree/master/recognition)
+ [insightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch)
+ [arcface-pytorch](https://github.com/ronghuaiyang/arcface-pytorch)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

21
contrast/config.py.bak Normal file
View File

@ -0,0 +1,21 @@
import torch
import torchvision.transforms as T
class Config:
host = "192.168.1.28"
port = "19530"
embedding_size = 256
img_size = 224
test_transform = T.Compose([
T.ToTensor(),
T.Resize((img_size, img_size)),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[0.5], std=[0.5]),
])
# test_model = "checkpoints/resnet18_our388.pth"
test_model = "checkpoints/mobilenetv3Large_our388_noPara.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config()

21
contrast/dataset.py Normal file
View File

@ -0,0 +1,21 @@
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from config import config as conf
def load_data(conf, training=True):
if training:
dataroot = conf.train_root
transform = conf.train_transform
batch_size = conf.train_batch_size
else:
dataroot = conf.test_root
transform = conf.test_transform
batch_size = conf.test_batch_size
data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes)
loader = DataLoader(data, batch_size=batch_size, shuffle=True,
pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
return loader, class_num

129
contrast/img_data.py Normal file

File diff suppressed because one or more lines are too long

66
contrast/logic.py Normal file
View File

@ -0,0 +1,66 @@
import sys
import torch
from tools.config import cfg as conf
sys.path.append('contrast')
# from config import config as conf
# from model import resnet18, MobileNetV3_Large
from test_logic import similarity_interface
from img_data import queueImgs_add
# import pymilvus
class datacollection:
barcode_flag = None
add_flag = None
queImgsDict = None
mainMilvus = None
tempLibList = None
model = None
barcode_list = None
actionModel = True # 是否是运行模式, False是测试模式 True是运行模式
class similarityResult:
top10 = None
top1 = None
tempLibList = None
topn = None
class similarity:
def __init__(self):
pass
def getSimilarity(self, model, dataCollection, similarityRes):
dataCollection.mainMilvus = model.milvusModel
dataCollection.model = model.similarityModel
# try:
if dataCollection.add_flag:
if dataCollection.barcode_flag: # 加购 有barcode -> 输出top10和top1
similarityRes.top10, similarityRes.top1, similarityRes.tempLibList = similarity_interface(
dataCollection)
print(f"top10: {similarityRes.top10}\ntop1: {similarityRes.top1}")
else: # 加购 无barcode -> 输出top10
similarityRes.top10, similarityRes.tempLibList = similarity_interface(dataCollection)
else: # 退购 -> 输出top10和topn
if dataCollection.barcode_flag:
similarityRes.top10, similarityRes.top1, similarityRes.topn = similarity_interface(dataCollection)
else:
similarityRes.top10, similarityRes.topn = similarity_interface(dataCollection)
return similarityRes
# except pymilvus.exceptions.SchemaNotReadyException as SchemaNotReadyException: ###当前特征库不存在
# print('pymilvus.exceptions.SchemaNotReadyException', SchemaNotReadyException)
def main():
data_collection = datacollection()
similarityRes = similarityResult()
data_collection.barcode_flag = queueImgs_add['barcode_flag']
data_collection.add_flag = queueImgs_add['add_flag']
data_collection.queImgsDict = queueImgs_add
similarity().getSimilarity(data_collection, similarityRes)
if __name__ == '__main__':
main()

1894
contrast/main_barcodes.json Normal file

File diff suppressed because it is too large Load Diff

17
contrast/main_library.py Normal file
View File

@ -0,0 +1,17 @@
"""搭建主特征库"""
from test_logic import create_milvus, img2feature
from config import config as conf
from img_data import library_imgs, temp_imgs
def createMainMilvus(imgs_dict): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
barcode_list, imgs_list = img2feature(imgs_dict)
mainMilvus = create_milvus('main_features', conf.host, conf.port, barcode_list, imgs_list)
return mainMilvus
def createTempMilvus(imgs_dict): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
barcode_list, imgs_list = img2feature(imgs_dict)
tempMilvus = create_milvus('temp_features', conf.host, conf.port, barcode_list, imgs_list)
return tempMilvus
if __name__ == "__main__":
createMainMilvus(library_imgs)
# createTempMilvus(temp_imgs)

View File

@ -0,0 +1,2 @@
from .resnet_pre import resnet18
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,200 @@
'''MobileNetV3 in PyTorch.
See the paper "Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation" for more details.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from tools.config import config as conf
class hswish(nn.Module):
def forward(self, x):
out = x * F.relu6(x + 3, inplace=True) / 6
return out
class hsigmoid(nn.Module):
def forward(self, x):
out = F.relu6(x + 3, inplace=True) / 6
return out
class SeModule(nn.Module):
def __init__(self, in_size, reduction=4):
super(SeModule, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(in_size // reduction),
nn.ReLU(inplace=True),
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(in_size),
hsigmoid()
)
def forward(self, x):
return x * self.se(x)
class Block(nn.Module):
'''expand + depthwise + pointwise'''
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
super(Block, self).__init__()
self.stride = stride
self.se = semodule
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(expand_size)
self.nolinear1 = nolinear
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
self.bn2 = nn.BatchNorm2d(expand_size)
self.nolinear2 = nolinear
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = nn.BatchNorm2d(out_size)
self.shortcut = nn.Sequential()
if stride == 1 and in_size != out_size:
self.shortcut = nn.Sequential(
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_size),
)
def forward(self, x):
out = self.nolinear1(self.bn1(self.conv1(x)))
out = self.nolinear2(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
if self.se != None:
out = self.se(out)
out = out + self.shortcut(x) if self.stride==1 else out
return out
class MobileNetV3_Large(nn.Module):
def __init__(self, num_classes=conf.embedding_size):
super(MobileNetV3_Large, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.hs1 = hswish()
self.bneck = nn.Sequential(
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
Block(3, 40, 240, 80, hswish(), None, 2),
Block(3, 80, 200, 80, hswish(), None, 1),
Block(3, 80, 184, 80, hswish(), None, 1),
Block(3, 80, 184, 80, hswish(), None, 1),
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
)
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(960)
self.hs2 = hswish()
self.linear3 = nn.Linear(960, 1280)
self.bn3 = nn.BatchNorm1d(1280)
self.hs3 = hswish()
self.linear4 = nn.Linear(1280, num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
out = self.hs1(self.bn1(self.conv1(x)))
out = self.bneck(out)
out = self.hs2(self.bn2(self.conv2(out)))
out = F.avg_pool2d(out, conf.img_size // 32)
out = out.view(out.size(0), -1)
out = self.hs3(self.bn3(self.linear3(out)))
out = self.linear4(out)
return out
class MobileNetV3_Small(nn.Module):
def __init__(self, num_classes=conf.embedding_size):
super(MobileNetV3_Small, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.hs1 = hswish()
self.bneck = nn.Sequential(
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
)
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(576)
self.hs2 = hswish()
self.linear3 = nn.Linear(576, 1280)
self.bn3 = nn.BatchNorm1d(1280)
self.hs3 = hswish()
self.linear4 = nn.Linear(1280, num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
out = self.hs1(self.bn1(self.conv1(x)))
out = self.bneck(out)
out = self.hs2(self.bn2(self.conv2(out)))
out = F.avg_pool2d(out, conf.img_size // 32)
out = out.view(out.size(0), -1)
out = self.hs3(self.bn3(self.linear3(out)))
out = self.linear4(out)
return out
def test():
net = MobileNetV3_Small()
x = torch.randn(2,3,224,224)
y = net(x)
print(y.size())
# test()

View File

@ -0,0 +1,462 @@
import torch
import torch.nn as nn
from tools.config import config as conf
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
# from .utils import load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.cam = cam
self.bam = bam
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
if self.cam:
if planes == 64:
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
elif planes == 128:
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
elif planes == 256:
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
elif planes == 512:
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
self.sigmod = nn.Sigmoid()
if self.bam:
self.bam = SpatialAttention()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
if self.cam:
ori_out = self.globalAvgPool(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmod(out)
out = out.view(out.size(0), out.size(-1), 1, 1)
out = out * ori_out
if self.bam:
out = out*self.bam(out)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.cam = cam
self.bam = bam
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
if self.cam:
if planes == 64:
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
elif planes == 128:
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
elif planes == 256:
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
elif planes == 512:
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
self.sigmod = nn.Sigmoid()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
if self.cam:
ori_out = self.globalAvgPool(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmod(out)
out = out.view(out.size(0), out.size(-1), 1, 1)
out = out * ori_out
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, scale=0.75):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, int(64*scale), layers[0])
self.layer2 = self._make_layer(block, int(128*scale), layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, int(256*scale), layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, int(512*scale), layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(int(512 * block.expansion*scale), num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# print('poolBefore', x.shape)
x = self.avgpool(x)
# print('poolAfter', x.shape)
x = torch.flatten(x, 1)
# print('fcBefore',x.shape)
x = self.fc(x)
# print('fcAfter',x.shape)
return x
def forward(self, x):
return self._forward_impl(x)
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
# model = ResNet(block, layers, **kwargs)
# if pretrained:
# state_dict = load_state_dict_from_url(model_urls[arch],
# progress=progress)
# model.load_state_dict(state_dict, strict=False)
# return model
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
src_state_dict = state_dict
target_state_dict = model.state_dict()
skip_keys = []
# skip mismatch size tensors in case of pretraining
for k in src_state_dict.keys():
if k not in target_state_dict:
continue
if src_state_dict[k].size() != target_state_dict[k].size():
skip_keys.append(k)
for k in skip_keys:
del src_state_dict[k]
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
return model
def resnet14(pretrained=True, progress=True, **kwargs):
r"""ResNet-14 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
**kwargs)
def resnet18(pretrained=True, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)

101
contrast/search.py Normal file
View File

@ -0,0 +1,101 @@
import pdb
class ImgSearch():
def __init__(self):
self.search_params = {
"metric_type": "COSINE",
}
def get_max(self, a, b):
if a > b:
return a
else:
return b
def check_keys(self, dict1_last, dict2):
for key2 in list(dict2.keys()):
if key2 in list(dict1_last.keys()):
value = self.get_max(dict1_last[key2], dict2[key2])
dict1_last[key2] = value
else:
dict1_last[key2] = dict2[key2]
return dict1_last
def result_analysis(self, result, top1_flag=False):
result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典
for hits in result:
for hit in hits:
if not hit.id in result_dict: ## barcodehit.id不在结果字典中
result_dict.update({hit.id: round(hit.distance, 2)})
else: ## 将同一barcode相似度保存较高的
distance = result_dict.get(hit.id)
distance_new = self.get_max(distance, hit.distance)
result_dict.update({hit.id: round(distance_new, 2)})
if top1_flag:
return result_dict
else:
## 将所有barcode相似度结果排序存储
if len(result_dict) > 10:
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10])
else:
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True))
return result_sort_dict
def result_update(self, temp_result, last_result):
temp_keys = list(temp_result.keys())
last_keys = list(last_result.keys())
for ke in temp_keys:
temp_value = temp_result[ke]
if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcodetrack_id2中barcode相似度高才更新
last_value = last_result[ke]
if temp_value > last_value:
last_result.update({ke: temp_value})
else: ## track_id1的结果和track_id2的结果无公共barcode
last_result.update({ke: temp_value})
return last_result
def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id
result_last = dict()
for i in range(len(queBarIdList)):
vectorsSearch = queueFeatures[i]
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10)
result_sort_dic = self.result_analysis(result)
result_last.update({queBarIdList[i]: result_sort_dic})
return result_last
def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId):
newBarList = []
### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn]
for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode
if len(bar.split('_')) == 3:
mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1]
if mac_barcode in barIdList:
newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm]
if len(newBarList) == 0:
return {}
else:
expr = f"pk in {newBarList}"
result_last = dict()
for i in range(len(queueList)):
vectorsSearch = queueFeatures[i]
result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params,
limit=len(newBarList))
result_sort_dic = self.result_analysis(result)
result_last.update({queueList[i]: result_sort_dic})
return result_last
def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId
result_last = dict()
for i in range(len(queBarIdList)):
pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId
vectorsSearch = queFeatures[i]
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'",
param=self.search_params, limit=1)
result_dic = self.result_analysis(result, top1_flag=True)
if (len(result_dic) != 0) and (len(result_last) != 0):
result_last = self.result_update(result_dic, result_last)
else:
result_last.update({key: value for key, value in result_dic.items()})
if len(result_last) == 0:
pk_barcode = queBarIdList[0].split('_')[1]
result_last.update({pk_barcode: 0})
return result_last

317
contrast/test_logic.py Normal file
View File

@ -0,0 +1,317 @@
# -*- coding: utf-8 -*-
import pdb
import random
import json
import time
import torch
from PIL import Image
from contrast.model import resnet18, MobileNetV3_Large
# import pymilvus
# from pymilvus import (
# connections,
# utility,
# FieldSchema, CollectionSchema, DataType,
# Collection,
# Milvus
# )
# from config import config as conf
from contrast.search import ImgSearch
from contrast.img_data import queueImgs_add
import sys
from threading import Thread
sys.path.append('../tools')
from tools.config import cfg as conf
from tools.config import gvalue
def test_preprocess(images: list, actionModel) -> torch.Tensor:
res = []
for img in images:
# print(img)
try:
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(images, model, actionModel):
data = test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf.device)
features = model(data)
return features
def group_image(images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def barcode_state(barcodeIDList):
with open('contrast/main_barcodes.json', 'r') as file:
data = json.load(file)
main_barcode = list(data.values())[0]
barIdList_true = []
barIdList_false = []
for barId in barcodeIDList:
bar = barId.split('_')[1]
if bar in main_barcode:
barIdList_true.append(barId)
else:
barIdList_false.append(barId)
return barIdList_true, barIdList_false
def getFeatureList(barList, imgList, model, actionModel):
featList = [[] for i in range(len(barList))]
for index, feat in enumerate(imgList):
groups = group_image(feat)
for group in groups:
feat_tensor = inference(group, model, actionModel)
for fe in feat_tensor:
if fe.device == 'cpu':
fe_np = fe.squeeze().detach().numpy()
else:
fe_np = fe.squeeze().detach().cpu().numpy()
featList[index].append(fe_np)
return featList
def img2feature(imgs_dict, model, actionModel, barcode_flag):
if not len(imgs_dict) > 0:
raise ValueError("Tracking fail no images files provided")
queBarIdList = list(imgs_dict.keys())
if barcode_flag:
# # ========判断barcode是否在特征库============
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
queFeatList_t = []
if len(queBarIdList_t) == 0:
print(f"All barcodes are not in the main_library: {barIdList_f}")
return queBarIdList_t, queFeatList_t
else:
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
print(f"These barcodes are not in the main_library: {barIdList_f}")
for bar_f in barIdList_f:
del imgs_dict[bar_f]
queImgList_t = list(imgs_dict.values())
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
return queBarIdList_t, queFeatList_t
else:
queImgsList = list(imgs_dict.values())
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
return queBarIdList, queFeatList
# def create_milvus(collection_name, host, port, barcode_list, features):
# # 1. connect to Milvus
# fmt = "\n=== {:30} ===\n"
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
# print(f"Does collection {collection_name} exist in Milvus: {has}")
# # if has: ## 删除collection_name的库
# # utility.drop_collection(collection_name)
#
# # 2. create colllection
# fields = [
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
# ]
# schema = CollectionSchema(fields)
# print(fmt.format(f"Create collection {collection_name}"))
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
# # 3. insert data
# for i in range(len(features)):
# entities = [
# # provide the pk field because `auto_id` is set to False
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
# features[i],
# ]
# print(fmt.format("Start inserting entities"))
# insert_result = hello_milvus.insert(entities)
# hello_milvus.flush()
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
# return hello_milvus
# def load_collection(collection_name):
# collection = Collection(collection_name)
# # collection.release() ### 将collection从加载状态变成未加载
# # collection.drop_index() ### 删除索引
#
# index_params = {
# "index_type": "IVF_FLAT",
# # "index_type": "IVF_SQ8",
# # "index_type": "GPU_IVF_FLAT",
# "metric_type": "COSINE",
# "params": {
# "nlist": 10000
# }
# }
# #### 准确率低
# # index_params = {
# # "index_type": "IVF_PQ",
# # "metric_type": "COSINE",
# # "params": {
# # "nlist": 99,
# # "m": 2,
# # "nbits": 8
# # }
# # }
# collection.create_index(
# field_name="embeddings",
# index_params=index_params,
# index_name="SQ8"
# )
# collection.load()
# return collection
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
# searchImg = ImgSearch() ## 相似度比较
# # 将输入图片加入临时库
# if add_flag:
# if actionModel:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
# else:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
#
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
# if len(queBarIdList) == 0:
# top10, top1 = {}, {}
# else:
# for bar in queBarIdList:
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
# else:
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
# # 存入临时特征库
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
#
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
# 'host': conf.host,
# 'port': conf.port,
# 'barcode_list': queBarIdList,
# 'features': queBarIdFeatures})
# thread.start()
# start1 = time.time()
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# start2 = time.time()
# print('search top10 time>>>> {}'.format(start2-start1))
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
# start3 = time.time()
# print('search top1 time>>>>> {}'.format(start3-start2))
# return top10, top1, gvalue.tempLibLists
# else: # 加购 无barcode -> 输出top10
# # 无barcode时生成随机数作为字典key值
# queBarIdList_rand = []
# for i in range(len(queBarIdList)):
# random_number = ''.join(random.choices('0123456789', k=10))
# queBarIdList_rand.append(str(random_number))
# # gvalue.tempLibList.append(str(random_number))
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
# else:
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
# 'host': conf.host,
# 'port': conf.port,
# 'barcode_list': queBarIdList_rand,
# 'features': queBarIdFeatures})
# thread.start()
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# # print(f'top10: {top10}')
# return top10, gvalue.tempLibLists
# else: # 退购 -> 输出top10和topn
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
# gvalue.tempLibList = []
# else:
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
# ## 加载临时特征库
# tempMilvusName = "temp_features"
# has = utility.has_collection(tempMilvusName)
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
# tempMilvus = load_collection(tempMilvusName)
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
# if actionModel:
# barcode_list = barcode_list
# else:
# barcode_list = queueImgs_add['barcode_list']
# if actionModel:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
# else:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
# if barcode_flag:
# if len(queBarIdList) == 0:
# top10, top1, top_n = {}, {}, {}
# else:
# start1 = time.time()
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
# start2 = time.time()
# print('search top1 time>>>> {}'.format(start2 - start1))
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# start3 = time.time()
# print('search top10 time>>>> {}'.format(start3 - start2))
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
# return top10, top1, top_n
# else:
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
# # print(f'top10: {top10}, topn: {top_n}')
# return top10, top_n
def similarity_interface(dataCollection):
queImgsDict = dataCollection.queImgsDict
add_flag = dataCollection.add_flag
barcode_flag = dataCollection.barcode_flag
main_milvus = dataCollection.mainMilvus
#tempLibList = dataCollection.tempLibList
model = dataCollection.model
actionModel = dataCollection.actionModel
barcode_list = dataCollection.barcode_list
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
return 0
if __name__ == '__main__':
pass
# connections.connect('default', host=conf.host, port=conf.port)
# # 加载主特征库
# mainMilvusName = "main_features"
# has = utility.has_collection(mainMilvusName)
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
# mainMilvus = Collection(mainMilvusName)
# mainMilvus.load()
# model = initModel()
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
# add_flag = queueImgs_add['add_flag']
# barcode_flag = queueImgs_add['barcode_flag']
# tempLibList = [] # 临时特征库的barcodeId_list
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
# if add_flag:
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
# print(f"top10: {top10}\ntop1: {top1}")
# else: # 加购 无barcode -> 输出top10
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
# else: # 退购 -> 输出top10和topn
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)

64
contrast/utils.py Normal file
View File

@ -0,0 +1,64 @@
"""Train List 训练列表
格式:
ImagePath Label
示例:
/data/WebFace/0124920/003.jpg 10572
/data/WebFace/0124920/012.jpg 10572
/data/WebFace/0124920/020.jpg 10572
"""
import os
import os.path as osp
from imutils import paths
def generate_list(images_directory, saved_name=None):
"""生成数据列表
Args:
images_directory: 人脸数据目录,通常包含多个子文件夹。如
WebFace和LFW的格式
Returns:
data_list: [<路径> <标签>]
"""
subdirs = os.listdir(images_directory)
num_ids = len(subdirs)
data_list = []
for i in range(num_ids):
subdir = osp.join(images_directory, subdirs[i])
files = os.listdir(subdir)
paths = [osp.join(subdir, file) for file in files]
# 添加ID作为其人脸标签
paths_with_Id = [f"{p} {i}\n" for p in paths]
data_list.extend(paths_with_Id)
if saved_name:
with open(saved_name, 'w', encoding='utf-8') as f:
f.writelines(data_list)
return data_list
def transform_clean_list(webface_directory, cleaned_list_path):
"""转换webface的干净列表格式
Args:
webface_directory: WebFace数据目录
cleaned_list_path: cleaned_list.txt路径
Returns:
cleaned_list: 转换后的数据列表
"""
with open(cleaned_list_path, encoding='utf-8') as f:
cleaned_list = f.readlines()
cleaned_list = [p.replace('\\', '/') for p in cleaned_list]
cleaned_list = [osp.join(webface_directory, p) for p in cleaned_list]
return cleaned_list
def remove_dirty_image(webface_directory, cleaned_list):
cleaned_list = set([c.split()[0] for c in cleaned_list])
for p in paths.list_images(webface_directory):
if p not in cleaned_list:
print(f"remove {p}")
os.remove(p)
if __name__ == '__main__':
data = '/data/CASIA-WebFace/'
lst = '/data/cleaned_list.txt'
cleaned_list = transform_clean_list(data, lst)
remove_dirty_image(data, cleaned_list)