update
This commit is contained in:
12
contrast/.idea/contrast_nettest.iml
generated
Normal file
12
contrast/.idea/contrast_nettest.iml
generated
Normal file
@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="GOOGLE" />
|
||||
<option name="myDocStringFormat" value="Google" />
|
||||
</component>
|
||||
</module>
|
19
contrast/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
19
contrast/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,19 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="6">
|
||||
<item index="0" class="java.lang.String" itemvalue="thop" />
|
||||
<item index="1" class="java.lang.String" itemvalue="regex" />
|
||||
<item index="2" class="java.lang.String" itemvalue="tensorboardX" />
|
||||
<item index="3" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="4" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="5" class="java.lang.String" itemvalue="terminaltables" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
6
contrast/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
contrast/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
4
contrast/.idea/misc.xml
generated
Normal file
4
contrast/.idea/misc.xml
generated
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
contrast/.idea/modules.xml
generated
Normal file
8
contrast/.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/contrast_nettest.iml" filepath="$PROJECT_DIR$/.idea/contrast_nettest.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
6
contrast/.idea/vcs.xml
generated
Normal file
6
contrast/.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
60
contrast/README.md
Normal file
60
contrast/README.md
Normal file
@ -0,0 +1,60 @@
|
||||
# Build Your Own Face Recognition Model
|
||||
|
||||
训练你自己的人脸识别模型!
|
||||
|
||||
人脸识别从原始的 Softmax Embbedding,经过2015年 Facenet 领衔的 triple loss metric learning,然后是 additional margin metric learning。这次的系列博客实现的是2018年提出的 ArcFace 。
|
||||
|
||||
|
||||
### 依赖
|
||||
```py
|
||||
Python >= 3.6
|
||||
pytorch >= 1.0
|
||||
torchvision
|
||||
imutils
|
||||
pillow == 6.2.0
|
||||
tqdm
|
||||
```
|
||||
|
||||
### 数据准备
|
||||
|
||||
+ 下载WebFace(百度一下)以及干净的图片列表([BaiduYun](http://pan.baidu.com/s/1hrKpbm8))用于训练
|
||||
+ 下载LFW([BaiduYun](https://pan.baidu.com/s/12IKEpvM8-tYgSaUiz_adGA) 提取码 u7z4)以及[测试列表](https://github.com/ronghuaiyang/arcface-pytorch/blob/master/lfw_test_pair.txt)用于测试
|
||||
+ 删除WebFace中的脏数据,使用`utils.py`
|
||||
|
||||
### 配置参数
|
||||
|
||||
见`config.py`
|
||||
|
||||
### 训练
|
||||
|
||||
天然支持单机多GPU训练
|
||||
|
||||
```py
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
python train.py
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```py
|
||||
python test.py
|
||||
```
|
||||
|
||||
### 博客
|
||||
|
||||
虽然有关人脸识别的介绍已经很多了,但受到许多 [Build-Your-Own-x](https://github.com/danistefanovic/build-your-own-x) 文章的启发,就想写一个 Build Your Own Face Model 的博客,愿于他人有益。
|
||||
|
||||
+ 001 [数据准备](./blog/data.md)
|
||||
+ 002 [模型架构](./blog/model.md)
|
||||
+ 003 [损失函数](./blog/loss.md)
|
||||
+ 004 [度量函数](./blog/metric.md)
|
||||
+ 005 [训练](./blog/train.md)
|
||||
+ 006 [测试](./blog/test.md)
|
||||
|
||||
### 致谢
|
||||
|
||||
虽然并未注明,但本项目中有一些代码直接复制或者修改自以下仓库,许可证与之相同:
|
||||
|
||||
+ [insightFace](https://github.com/deepinsight/insightface/tree/master/recognition)
|
||||
+ [insightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
+ [arcface-pytorch](https://github.com/ronghuaiyang/arcface-pytorch)
|
BIN
contrast/__pycache__/img_data.cpython-310.pyc
Normal file
BIN
contrast/__pycache__/img_data.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/img_data.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/img_data.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/logic.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/logic.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/search.cpython-310.pyc
Normal file
BIN
contrast/__pycache__/search.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/search.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/search.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/test_logic.cpython-310.pyc
Normal file
BIN
contrast/__pycache__/test_logic.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/test_logic.cpython-38.pyc
Normal file
BIN
contrast/__pycache__/test_logic.cpython-38.pyc
Normal file
Binary file not shown.
21
contrast/config.py.bak
Normal file
21
contrast/config.py.bak
Normal file
@ -0,0 +1,21 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
class Config:
|
||||
host = "192.168.1.28"
|
||||
port = "19530"
|
||||
|
||||
embedding_size = 256
|
||||
img_size = 224
|
||||
test_transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size)),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
# test_model = "checkpoints/resnet18_our388.pth"
|
||||
test_model = "checkpoints/mobilenetv3Large_our388_noPara.pth"
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
config = Config()
|
21
contrast/dataset.py
Normal file
21
contrast/dataset.py
Normal file
@ -0,0 +1,21 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
from config import config as conf
|
||||
|
||||
|
||||
def load_data(conf, training=True):
|
||||
if training:
|
||||
dataroot = conf.train_root
|
||||
transform = conf.train_transform
|
||||
batch_size = conf.train_batch_size
|
||||
else:
|
||||
dataroot = conf.test_root
|
||||
transform = conf.test_transform
|
||||
batch_size = conf.test_batch_size
|
||||
|
||||
data = ImageFolder(dataroot, transform=transform)
|
||||
class_num = len(data.classes)
|
||||
loader = DataLoader(data, batch_size=batch_size, shuffle=True,
|
||||
pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
return loader, class_num
|
129
contrast/img_data.py
Normal file
129
contrast/img_data.py
Normal file
File diff suppressed because one or more lines are too long
66
contrast/logic.py
Normal file
66
contrast/logic.py
Normal file
@ -0,0 +1,66 @@
|
||||
import sys
|
||||
import torch
|
||||
|
||||
from tools.config import cfg as conf
|
||||
sys.path.append('contrast')
|
||||
# from config import config as conf
|
||||
# from model import resnet18, MobileNetV3_Large
|
||||
from test_logic import similarity_interface
|
||||
from img_data import queueImgs_add
|
||||
|
||||
# import pymilvus
|
||||
|
||||
|
||||
class datacollection:
|
||||
barcode_flag = None
|
||||
add_flag = None
|
||||
queImgsDict = None
|
||||
mainMilvus = None
|
||||
tempLibList = None
|
||||
model = None
|
||||
barcode_list = None
|
||||
actionModel = True # 是否是运行模式, False是测试模式 True是运行模式
|
||||
|
||||
|
||||
class similarityResult:
|
||||
top10 = None
|
||||
top1 = None
|
||||
tempLibList = None
|
||||
topn = None
|
||||
|
||||
|
||||
class similarity:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def getSimilarity(self, model, dataCollection, similarityRes):
|
||||
dataCollection.mainMilvus = model.milvusModel
|
||||
dataCollection.model = model.similarityModel
|
||||
# try:
|
||||
if dataCollection.add_flag:
|
||||
if dataCollection.barcode_flag: # 加购 有barcode -> 输出top10和top1
|
||||
similarityRes.top10, similarityRes.top1, similarityRes.tempLibList = similarity_interface(
|
||||
dataCollection)
|
||||
print(f"top10: {similarityRes.top10}\ntop1: {similarityRes.top1}")
|
||||
else: # 加购 无barcode -> 输出top10
|
||||
similarityRes.top10, similarityRes.tempLibList = similarity_interface(dataCollection)
|
||||
else: # 退购 -> 输出top10和topn
|
||||
if dataCollection.barcode_flag:
|
||||
similarityRes.top10, similarityRes.top1, similarityRes.topn = similarity_interface(dataCollection)
|
||||
else:
|
||||
similarityRes.top10, similarityRes.topn = similarity_interface(dataCollection)
|
||||
return similarityRes
|
||||
# except pymilvus.exceptions.SchemaNotReadyException as SchemaNotReadyException: ###当前特征库不存在
|
||||
# print('pymilvus.exceptions.SchemaNotReadyException', SchemaNotReadyException)
|
||||
|
||||
def main():
|
||||
data_collection = datacollection()
|
||||
similarityRes = similarityResult()
|
||||
data_collection.barcode_flag = queueImgs_add['barcode_flag']
|
||||
data_collection.add_flag = queueImgs_add['add_flag']
|
||||
data_collection.queImgsDict = queueImgs_add
|
||||
similarity().getSimilarity(data_collection, similarityRes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
1894
contrast/main_barcodes.json
Normal file
1894
contrast/main_barcodes.json
Normal file
File diff suppressed because it is too large
Load Diff
17
contrast/main_library.py
Normal file
17
contrast/main_library.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""搭建主特征库"""
|
||||
from test_logic import create_milvus, img2feature
|
||||
from config import config as conf
|
||||
from img_data import library_imgs, temp_imgs
|
||||
|
||||
def createMainMilvus(imgs_dict): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||
barcode_list, imgs_list = img2feature(imgs_dict)
|
||||
mainMilvus = create_milvus('main_features', conf.host, conf.port, barcode_list, imgs_list)
|
||||
return mainMilvus
|
||||
def createTempMilvus(imgs_dict): ##imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||
barcode_list, imgs_list = img2feature(imgs_dict)
|
||||
tempMilvus = create_milvus('temp_features', conf.host, conf.port, barcode_list, imgs_list)
|
||||
return tempMilvus
|
||||
|
||||
if __name__ == "__main__":
|
||||
createMainMilvus(library_imgs)
|
||||
# createTempMilvus(temp_imgs)
|
2
contrast/model/__init__.py
Normal file
2
contrast/model/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .resnet_pre import resnet18
|
||||
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
BIN
contrast/model/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/fmobilenet.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/fmobilenet.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/mobilenet_v3.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/mobilenet_v3.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/mobilevit.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/mobilevit.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/resnet_pre.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/resnet_pre.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
BIN
contrast/model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/model/__pycache__/utils.cpython-310.pyc
Normal file
BIN
contrast/model/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
200
contrast/model/mobilenet_v3.py
Normal file
200
contrast/model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
||||
'''MobileNetV3 in PyTorch.
|
||||
|
||||
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from tools.config import config as conf
|
||||
|
||||
|
||||
class hswish(nn.Module):
|
||||
def forward(self, x):
|
||||
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class hsigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
out = F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class SeModule(nn.Module):
|
||||
def __init__(self, in_size, reduction=4):
|
||||
super(SeModule, self).__init__()
|
||||
self.se = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size),
|
||||
hsigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.se(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expand + depthwise + pointwise'''
|
||||
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
self.se = semodule
|
||||
|
||||
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear1 = nolinear
|
||||
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear2 = nolinear
|
||||
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_size)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 1 and in_size != out_size:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_size),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.se != None:
|
||||
out = self.se(out)
|
||||
out = out + self.shortcut(x) if self.stride==1 else out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV3_Large(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Large, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(960)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(960, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class MobileNetV3_Small(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Small, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(576)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(576, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNetV3_Small()
|
||||
x = torch.randn(2,3,224,224)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
462
contrast/model/resnet_pre.py
Normal file
462
contrast/model/resnet_pre.py
Normal file
@ -0,0 +1,462 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tools.config import config as conf
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
# from .utils import load_state_dict_from_url
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
|
||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
|
||||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
x = torch.cat([avg_out, max_out], dim=1)
|
||||
x = self.conv1(x)
|
||||
return self.sigmoid(x)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
if self.bam:
|
||||
self.bam = SpatialAttention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
|
||||
if self.bam:
|
||||
out = out*self.bam(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=0.75):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, int(64*scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128*scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256*scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512*scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion*scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
# print('poolBefore', x.shape)
|
||||
x = self.avgpool(x)
|
||||
# print('poolAfter', x.shape)
|
||||
x = torch.flatten(x, 1)
|
||||
# print('fcBefore',x.shape)
|
||||
x = self.fc(x)
|
||||
|
||||
# print('fcAfter',x.shape)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# return model
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-14 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
101
contrast/search.py
Normal file
101
contrast/search.py
Normal file
@ -0,0 +1,101 @@
|
||||
import pdb
|
||||
|
||||
|
||||
class ImgSearch():
|
||||
def __init__(self):
|
||||
self.search_params = {
|
||||
"metric_type": "COSINE",
|
||||
}
|
||||
|
||||
def get_max(self, a, b):
|
||||
if a > b:
|
||||
return a
|
||||
else:
|
||||
return b
|
||||
|
||||
def check_keys(self, dict1_last, dict2):
|
||||
for key2 in list(dict2.keys()):
|
||||
if key2 in list(dict1_last.keys()):
|
||||
value = self.get_max(dict1_last[key2], dict2[key2])
|
||||
dict1_last[key2] = value
|
||||
else:
|
||||
dict1_last[key2] = dict2[key2]
|
||||
return dict1_last
|
||||
def result_analysis(self, result, top1_flag=False):
|
||||
result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典
|
||||
for hits in result:
|
||||
for hit in hits:
|
||||
if not hit.id in result_dict: ## barcode(hit.id)不在结果字典中
|
||||
result_dict.update({hit.id: round(hit.distance, 2)})
|
||||
else: ## 将同一barcode相似度保存较高的
|
||||
distance = result_dict.get(hit.id)
|
||||
distance_new = self.get_max(distance, hit.distance)
|
||||
result_dict.update({hit.id: round(distance_new, 2)})
|
||||
if top1_flag:
|
||||
return result_dict
|
||||
else:
|
||||
## 将所有barcode相似度结果排序存储
|
||||
if len(result_dict) > 10:
|
||||
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||
else:
|
||||
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True))
|
||||
return result_sort_dict
|
||||
def result_update(self, temp_result, last_result):
|
||||
temp_keys = list(temp_result.keys())
|
||||
last_keys = list(last_result.keys())
|
||||
for ke in temp_keys:
|
||||
temp_value = temp_result[ke]
|
||||
if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcode,track_id2中barcode相似度高才更新
|
||||
last_value = last_result[ke]
|
||||
if temp_value > last_value:
|
||||
last_result.update({ke: temp_value})
|
||||
else: ## track_id1的结果和track_id2的结果无公共barcode
|
||||
last_result.update({ke: temp_value})
|
||||
return last_result
|
||||
def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id
|
||||
result_last = dict()
|
||||
for i in range(len(queBarIdList)):
|
||||
vectorsSearch = queueFeatures[i]
|
||||
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10)
|
||||
result_sort_dic = self.result_analysis(result)
|
||||
result_last.update({queBarIdList[i]: result_sort_dic})
|
||||
return result_last
|
||||
|
||||
def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId):
|
||||
newBarList = []
|
||||
### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn]
|
||||
for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode
|
||||
if len(bar.split('_')) == 3:
|
||||
mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1]
|
||||
if mac_barcode in barIdList:
|
||||
newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm]
|
||||
if len(newBarList) == 0:
|
||||
return {}
|
||||
else:
|
||||
expr = f"pk in {newBarList}"
|
||||
result_last = dict()
|
||||
for i in range(len(queueList)):
|
||||
vectorsSearch = queueFeatures[i]
|
||||
result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params,
|
||||
limit=len(newBarList))
|
||||
result_sort_dic = self.result_analysis(result)
|
||||
result_last.update({queueList[i]: result_sort_dic})
|
||||
return result_last
|
||||
|
||||
def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId
|
||||
result_last = dict()
|
||||
for i in range(len(queBarIdList)):
|
||||
pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId
|
||||
vectorsSearch = queFeatures[i]
|
||||
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'",
|
||||
param=self.search_params, limit=1)
|
||||
result_dic = self.result_analysis(result, top1_flag=True)
|
||||
if (len(result_dic) != 0) and (len(result_last) != 0):
|
||||
result_last = self.result_update(result_dic, result_last)
|
||||
else:
|
||||
result_last.update({key: value for key, value in result_dic.items()})
|
||||
if len(result_last) == 0:
|
||||
pk_barcode = queBarIdList[0].split('_')[1]
|
||||
result_last.update({pk_barcode: 0})
|
||||
return result_last
|
||||
|
317
contrast/test_logic.py
Normal file
317
contrast/test_logic.py
Normal file
@ -0,0 +1,317 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pdb
|
||||
import random
|
||||
import json
|
||||
import time
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from contrast.model import resnet18, MobileNetV3_Large
|
||||
# import pymilvus
|
||||
# from pymilvus import (
|
||||
# connections,
|
||||
# utility,
|
||||
# FieldSchema, CollectionSchema, DataType,
|
||||
# Collection,
|
||||
# Milvus
|
||||
# )
|
||||
# from config import config as conf
|
||||
from contrast.search import ImgSearch
|
||||
from contrast.img_data import queueImgs_add
|
||||
import sys
|
||||
from threading import Thread
|
||||
|
||||
|
||||
sys.path.append('../tools')
|
||||
from tools.config import cfg as conf
|
||||
from tools.config import gvalue
|
||||
|
||||
|
||||
|
||||
def test_preprocess(images: list, actionModel) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
# print(img)
|
||||
try:
|
||||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def inference(images, model, actionModel):
|
||||
data = test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf.device)
|
||||
features = model(data)
|
||||
return features
|
||||
|
||||
|
||||
def group_image(images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
|
||||
def barcode_state(barcodeIDList):
|
||||
with open('contrast/main_barcodes.json', 'r') as file:
|
||||
data = json.load(file)
|
||||
main_barcode = list(data.values())[0]
|
||||
barIdList_true = []
|
||||
barIdList_false = []
|
||||
for barId in barcodeIDList:
|
||||
bar = barId.split('_')[1]
|
||||
if bar in main_barcode:
|
||||
barIdList_true.append(barId)
|
||||
else:
|
||||
barIdList_false.append(barId)
|
||||
return barIdList_true, barIdList_false
|
||||
|
||||
|
||||
def getFeatureList(barList, imgList, model, actionModel):
|
||||
featList = [[] for i in range(len(barList))]
|
||||
for index, feat in enumerate(imgList):
|
||||
groups = group_image(feat)
|
||||
for group in groups:
|
||||
feat_tensor = inference(group, model, actionModel)
|
||||
for fe in feat_tensor:
|
||||
if fe.device == 'cpu':
|
||||
fe_np = fe.squeeze().detach().numpy()
|
||||
else:
|
||||
fe_np = fe.squeeze().detach().cpu().numpy()
|
||||
featList[index].append(fe_np)
|
||||
return featList
|
||||
|
||||
|
||||
def img2feature(imgs_dict, model, actionModel, barcode_flag):
|
||||
if not len(imgs_dict) > 0:
|
||||
raise ValueError("Tracking fail no images files provided")
|
||||
queBarIdList = list(imgs_dict.keys())
|
||||
if barcode_flag:
|
||||
# # ========判断barcode是否在特征库============
|
||||
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
|
||||
queFeatList_t = []
|
||||
if len(queBarIdList_t) == 0:
|
||||
print(f"All barcodes are not in the main_library: {barIdList_f}")
|
||||
return queBarIdList_t, queFeatList_t
|
||||
else:
|
||||
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
|
||||
print(f"These barcodes are not in the main_library: {barIdList_f}")
|
||||
for bar_f in barIdList_f:
|
||||
del imgs_dict[bar_f]
|
||||
queImgList_t = list(imgs_dict.values())
|
||||
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
|
||||
return queBarIdList_t, queFeatList_t
|
||||
else:
|
||||
queImgsList = list(imgs_dict.values())
|
||||
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
|
||||
return queBarIdList, queFeatList
|
||||
|
||||
|
||||
|
||||
# def create_milvus(collection_name, host, port, barcode_list, features):
|
||||
# # 1. connect to Milvus
|
||||
# fmt = "\n=== {:30} ===\n"
|
||||
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
|
||||
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
|
||||
# print(f"Does collection {collection_name} exist in Milvus: {has}")
|
||||
# # if has: ## 删除collection_name的库
|
||||
# # utility.drop_collection(collection_name)
|
||||
#
|
||||
# # 2. create colllection
|
||||
# fields = [
|
||||
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
|
||||
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
|
||||
# ]
|
||||
# schema = CollectionSchema(fields)
|
||||
# print(fmt.format(f"Create collection {collection_name}"))
|
||||
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
|
||||
# # 3. insert data
|
||||
# for i in range(len(features)):
|
||||
# entities = [
|
||||
# # provide the pk field because `auto_id` is set to False
|
||||
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
|
||||
# features[i],
|
||||
# ]
|
||||
# print(fmt.format("Start inserting entities"))
|
||||
# insert_result = hello_milvus.insert(entities)
|
||||
# hello_milvus.flush()
|
||||
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
|
||||
# return hello_milvus
|
||||
|
||||
|
||||
# def load_collection(collection_name):
|
||||
# collection = Collection(collection_name)
|
||||
# # collection.release() ### 将collection从加载状态变成未加载
|
||||
# # collection.drop_index() ### 删除索引
|
||||
#
|
||||
# index_params = {
|
||||
# "index_type": "IVF_FLAT",
|
||||
# # "index_type": "IVF_SQ8",
|
||||
# # "index_type": "GPU_IVF_FLAT",
|
||||
# "metric_type": "COSINE",
|
||||
# "params": {
|
||||
# "nlist": 10000
|
||||
# }
|
||||
# }
|
||||
# #### 准确率低
|
||||
# # index_params = {
|
||||
# # "index_type": "IVF_PQ",
|
||||
# # "metric_type": "COSINE",
|
||||
# # "params": {
|
||||
# # "nlist": 99,
|
||||
# # "m": 2,
|
||||
# # "nbits": 8
|
||||
# # }
|
||||
# # }
|
||||
# collection.create_index(
|
||||
# field_name="embeddings",
|
||||
# index_params=index_params,
|
||||
# index_name="SQ8"
|
||||
# )
|
||||
# collection.load()
|
||||
# return collection
|
||||
|
||||
|
||||
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
|
||||
# searchImg = ImgSearch() ## 相似度比较
|
||||
# # 将输入图片加入临时库
|
||||
# if add_flag:
|
||||
# if actionModel:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
|
||||
# else:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
|
||||
#
|
||||
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
|
||||
# if len(queBarIdList) == 0:
|
||||
# top10, top1 = {}, {}
|
||||
# else:
|
||||
# for bar in queBarIdList:
|
||||
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
|
||||
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
|
||||
# else:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
|
||||
# # 存入临时特征库
|
||||
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
|
||||
#
|
||||
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||||
# 'host': conf.host,
|
||||
# 'port': conf.port,
|
||||
# 'barcode_list': queBarIdList,
|
||||
# 'features': queBarIdFeatures})
|
||||
# thread.start()
|
||||
# start1 = time.time()
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start2 = time.time()
|
||||
# print('search top10 time>>>> {}'.format(start2-start1))
|
||||
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start3 = time.time()
|
||||
# print('search top1 time>>>>> {}'.format(start3-start2))
|
||||
# return top10, top1, gvalue.tempLibLists
|
||||
# else: # 加购 无barcode -> 输出top10
|
||||
# # 无barcode时,生成随机数作为字典key值
|
||||
# queBarIdList_rand = []
|
||||
# for i in range(len(queBarIdList)):
|
||||
# random_number = ''.join(random.choices('0123456789', k=10))
|
||||
# queBarIdList_rand.append(str(random_number))
|
||||
# # gvalue.tempLibList.append(str(random_number))
|
||||
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
|
||||
# else:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
|
||||
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
|
||||
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||||
# 'host': conf.host,
|
||||
# 'port': conf.port,
|
||||
# 'barcode_list': queBarIdList_rand,
|
||||
# 'features': queBarIdFeatures})
|
||||
# thread.start()
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# # print(f'top10: {top10}')
|
||||
# return top10, gvalue.tempLibLists
|
||||
# else: # 退购 -> 输出top10和topn
|
||||
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
|
||||
# gvalue.tempLibList = []
|
||||
# else:
|
||||
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||||
# ## 加载临时特征库
|
||||
# tempMilvusName = "temp_features"
|
||||
# has = utility.has_collection(tempMilvusName)
|
||||
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
|
||||
# tempMilvus = load_collection(tempMilvusName)
|
||||
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
|
||||
# if actionModel:
|
||||
# barcode_list = barcode_list
|
||||
# else:
|
||||
# barcode_list = queueImgs_add['barcode_list']
|
||||
# if actionModel:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
|
||||
# else:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
|
||||
# if barcode_flag:
|
||||
# if len(queBarIdList) == 0:
|
||||
# top10, top1, top_n = {}, {}, {}
|
||||
# else:
|
||||
# start1 = time.time()
|
||||
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start2 = time.time()
|
||||
# print('search top1 time>>>> {}'.format(start2 - start1))
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start3 = time.time()
|
||||
# print('search top10 time>>>> {}'.format(start3 - start2))
|
||||
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||||
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
|
||||
# return top10, top1, top_n
|
||||
# else:
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||||
# # print(f'top10: {top10}, topn: {top_n}')
|
||||
# return top10, top_n
|
||||
|
||||
|
||||
def similarity_interface(dataCollection):
|
||||
queImgsDict = dataCollection.queImgsDict
|
||||
add_flag = dataCollection.add_flag
|
||||
barcode_flag = dataCollection.barcode_flag
|
||||
main_milvus = dataCollection.mainMilvus
|
||||
#tempLibList = dataCollection.tempLibList
|
||||
model = dataCollection.model
|
||||
actionModel = dataCollection.actionModel
|
||||
barcode_list = dataCollection.barcode_list
|
||||
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
||||
# connections.connect('default', host=conf.host, port=conf.port)
|
||||
# # 加载主特征库
|
||||
# mainMilvusName = "main_features"
|
||||
# has = utility.has_collection(mainMilvusName)
|
||||
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
|
||||
# mainMilvus = Collection(mainMilvusName)
|
||||
# mainMilvus.load()
|
||||
# model = initModel()
|
||||
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
|
||||
# add_flag = queueImgs_add['add_flag']
|
||||
# barcode_flag = queueImgs_add['barcode_flag']
|
||||
# tempLibList = [] # 临时特征库的barcodeId_list
|
||||
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
|
||||
# if add_flag:
|
||||
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
|
||||
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||||
# print(f"top10: {top10}\ntop1: {top1}")
|
||||
# else: # 加购 无barcode -> 输出top10
|
||||
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||||
# else: # 退购 -> 输出top10和topn
|
||||
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
64
contrast/utils.py
Normal file
64
contrast/utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""Train List 训练列表
|
||||
格式:
|
||||
ImagePath Label
|
||||
|
||||
示例:
|
||||
/data/WebFace/0124920/003.jpg 10572
|
||||
/data/WebFace/0124920/012.jpg 10572
|
||||
/data/WebFace/0124920/020.jpg 10572
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from imutils import paths
|
||||
|
||||
def generate_list(images_directory, saved_name=None):
|
||||
"""生成数据列表
|
||||
Args:
|
||||
images_directory: 人脸数据目录,通常包含多个子文件夹。如
|
||||
WebFace和LFW的格式
|
||||
Returns:
|
||||
data_list: [<路径> <标签>]
|
||||
"""
|
||||
subdirs = os.listdir(images_directory)
|
||||
num_ids = len(subdirs)
|
||||
data_list = []
|
||||
for i in range(num_ids):
|
||||
subdir = osp.join(images_directory, subdirs[i])
|
||||
files = os.listdir(subdir)
|
||||
paths = [osp.join(subdir, file) for file in files]
|
||||
# 添加ID作为其人脸标签
|
||||
paths_with_Id = [f"{p} {i}\n" for p in paths]
|
||||
data_list.extend(paths_with_Id)
|
||||
|
||||
if saved_name:
|
||||
with open(saved_name, 'w', encoding='utf-8') as f:
|
||||
f.writelines(data_list)
|
||||
return data_list
|
||||
|
||||
def transform_clean_list(webface_directory, cleaned_list_path):
|
||||
"""转换webface的干净列表格式
|
||||
Args:
|
||||
webface_directory: WebFace数据目录
|
||||
cleaned_list_path: cleaned_list.txt路径
|
||||
Returns:
|
||||
cleaned_list: 转换后的数据列表
|
||||
"""
|
||||
with open(cleaned_list_path, encoding='utf-8') as f:
|
||||
cleaned_list = f.readlines()
|
||||
cleaned_list = [p.replace('\\', '/') for p in cleaned_list]
|
||||
cleaned_list = [osp.join(webface_directory, p) for p in cleaned_list]
|
||||
return cleaned_list
|
||||
|
||||
def remove_dirty_image(webface_directory, cleaned_list):
|
||||
cleaned_list = set([c.split()[0] for c in cleaned_list])
|
||||
for p in paths.list_images(webface_directory):
|
||||
if p not in cleaned_list:
|
||||
print(f"remove {p}")
|
||||
os.remove(p)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = '/data/CASIA-WebFace/'
|
||||
lst = '/data/cleaned_list.txt'
|
||||
cleaned_list = transform_clean_list(data, lst)
|
||||
remove_dirty_image(data, cleaned_list)
|
Reference in New Issue
Block a user