rebuild
This commit is contained in:
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
*.pth
|
||||||
|
blog/
|
||||||
|
data/
|
||||||
|
experiment/
|
||||||
|
log/
|
||||||
|
shop_xlsx/
|
||||||
|
loss/
|
||||||
|
checkpoints/
|
||||||
|
search_library/
|
||||||
|
quant_imgs/
|
||||||
|
README.md
|
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
869
.idea/CopilotChatHistory.xml
generated
Normal file
869
.idea/CopilotChatHistory.xml
generated
Normal file
File diff suppressed because one or more lines are too long
6
.idea/CopilotSideBarWebPersist.xml
generated
Normal file
6
.idea/CopilotSideBarWebPersist.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="CopilotSideBarWebPersist">
|
||||||
|
<option name="autoAddFileCloseState" value="true" />
|
||||||
|
</component>
|
||||||
|
</project>
|
19326
.idea/CopilotWebChatHistory.xml
generated
Normal file
19326
.idea/CopilotWebChatHistory.xml
generated
Normal file
File diff suppressed because one or more lines are too long
8
.idea/contrast_nettest.iml
generated
Normal file
8
.idea/contrast_nettest.iml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="服务器3-NV4090-env:py-contrast-nettest" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
114
.idea/deployment.xml
generated
Normal file
114
.idea/deployment.xml
generated
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PublishConfigData" autoUpload="Always" serverName="lc@192.168.10.89:22 password (6)" exclude=".svn;.cvs;.idea;.DS_Store;.git;.hg;*.hprof;*.pyc;*.jpg;*.mp4;data/" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false">
|
||||||
|
<option name="confirmBeforeUploading" value="false" />
|
||||||
|
<serverData>
|
||||||
|
<paths name="contrast_nettest">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/contrast_nettest" local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="ieemoo0169@192.168.10.93:22 password">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/home/ieemoo0169/contrast_nettest" local="$PROJECT_DIR$" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.56:22 password">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (10)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (11)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (2)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (3)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (4)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (5)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (6)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/home/lc/contrast_nettest" local="$PROJECT_DIR$" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (7)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (8)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="lc@192.168.10.89:22 password (9)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="yolov5">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
</serverData>
|
||||||
|
<option name="myAutoUpload" value="ALWAYS" />
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Remote Python 3.8.18 (sftp://lc@192.168.1.142:22/home/lc/project/miniconda3/envs/my_env/bin/python)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="服务器3-NV4090-env:py-contrast-nettest" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/contrast_nettest.iml" filepath="$PROJECT_DIR$/.idea/contrast_nettest.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
8
.idea/sshConfigs.xml
generated
Normal file
8
.idea/sshConfigs.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="SshConfigs">
|
||||||
|
<configs>
|
||||||
|
<sshConfig authType="PASSWORD" connectionConfig="{"serverAliveInterval":300}" host="192.168.1.28" id="f9cd63ee-d39a-42a7-b369-1eb74d4f71ae" port="22" nameFormat="DESCRIPTIVE" username="ieemoo0169" useOpenSSHConfig="true" />
|
||||||
|
</configs>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
14
.idea/webServers.xml
generated
Normal file
14
.idea/webServers.xml
generated
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="WebServers">
|
||||||
|
<option name="servers">
|
||||||
|
<webServer id="422a5cdc-8aff-4e1f-9f9a-2377f5a31f0b" name="contrast_nettest">
|
||||||
|
<fileTransfer rootFolder="/home/ieemoo0169" accessType="SFTP" host="192.168.1.28" port="22" sshConfigId="74dc3f38-9a9b-4eb8-ae6f-ed04cca88f27" sshConfig="ieemoo0169@192.168.1.28:22 password">
|
||||||
|
<advancedOptions>
|
||||||
|
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
|
||||||
|
</advancedOptions>
|
||||||
|
</fileTransfer>
|
||||||
|
</webServer>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
</project>
|
9
.vscode/sftp.json
vendored
Normal file
9
.vscode/sftp.json
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"name": "My Server",
|
||||||
|
"host": "localhost",
|
||||||
|
"protocol": "sftp",
|
||||||
|
"port": 22,
|
||||||
|
"username": "username",
|
||||||
|
"remotePath": "/",
|
||||||
|
"uploadOnSave": true
|
||||||
|
}
|
BIN
__pycache__/config.cpython-38.pyc
Normal file
BIN
__pycache__/config.cpython-38.pyc
Normal file
Binary file not shown.
BIN
__pycache__/test_ori.cpython-38.pyc
Normal file
BIN
__pycache__/test_ori.cpython-38.pyc
Normal file
Binary file not shown.
122
config.py
Normal file
122
config.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to_square(img):
|
||||||
|
w, h = img.size
|
||||||
|
max_wh = max(w, h)
|
||||||
|
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||||
|
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# network settings
|
||||||
|
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
|
||||||
|
# mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
|
||||||
|
metric = 'arcface' # [cosface, arcface, softmax]
|
||||||
|
cbam = False
|
||||||
|
embedding_size = 256 # 256 # gift:2 contrast:256
|
||||||
|
drop_ratio = 0.5
|
||||||
|
img_size = 224
|
||||||
|
multiple_cards = True # 多卡加载
|
||||||
|
model_half = False # 模型半精度测试
|
||||||
|
data_half = True # 数据半精度测试
|
||||||
|
channel_ratio = 0.75 # 通道剪枝比例
|
||||||
|
quantization_test = False # int8量化模型测试
|
||||||
|
|
||||||
|
# custom base_data settings
|
||||||
|
custom_backbone = False # 迁移学习载入除最后一层的所有层
|
||||||
|
custom_num_classes = 128 # 迁移学习的类别数量
|
||||||
|
|
||||||
|
# if quantization_test:
|
||||||
|
# device = torch.device('cpu')
|
||||||
|
# else:
|
||||||
|
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
|
||||||
|
# PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||||||
|
|
||||||
|
student = 'resnet'
|
||||||
|
# data preprocess
|
||||||
|
"""transforms.RandomCrop(size),
|
||||||
|
transforms.RandomVerticalFlip(p=0.5),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
RandomRotate(15, 0.3),
|
||||||
|
# RandomGaussianBlur()"""
|
||||||
|
train_transform = T.Compose([
|
||||||
|
T.Lambda(pad_to_square), # 补边
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((img_size, img_size), antialias=True),
|
||||||
|
# T.RandomCrop(img_size * 4 // 5),
|
||||||
|
T.RandomHorizontalFlip(p=0.5),
|
||||||
|
T.RandomRotation(180),
|
||||||
|
T.ColorJitter(brightness=0.5),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[0.5], std=[0.5]),
|
||||||
|
])
|
||||||
|
test_transform = T.Compose([
|
||||||
|
# T.Lambda(pad_to_square), # 补边
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((img_size, img_size), antialias=True),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
# T.Normalize(mean=[0,0,0], std=[255,255,255]),
|
||||||
|
T.Normalize(mean=[0.5], std=[0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
|
||||||
|
test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||||
|
restore = True
|
||||||
|
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||||||
|
restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
|
||||||
|
|
||||||
|
# test settings
|
||||||
|
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||||||
|
|
||||||
|
# test_val = "./data/2250_train"
|
||||||
|
# test_list = "./data/2250_train/val_pair.txt"
|
||||||
|
# test_group_json = "./data/2250_train/cross_same.json"
|
||||||
|
|
||||||
|
test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
|
||||||
|
test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
|
||||||
|
test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
|
||||||
|
# test_group_json = "./data/test/inner_group_pairs.json"
|
||||||
|
|
||||||
|
# test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
|
||||||
|
test_model = "checkpoints/resnet18_1009/best.pth"
|
||||||
|
# test_model = "checkpoints/zhanting/inland/res_801.pth"
|
||||||
|
# test_model = "checkpoints/resnet18_20250504/best.pth"
|
||||||
|
# test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
|
||||||
|
group_test = False
|
||||||
|
# group_test = False
|
||||||
|
|
||||||
|
train_batch_size = 128 # 256
|
||||||
|
test_batch_size = 128 # 256
|
||||||
|
|
||||||
|
epoch = 5 # 512
|
||||||
|
optimizer = 'sgd' # ['sgd', 'adam', 'adamw']
|
||||||
|
lr = 5e-3 # 1e-2
|
||||||
|
lr_step = 10 # 10
|
||||||
|
lr_decay = 0.98 # 0.98
|
||||||
|
weight_decay = 5e-4
|
||||||
|
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
|
||||||
|
log_path = './log'
|
||||||
|
lr_min = 1e-6 # min lr
|
||||||
|
|
||||||
|
pin_memory = False # if memory is large, set it True to speed up a bit
|
||||||
|
num_workers = 32 # 64
|
||||||
|
compare = False # compare the result of different models
|
||||||
|
|
||||||
|
'''
|
||||||
|
train_distill settings
|
||||||
|
'''
|
||||||
|
warmup_epochs = 3 # warmup_epoch
|
||||||
|
distributed = True # distributed training
|
||||||
|
teacher_path = "./checkpoints/resnet50_0519/best.pth"
|
||||||
|
distill_weight = 0.8 # 蒸馏权重
|
||||||
|
|
||||||
|
config = Config()
|
1
configs/__init__.py
Normal file
1
configs/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .utils import trainer_tools
|
69
configs/compare.yml
Normal file
69
configs/compare.yml
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# configs/compare.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||||
|
seed: 42 # 随机种子(保证可复现性)
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 0.75
|
||||||
|
|
||||||
|
# 训练参数
|
||||||
|
training:
|
||||||
|
epochs: 600 # 总训练轮次
|
||||||
|
batch_size: 128 # 批次大小
|
||||||
|
lr: 0.001 # 初始学习率
|
||||||
|
optimizer: "sgd" # 优化器类型
|
||||||
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
|
lr_step: 10 # 学习率调整间隔(epoch)
|
||||||
|
lr_decay: 0.98 # 学习率衰减率
|
||||||
|
weight_decay: 0.0005 # 权重衰减
|
||||||
|
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
||||||
|
restore: false
|
||||||
|
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||||
|
|
||||||
|
# 验证参数
|
||||||
|
validation:
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
val_batch_size: 128 # 测试批次大小
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_batch_size: 128 # 训练批次大小
|
||||||
|
val_batch_size: 128 # 验证批次大小
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||||
|
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
# 分布式训练(可选)
|
||||||
|
distributed:
|
||||||
|
enabled: false # 是否启用分布式训练
|
||||||
|
backend: "nccl" # 分布式后端(nccl/gloo)
|
75
configs/distill.yml
Normal file
75
configs/distill.yml
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# configs/compare.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||||
|
seed: 42 # 随机种子(保证可复现性)
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 1.0 # 主干特征通道缩放比例(默认)
|
||||||
|
student_channel_ratio: 0.75
|
||||||
|
teacher_model_path: "./checkpoints/resnet50_0519/best.pth"
|
||||||
|
|
||||||
|
# 训练参数
|
||||||
|
training:
|
||||||
|
epochs: 600 # 总训练轮次
|
||||||
|
batch_size: 128 # 批次大小
|
||||||
|
lr: 0.001 # 初始学习率
|
||||||
|
optimizer: "sgd" # 优化器类型
|
||||||
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
|
lr_step: 10 # 学习率调整间隔(epoch)
|
||||||
|
lr_decay: 0.98 # 学习率衰减率
|
||||||
|
weight_decay: 0.0005 # 权重衰减
|
||||||
|
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
||||||
|
restore: false
|
||||||
|
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||||
|
distill_weight: 0.8 # 蒸馏损失权重
|
||||||
|
temperature: 4 # 蒸馏温度
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 验证参数
|
||||||
|
validation:
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
val_batch_size: 128 # 测试批次大小
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_batch_size: 128 # 训练批次大小
|
||||||
|
val_batch_size: 100 # 验证批次大小
|
||||||
|
num_workers: 4 # 数据加载线程数
|
||||||
|
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||||
|
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
# 分布式训练(可选)
|
||||||
|
distributed:
|
||||||
|
enabled: false # 是否启用分布式训练
|
||||||
|
backend: "nccl" # 分布式后端(nccl/gloo)
|
69
configs/scatter.yml
Normal file
69
configs/scatter.yml
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# configs/scatter.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 1.0
|
||||||
|
|
||||||
|
# 训练参数
|
||||||
|
training:
|
||||||
|
epochs: 300 # 总训练轮次
|
||||||
|
batch_size: 64 # 批次大小
|
||||||
|
lr: 0.005 # 初始学习率
|
||||||
|
optimizer: "sgd" # 优化器类型
|
||||||
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
|
lr_step: 10 # 学习率调整间隔(epoch)
|
||||||
|
lr_decay: 0.98 # 学习率衰减率
|
||||||
|
weight_decay: 0.0005 # 权重衰减
|
||||||
|
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
|
||||||
|
restore: True
|
||||||
|
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 验证参数
|
||||||
|
validation:
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
val_batch_size: 128 # 测试批次大小
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_batch_size: 128 # 训练批次大小
|
||||||
|
val_batch_size: 100 # 验证批次大小
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
|
||||||
|
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
# 分布式训练(可选)
|
||||||
|
distributed:
|
||||||
|
enabled: false # 是否启用分布式训练
|
||||||
|
backend: "nccl" # 分布式后端(nccl/gloo)
|
41
configs/test.yml
Normal file
41
configs/test.yml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# configs/test.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 1.0
|
||||||
|
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
|
||||||
|
half: false # 是否启用半精度测试(fp16)
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
test_batch_size: 128 # 训练批次大小
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
test_dir: "../data_center/scatter/" # 验证数据集根目录
|
||||||
|
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||||
|
test_list: "../data_center/scatter/val_pair.txt"
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
save:
|
||||||
|
save_dir: ""
|
||||||
|
save_name: ""
|
||||||
|
|
||||||
|
|
56
configs/utils.py
Normal file
56
configs/utils.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||||
|
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||||
|
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||||
|
from model.metric import ArcFace, CosFace
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn as nn
|
||||||
|
import timm
|
||||||
|
|
||||||
|
|
||||||
|
class trainer_tools:
|
||||||
|
def __init__(self, conf):
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
def get_backbone(self):
|
||||||
|
backbone_mapping = {
|
||||||
|
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||||
|
'mobilevit_s': lambda: mobilevit_s(),
|
||||||
|
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||||
|
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||||
|
'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
|
||||||
|
'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
|
||||||
|
'mobilenetv3_large': lambda: MobileNetV3_Large(),
|
||||||
|
'vit_base': lambda: vit_base_16(pretrained=True),
|
||||||
|
'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
|
||||||
|
num_classes=self.conf.embedding_size)
|
||||||
|
}
|
||||||
|
return backbone_mapping
|
||||||
|
|
||||||
|
def get_metric(self, class_num):
|
||||||
|
# 优化后的metric选择代码块,使用字典映射提高可读性和扩展性
|
||||||
|
metric_mapping = {
|
||||||
|
'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
|
||||||
|
'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
|
||||||
|
'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
|
||||||
|
}
|
||||||
|
return metric_mapping
|
||||||
|
|
||||||
|
def get_optimizer(self, model, metric):
|
||||||
|
optimizer_mapping = {
|
||||||
|
'sgd': lambda: optim.SGD(
|
||||||
|
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||||||
|
lr=self.conf['training']['lr'],
|
||||||
|
weight_decay=self.conf['training']['weight_decay']
|
||||||
|
),
|
||||||
|
'adam': lambda: optim.Adam(
|
||||||
|
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||||||
|
lr=self.conf['training']['lr'],
|
||||||
|
weight_decay=self.conf['training']['weight_decay']
|
||||||
|
),
|
||||||
|
'adamw': lambda: optim.AdamW(
|
||||||
|
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||||||
|
lr=self.conf['training']['lr'],
|
||||||
|
weight_decay=self.conf['training']['weight_decay']
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return optimizer_mapping
|
47
configs/write_feature.yml
Normal file
47
configs/write_feature.yml
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# configs/write_feature.yml
|
||||||
|
# 专为模型训练对比设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet18'
|
||||||
|
channel_ratio: 0.75
|
||||||
|
checkpoints: "../checkpoints/resnet18_1009/best.pth"
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
train_batch_size: 128 # 训练批次大小
|
||||||
|
test_batch_size: 128 # 验证批次大小
|
||||||
|
num_workers: 32 # 数据加载线程数
|
||||||
|
half: true # 是否启用半精度数据
|
||||||
|
img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
|
||||||
|
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
|
||||||
|
xlsx_pth: false # 过滤商品, 默认None不进行过滤
|
||||||
|
|
||||||
|
transform:
|
||||||
|
img_size: 224 # 图像尺寸
|
||||||
|
img_mean: 0.5 # 图像均值
|
||||||
|
img_std: 0.5 # 图像方差
|
||||||
|
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||||
|
RandomRotation: 180 # 随机旋转角度
|
||||||
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
save:
|
||||||
|
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
|
||||||
|
json_path: "../data/feature_json_compare/" # 保存单个json文件
|
||||||
|
error_barcodes: "error_barcodes.txt"
|
||||||
|
barcodes_statistics: "../search_library/barcodes_statistics.txt"
|
88
model/BAM.py
Normal file
88
model/BAM.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
from torch.nn import init
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.shape[0], -1)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAttention(nn.Module):
|
||||||
|
def __int__(self, channel, reduction, num_layers):
|
||||||
|
super(ChannelAttention, self).__init__()
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
gate_channels = [channel]
|
||||||
|
gate_channels += [len(channel) // reduction] * num_layers
|
||||||
|
gate_channels += [channel]
|
||||||
|
|
||||||
|
self.ca = nn.Sequential()
|
||||||
|
self.ca.add_module('flatten', Flatten())
|
||||||
|
for i in range(len(gate_channels) - 2):
|
||||||
|
self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
|
||||||
|
self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
|
||||||
|
self.ca.add_module('', nn.ReLU())
|
||||||
|
self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.avgpool(x)
|
||||||
|
res = self.ca(res)
|
||||||
|
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
|
||||||
|
super(SpatialAttention).__init__()
|
||||||
|
self.sa = nn.Sequential()
|
||||||
|
self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
|
||||||
|
self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
|
||||||
|
self.sa.add_module('', nn.ReLU())
|
||||||
|
for i in range(num_lay):
|
||||||
|
self.sa.add_module('', nn.Conv2d(kernel_size=3,
|
||||||
|
in_channels=(channel // reduction),
|
||||||
|
out_channels=(channel // reduction),
|
||||||
|
padding=1,
|
||||||
|
dilation=2))
|
||||||
|
self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
|
||||||
|
self.sa.add_module('', nn.ReLU())
|
||||||
|
self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.sa(x)
|
||||||
|
res = res.expand_as(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class BAMblock(nn.Module):
|
||||||
|
def __init__(self, channel=512, reduction=16, dia_val=2):
|
||||||
|
super(BAMblock, self).__init__()
|
||||||
|
self.ca = ChannelAttention(channel, reduction)
|
||||||
|
self.sa = SpatialAttention(channel, reduction, dia_val)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal(m.weight, mode='fan_out')
|
||||||
|
if m.bais is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, _, _ = x.size()
|
||||||
|
sa_out = self.sa(x)
|
||||||
|
ca_out = self.ca(x)
|
||||||
|
weight = self.sigmoid(sa_out + ca_out)
|
||||||
|
out = (1 + weight) * x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(512 // 14)
|
70
model/CBAM.py
Normal file
70
model/CBAM.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
class channelAttention(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(channelAttention, self).__init__()
|
||||||
|
self.Maxpooling = nn.AdaptiveMaxPool2d(1)
|
||||||
|
self.Avepooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.ca = nn.Sequential()
|
||||||
|
self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
|
||||||
|
self.ca.add_module('Relu', nn.ReLU())
|
||||||
|
self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
M_out = self.Maxpooling(x)
|
||||||
|
A_out = self.Avepooling(x)
|
||||||
|
M_out = self.ca(M_out)
|
||||||
|
A_out = self.ca(A_out)
|
||||||
|
out = self.sigmod(M_out+A_out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
max_result, _ = torch.max(x, dim=1, keepdim=True)
|
||||||
|
avg_result = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
result = torch.cat([max_result, avg_result], dim=1)
|
||||||
|
output = self.conv(result)
|
||||||
|
output = self.sigmoid(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class CBAM(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16, kernel_size=7):
|
||||||
|
super().__init__()
|
||||||
|
self.ca = channelAttention(channel, reduction)
|
||||||
|
self.sa = SpatialAttention(kernel_size)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():#权重初始化
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# b,c_,_ = x.size()
|
||||||
|
# residual = x
|
||||||
|
out = x*self.ca(x)
|
||||||
|
out = out*self.sa(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
input=torch.randn(50,512,7,7)
|
||||||
|
kernel_size=input.shape[2]
|
||||||
|
cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
|
||||||
|
output=cbam(input)
|
||||||
|
print(output.shape)
|
37
model/Tool.py
Normal file
37
model/Tool.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class GeM(nn.Module):
|
||||||
|
def __init__(self, p=3, eps=1e-6):
|
||||||
|
super(GeM, self).__init__()
|
||||||
|
self.p = nn.Parameter(torch.ones(1) * p)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.gem(x, p=self.p, eps=self.eps, stride=2)
|
||||||
|
|
||||||
|
def gem(self, x, p=3, eps=1e-6, stride=2):
|
||||||
|
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + \
|
||||||
|
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||||
|
', ' + 'eps=' + str(self.eps) + ')'
|
||||||
|
|
||||||
|
|
||||||
|
class TripletLoss(nn.Module):
|
||||||
|
def __init__(self, margin):
|
||||||
|
super(TripletLoss, self).__init__()
|
||||||
|
self.margin = margin
|
||||||
|
|
||||||
|
def forward(self, anchor, positive, negative, size_average=True):
|
||||||
|
distance_positive = (anchor - positive).pow(2).sum(1)
|
||||||
|
distance_negative = (anchor - negative).pow(2).sum(1)
|
||||||
|
losses = F.relu(distance_negative - distance_positive + self.margin)
|
||||||
|
return losses.mean() if size_average else losses.sum()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('')
|
14
model/__init__.py
Normal file
14
model/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from .fmobilenet import FaceMobileNet
|
||||||
|
# from .resnet_face import ResIRSE
|
||||||
|
from .mobilevit import mobilevit_s
|
||||||
|
from .metric import ArcFace, CosFace
|
||||||
|
from .loss import FocalLoss
|
||||||
|
from .resbam import resnet
|
||||||
|
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
|
||||||
|
from .mobilenet_v2 import mobilenet_v2
|
||||||
|
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||||
|
# from .mobilenet_v1 import mobilenet_v1
|
||||||
|
from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, \
|
||||||
|
PPLCNET_x2_5
|
||||||
|
from .vit import vit_base
|
||||||
|
from .mlp import MLP
|
BIN
model/__pycache__/CBAM.cpython-38.pyc
Normal file
BIN
model/__pycache__/CBAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/Tool.cpython-38.pyc
Normal file
BIN
model/__pycache__/Tool.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
BIN
model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/lcnet.cpython-38.pyc
Normal file
BIN
model/__pycache__/lcnet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mlp.cpython-38.pyc
Normal file
BIN
model/__pycache__/mlp.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v1.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v1.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v2.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/resbam.cpython-38.pyc
Normal file
BIN
model/__pycache__/resbam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
BIN
model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/utils.cpython-38.pyc
Normal file
BIN
model/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/vit.cpython-38.pyc
Normal file
BIN
model/__pycache__/vit.cpython-38.pyc
Normal file
Binary file not shown.
142
model/benchmark.py
Normal file
142
model/benchmark.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from resnet_attention import resnet18_cbam, resnet34_cbam, resnet50_cbam
|
||||||
|
|
||||||
|
# 设置随机种子以确保结果可复现
|
||||||
|
torch.manual_seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"测试设备: {device}")
|
||||||
|
|
||||||
|
# 测试参数
|
||||||
|
batch_sizes = [1, 4, 8, 16]
|
||||||
|
image_sizes = [224, 384, 512]
|
||||||
|
num_runs = 100 # 每个配置运行的次数
|
||||||
|
warmup_runs = 20 # 预热运行次数,排除启动开销
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
model_configs = {
|
||||||
|
"resnet18": {
|
||||||
|
"base_model": lambda: resnet18_cbam(use_cbam=False),
|
||||||
|
"attention_model": lambda: resnet18_cbam(use_cbam=True)
|
||||||
|
},
|
||||||
|
"resnet34": {
|
||||||
|
"base_model": lambda: resnet34_cbam(use_cbam=False),
|
||||||
|
"attention_model": lambda: resnet34_cbam(use_cbam=True)
|
||||||
|
},
|
||||||
|
"resnet50": {
|
||||||
|
"base_model": lambda: resnet50_cbam(use_cbam=False),
|
||||||
|
"attention_model": lambda: resnet50_cbam(use_cbam=True)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 基准测试函数
|
||||||
|
def benchmark_model(model, input_size, batch_size, num_runs, warmup_runs):
|
||||||
|
"""
|
||||||
|
测试模型的推理性能
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- model: 待测试的模型
|
||||||
|
- input_size: 输入图像尺寸
|
||||||
|
- batch_size: 批次大小
|
||||||
|
- num_runs: 测试运行次数
|
||||||
|
- warmup_runs: 预热运行次数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 平均推理时间(毫秒)
|
||||||
|
- 吞吐量(样本/秒)
|
||||||
|
"""
|
||||||
|
# 设置为评估模式
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 创建随机输入
|
||||||
|
input_tensor = torch.randn(batch_size, 3, input_size, input_size, device=device)
|
||||||
|
|
||||||
|
# 预热
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(warmup_runs):
|
||||||
|
_ = model(input_tensor)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
torch.cuda.synchronize() # 同步GPU操作
|
||||||
|
|
||||||
|
# 测量推理时间
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(num_runs):
|
||||||
|
_ = model(input_tensor)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
torch.cuda.synchronize() # 同步GPU操作
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# 计算指标
|
||||||
|
total_time = end_time - start_time
|
||||||
|
avg_time_per_batch = total_time / num_runs * 1000 # 毫秒
|
||||||
|
throughput = batch_size * num_runs / total_time # 样本/秒
|
||||||
|
|
||||||
|
return avg_time_per_batch, throughput
|
||||||
|
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for model_name, config in model_configs.items():
|
||||||
|
results[model_name] = {}
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
base_model = config["base_model"]()
|
||||||
|
attention_model = config["attention_model"]()
|
||||||
|
|
||||||
|
# 计算参数量
|
||||||
|
base_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
|
||||||
|
attention_params = sum(p.numel() for p in attention_model.parameters() if p.requires_grad)
|
||||||
|
param_increase = (attention_params - base_params) / base_params * 100
|
||||||
|
|
||||||
|
print(f"\n测试模型: {model_name}")
|
||||||
|
print(f" 基础参数量: {base_params / 1e6:.2f}M")
|
||||||
|
print(f" 带注意力参数量: {attention_params / 1e6:.2f}M")
|
||||||
|
print(f" 参数量增加: {param_increase:.2f}%")
|
||||||
|
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
for image_size in image_sizes:
|
||||||
|
key = f"batch_{batch_size}_size_{image_size}"
|
||||||
|
results[model_name][key] = {}
|
||||||
|
|
||||||
|
# 测试基础模型
|
||||||
|
base_time, base_throughput = benchmark_model(
|
||||||
|
base_model, image_size, batch_size, num_runs, warmup_runs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试注意力模型
|
||||||
|
attention_time, attention_throughput = benchmark_model(
|
||||||
|
attention_model, image_size, batch_size, num_runs, warmup_runs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算增加的百分比
|
||||||
|
time_increase = (attention_time - base_time) / base_time * 100
|
||||||
|
throughput_decrease = (base_throughput - attention_throughput) / base_throughput * 100
|
||||||
|
|
||||||
|
results[model_name][key]["base_time"] = base_time
|
||||||
|
results[model_name][key]["attention_time"] = attention_time
|
||||||
|
results[model_name][key]["time_increase"] = time_increase
|
||||||
|
results[model_name][key]["base_throughput"] = base_throughput
|
||||||
|
results[model_name][key]["attention_throughput"] = attention_throughput
|
||||||
|
results[model_name][key]["throughput_decrease"] = throughput_decrease
|
||||||
|
|
||||||
|
print(f" 配置: 批次大小={batch_size}, 图像尺寸={image_size}x{image_size}")
|
||||||
|
print(f" 基础模型: 平均时间={base_time:.2f}ms, 吞吐量={base_throughput:.2f}样本/秒")
|
||||||
|
print(f" 注意力模型: 平均时间={attention_time:.2f}ms, 吞吐量={attention_throughput:.2f}样本/秒")
|
||||||
|
print(f" 时间增加: {time_increase:.2f}%, 吞吐量下降: {throughput_decrease:.2f}%")
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open('benchmark_results.json', 'w') as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
print("\n测试完成,结果已保存到 benchmark_results.json")
|
48
model/compare.py
Normal file
48
model/compare.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import torch
|
||||||
|
from config import config as conf
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.models as models
|
||||||
|
from model.resnet_pre import resnet18, resnet50
|
||||||
|
# from model.vit import vit_base_patch16_224, vit_base_patch32_224
|
||||||
|
|
||||||
|
|
||||||
|
class ContrastiveModel(nn.Module):
|
||||||
|
def __init__(self, projection_dim, model_name, contraposition=False):
|
||||||
|
super(ContrastiveModel, self).__init__()
|
||||||
|
self.contraposition = contraposition
|
||||||
|
self.base_model = self._get_model(model_name)
|
||||||
|
if not self.contraposition:
|
||||||
|
if 'vit' in model_name:
|
||||||
|
dim_mlp = self.base_model.head.weight.shape[1]
|
||||||
|
self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
|
||||||
|
else:
|
||||||
|
dim_mlp = self.base_model.fc.weight.shape[1]
|
||||||
|
self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
|
||||||
|
# # 冻结除 FC 层之外的所有层
|
||||||
|
# for name, param in self.base_model.named_parameters():
|
||||||
|
# if 'fc' not in name:
|
||||||
|
# param.requires_grad = False
|
||||||
|
|
||||||
|
def _get_projection_layer(self, dim_mlp, projection_dim):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(dim_mlp, dim_mlp),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(dim_mlp, projection_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_model(self, model_name):
|
||||||
|
base_model = None
|
||||||
|
if model_name == 'resnet18':
|
||||||
|
base_model = resnet18(pretrained=True)
|
||||||
|
elif model_name == 'resnet50':
|
||||||
|
base_model = resnet50(pretrained=True)
|
||||||
|
# elif model_name == 'vit':
|
||||||
|
# base_model = vit_base_patch32_224()
|
||||||
|
return base_model
|
||||||
|
def forward(self, x):
|
||||||
|
assert self.base_model is not None, 'base_model is none'
|
||||||
|
x = self.base_model(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pass
|
182
model/distill.py
Normal file
182
model/distill.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import Module
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vit_pytorch.vit import ViT
|
||||||
|
from vit_pytorch.t2t import T2TViT
|
||||||
|
from vit_pytorch.efficient import ViT as EfficientViT
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
from config import config as conf
|
||||||
|
# helpers
|
||||||
|
# Data Setup
|
||||||
|
from tools.dataset import load_data
|
||||||
|
train_dataloader, class_num = load_data(conf, training=True)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
# classes
|
||||||
|
|
||||||
|
class DistillMixin:
|
||||||
|
def forward(self, img, distill_token=None):
|
||||||
|
distilling = exists(distill_token)
|
||||||
|
x = self.to_patch_embedding(img)
|
||||||
|
b, n, _ = x.shape
|
||||||
|
|
||||||
|
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b=b)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
x += self.pos_embedding[:, :(n + 1)]
|
||||||
|
|
||||||
|
if distilling:
|
||||||
|
distill_tokens = repeat(distill_token, '1 n d -> b n d', b=b)
|
||||||
|
x = torch.cat((x, distill_tokens), dim=1)
|
||||||
|
|
||||||
|
x = self._attend(x)
|
||||||
|
|
||||||
|
if distilling:
|
||||||
|
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||||
|
|
||||||
|
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
|
||||||
|
|
||||||
|
x = self.to_latent(x)
|
||||||
|
out = self.mlp_head(x)
|
||||||
|
|
||||||
|
if distilling:
|
||||||
|
return out, distill_tokens
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DistillableViT(DistillMixin, ViT):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.dim = kwargs['dim']
|
||||||
|
self.num_classes = kwargs['num_classes']
|
||||||
|
|
||||||
|
def to_vit(self):
|
||||||
|
v = ViT(*self.args, **self.kwargs)
|
||||||
|
v.load_state_dict(self.state_dict())
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _attend(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.transformer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DistillableT2TViT, self).__init__(*args, **kwargs)
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.dim = kwargs['dim']
|
||||||
|
self.num_classes = kwargs['num_classes']
|
||||||
|
|
||||||
|
def to_vit(self):
|
||||||
|
v = T2TViT(*self.args, **self.kwargs)
|
||||||
|
v.load_state_dict(self.state_dict())
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _attend(self, x):
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.transformer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.dim = kwargs['dim']
|
||||||
|
self.num_classes = kwargs['num_classes']
|
||||||
|
|
||||||
|
|
||||||
|
def to_vit(self):
|
||||||
|
v = EfficientViT(*self.args, **self.kwargs)
|
||||||
|
v.load_state_dict(self.state_dict())
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _attend(self, x):
|
||||||
|
return self.transformer(x)
|
||||||
|
|
||||||
|
|
||||||
|
# knowledge distillation wrapper
|
||||||
|
|
||||||
|
class DistillWrapper(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
teacher,
|
||||||
|
student,
|
||||||
|
temperature=1.,
|
||||||
|
alpha=0.5,
|
||||||
|
hard=False,
|
||||||
|
mlp_layernorm=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# assert (isinstance(student, (
|
||||||
|
# DistillableViT, DistillableT2TViT, DistillableEfficientViT))), 'student must be a vision transformer'
|
||||||
|
if isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.teacher = teacher
|
||||||
|
self.student = student
|
||||||
|
|
||||||
|
dim = conf.embedding_size # student.dim
|
||||||
|
num_classes = class_num # class_num # student.num_classes
|
||||||
|
self.temperature = temperature
|
||||||
|
self.alpha = alpha
|
||||||
|
self.hard = hard
|
||||||
|
|
||||||
|
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
|
|
||||||
|
# student is vit
|
||||||
|
# self.distill_mlp = nn.Sequential(
|
||||||
|
# nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||||
|
# nn.Linear(dim, num_classes)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# student is resnet
|
||||||
|
self.distill_mlp = nn.Sequential(
|
||||||
|
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||||
|
nn.Linear(dim, num_classes).to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, img, labels, temperature=None, alpha=None, **kwargs):
|
||||||
|
|
||||||
|
alpha = default(alpha, self.alpha)
|
||||||
|
T = default(temperature, self.temperature)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_logits = self.teacher(img)
|
||||||
|
teacher_logits = self.distill_mlp(teacher_logits) # teach is vit 初始化
|
||||||
|
# student is vit
|
||||||
|
# student_logits, distill_tokens = self.student(img, distill_token=self.distillation_token, **kwargs)
|
||||||
|
# distill_logits = self.distill_mlp(distill_tokens)
|
||||||
|
|
||||||
|
# student is resnet
|
||||||
|
student_logits = self.student(img)
|
||||||
|
distill_logits = self.distill_mlp(student_logits)
|
||||||
|
loss = F.cross_entropy(distill_logits, labels)
|
||||||
|
# pdb.set_trace()
|
||||||
|
if not self.hard:
|
||||||
|
distill_loss = F.kl_div(
|
||||||
|
F.log_softmax(distill_logits / T, dim=-1),
|
||||||
|
F.softmax(teacher_logits / T, dim=-1).detach(),
|
||||||
|
reduction='batchmean')
|
||||||
|
distill_loss *= T ** 2
|
||||||
|
else:
|
||||||
|
teacher_labels = teacher_logits.argmax(dim=-1)
|
||||||
|
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
|
||||||
|
# pdb.set_trace()
|
||||||
|
return loss * (1 - alpha) + distill_loss * alpha
|
124
model/fmobilenet.py
Normal file
124
model/fmobilenet.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.shape[0], -1)
|
||||||
|
|
||||||
|
class ConvBn(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||||
|
nn.BatchNorm2d(out_c)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnPrelu(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
ConvBn(in_c, out_c, kernel, stride, padding, groups),
|
||||||
|
nn.PReLU(out_c)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWise(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
|
||||||
|
ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
|
||||||
|
ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseRes(nn.Module):
|
||||||
|
"""DepthWise with Residual"""
|
||||||
|
|
||||||
|
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDepthWiseRes(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = nn.Sequential(*[
|
||||||
|
DepthWiseRes(channels, channels, kernel, stride, padding, groups)
|
||||||
|
for _ in range(num_block)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceMobileNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, embedding_size):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
|
||||||
|
self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
|
||||||
|
self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
|
||||||
|
self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
|
||||||
|
self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
|
||||||
|
self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||||
|
self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
|
||||||
|
self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||||
|
self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
|
||||||
|
self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
|
||||||
|
self.flatten = Flatten()
|
||||||
|
self.linear = nn.Linear(2048, embedding_size, bias=False)
|
||||||
|
self.bn = nn.BatchNorm1d(embedding_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
#print('x',x.shape)
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.conv4(out)
|
||||||
|
out = self.conv5(out)
|
||||||
|
out = self.conv6(out)
|
||||||
|
out = self.conv7(out)
|
||||||
|
out = self.conv8(out)
|
||||||
|
out = self.conv9(out)
|
||||||
|
out = self.conv10(out)
|
||||||
|
out = self.flatten(out)
|
||||||
|
out = self.linear(out)
|
||||||
|
out = self.bn(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
x = Image.open("../samples/009.jpg").convert('L')
|
||||||
|
x = x.resize((128, 128))
|
||||||
|
x = np.asarray(x, dtype=np.float32)
|
||||||
|
x = x[None, None, ...]
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
net = FaceMobileNet(512)
|
||||||
|
net.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
out = net(x)
|
||||||
|
print(out.shape)
|
233
model/lcnet.py
Normal file
233
model/lcnet.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import thop
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import softpool_cuda
|
||||||
|
# from SoftPool import soft_pool2d, SoftPool2d
|
||||||
|
# except ImportError:
|
||||||
|
# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
|
||||||
|
# exit(0)
|
||||||
|
|
||||||
|
NET_CONFIG = {
|
||||||
|
# k, in_c, out_c, s, use_se
|
||||||
|
"blocks2": [[3, 16, 32, 1, False]],
|
||||||
|
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||||
|
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||||
|
"blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
|
||||||
|
[5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||||
|
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||||
|
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def autopad(k, p=None):
|
||||||
|
if p is None:
|
||||||
|
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor=8, min_value=None):
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwish(nn.Module):
|
||||||
|
def __init__(self, inplace=True):
|
||||||
|
super(HardSwish, self).__init__()
|
||||||
|
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.relu6(x+3) / 6
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoid(nn.Module):
|
||||||
|
def __init__(self, inplace=True):
|
||||||
|
super(HardSigmoid, self).__init__()
|
||||||
|
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (self.relu6(x+3)) / 6
|
||||||
|
|
||||||
|
|
||||||
|
class SELayer(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(SELayer, self).__init__()
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(channel, channel // reduction, bias=False),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(channel // reduction, channel, bias=False),
|
||||||
|
HardSigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.size()
|
||||||
|
y = self.avgpool(x).view(b, c)
|
||||||
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseSeparable(nn.Module):
|
||||||
|
def __init__(self, inp, oup, dw_size, stride, use_se=False):
|
||||||
|
super(DepthwiseSeparable, self).__init__()
|
||||||
|
self.use_se = use_se
|
||||||
|
self.stride = stride
|
||||||
|
self.inp = inp
|
||||||
|
self.oup = oup
|
||||||
|
self.dw_size = dw_size
|
||||||
|
self.dw_sp = nn.Sequential(
|
||||||
|
nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
|
||||||
|
padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
|
||||||
|
nn.BatchNorm2d(self.inp),
|
||||||
|
HardSwish(),
|
||||||
|
|
||||||
|
nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(self.oup),
|
||||||
|
HardSwish(),
|
||||||
|
)
|
||||||
|
self.se = SELayer(self.oup)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dw_sp(x)
|
||||||
|
if self.use_se:
|
||||||
|
x = self.se(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PP_LCNet(nn.Module):
|
||||||
|
def __init__(self, scale=1.0, class_num=256, class_expand=1280, dropout_prob=0.2):
|
||||||
|
super(PP_LCNet, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
|
||||||
|
kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||||
|
self.blocks2 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks3 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks4 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
|
||||||
|
])
|
||||||
|
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||||
|
self.blocks5 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks6 = nn.Sequential(*[
|
||||||
|
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||||
|
oup=make_divisible(out_c * self.scale),
|
||||||
|
dw_size=k, stride=s, use_se=use_se)
|
||||||
|
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.GAP = nn.AdaptiveAvgPool2d(1)
|
||||||
|
|
||||||
|
self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
|
||||||
|
out_channels=class_expand,
|
||||||
|
kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
|
||||||
|
self.hardswish = HardSwish()
|
||||||
|
self.dropout = nn.Dropout(p=dropout_prob)
|
||||||
|
|
||||||
|
self.fc = nn.Linear(class_expand, class_num)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
# print(x.shape)
|
||||||
|
x = self.blocks2(x)
|
||||||
|
# print(x.shape)
|
||||||
|
x = self.blocks3(x)
|
||||||
|
# print(x.shape)
|
||||||
|
x = self.blocks4(x)
|
||||||
|
# print(x.shape)
|
||||||
|
x = self.blocks5(x)
|
||||||
|
# print(x.shape)
|
||||||
|
x = self.blocks6(x)
|
||||||
|
# print(x.shape)
|
||||||
|
|
||||||
|
x = self.GAP(x)
|
||||||
|
x = self.last_conv(x)
|
||||||
|
x = self.hardswish(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = torch.flatten(x, start_dim=1, end_dim=-1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_25(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.25, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_35(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.35, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_5(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.5, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x0_75(**kwargs):
|
||||||
|
model = PP_LCNet(scale=0.75, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x1_0(**kwargs):
|
||||||
|
model = PP_LCNet(scale=1.0, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x1_5(**kwargs):
|
||||||
|
model = PP_LCNet(scale=1.5, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNET_x2_0(**kwargs):
|
||||||
|
model = PP_LCNet(scale=2.0, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def PPLCNET_x2_5(**kwargs):
|
||||||
|
model = PP_LCNet(scale=2.5, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# input = torch.randn(1, 3, 640, 640)
|
||||||
|
# model = PPLCNET_x2_5()
|
||||||
|
# flops, params = thop.profile(model, inputs=(input,))
|
||||||
|
# print('flops:', flops / 1000000000)
|
||||||
|
# print('params:', params / 1000000)
|
||||||
|
|
||||||
|
model = PPLCNET_x1_0()
|
||||||
|
# model_1 = PW_Conv(3, 16)
|
||||||
|
input = torch.randn(2, 3, 256, 256)
|
||||||
|
print(input.shape)
|
||||||
|
output = model(input)
|
||||||
|
print(output.shape) # [1, num_class]
|
||||||
|
|
18
model/loss.py
Normal file
18
model/loss.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, gamma=2):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
self.ce = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
|
||||||
|
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
||||||
|
logp = self.ce(input, target)
|
||||||
|
p = torch.exp(-logp)
|
||||||
|
loss = (1 - p) ** self.gamma * logp
|
||||||
|
return loss.mean()
|
94
model/metric.py
Normal file
94
model/metric.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Definition of ArcFace loss and CosFace loss
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ArcFace(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
||||||
|
"""ArcFace formula:
|
||||||
|
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
||||||
|
Note that:
|
||||||
|
0 <= m + theta <= Pi
|
||||||
|
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
||||||
|
we have:
|
||||||
|
cos(theta) < cos(Pi - m)
|
||||||
|
So we can use cos(Pi - m) as threshold to check whether
|
||||||
|
(m + theta) go out of [0, Pi]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_size: usually 128, 256, 512 ...
|
||||||
|
class_num: num of people when training
|
||||||
|
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||||
|
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = embedding_size
|
||||||
|
self.out_features = class_num
|
||||||
|
self.s = s
|
||||||
|
self.m = m
|
||||||
|
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
||||||
|
nn.init.xavier_uniform_(self.weight)
|
||||||
|
|
||||||
|
self.cos_m = math.cos(m)
|
||||||
|
self.sin_m = math.sin(m)
|
||||||
|
self.th = math.cos(math.pi - m)
|
||||||
|
self.mm = math.sin(math.pi - m) * m
|
||||||
|
|
||||||
|
def forward(self, input, label):
|
||||||
|
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
||||||
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||||
|
# print('F.normalize(input)',input.shape)
|
||||||
|
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
||||||
|
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
||||||
|
phi = cosine * self.cos_m - sine * self.sin_m
|
||||||
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
||||||
|
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
||||||
|
# update y_i by phi in cosine
|
||||||
|
output = cosine * 1.0 # make backward works
|
||||||
|
batch_size = len(output)
|
||||||
|
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||||
|
# print(f'output {(output * self.s).shape}')
|
||||||
|
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
||||||
|
return output * self.s
|
||||||
|
|
||||||
|
|
||||||
|
class CosFace(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
embedding_size: usually 128, 256, 512 ...
|
||||||
|
class_num: num of people when training
|
||||||
|
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||||
|
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.s = s
|
||||||
|
self.m = m
|
||||||
|
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||||
|
nn.init.xavier_uniform_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, input, label):
|
||||||
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||||
|
phi = cosine - self.m
|
||||||
|
output = cosine * 1.0 # make backward works
|
||||||
|
batch_size = len(output)
|
||||||
|
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||||
|
return output * self.s
|
||||||
|
|
||||||
|
class Distillation(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, T=1.0):
|
||||||
|
super(Distillation, self).__init__()
|
||||||
|
self.T = T
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||||
|
nn.init.xavier_uniform_(self.weight)
|
||||||
|
def forward(self, input, labels):
|
||||||
|
pass
|
274
model/mlp.py
Normal file
274
model/mlp.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
import pdb
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
from model.resnet_pre import resnet18, conv1x1, BasicBlock, load_state_dict_from_url, model_urls
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, input_dim=256, output_dim=1):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.fc1 = nn.Linear(self.input_dim, 128) # 32
|
||||||
|
self.fc2 = nn.Linear(128, 64)
|
||||||
|
self.fc3 = nn.Linear(64, 32)
|
||||||
|
self.fc4 = nn.Linear(32, 16)
|
||||||
|
self.fc5 = nn.Linear(16, self.output_dim)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
self.bn1 = nn.BatchNorm1d(128)
|
||||||
|
self.bn2 = nn.BatchNorm1d(64)
|
||||||
|
self.bn3 = nn.BatchNorm1d(32)
|
||||||
|
self.bn4 = nn.BatchNorm1d(16)
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
init.kaiming_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(self.bn1(x))
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.relu(self.bn2(x))
|
||||||
|
x = self.fc3(x)
|
||||||
|
x = self.relu(self.bn3(x))
|
||||||
|
x = self.fc4(x)
|
||||||
|
x = self.relu(self.bn4(x))
|
||||||
|
x = self.sigmoid(self.fc5(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Net2(nn.Module): # 该网络部署有风险,dnn推理有障碍
|
||||||
|
def __init__(self, input_dim=960, output_dim=1):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
|
||||||
|
# self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
|
||||||
|
# self.conv4 = nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=1)
|
||||||
|
self.maxPool1 = nn.MaxPool1d(kernel_size=3, stride=2)
|
||||||
|
self.conv5 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=1)
|
||||||
|
self.maxPool2 = nn.MaxPool1d(kernel_size=3, stride=2)
|
||||||
|
|
||||||
|
self.avgPool = nn.AdaptiveAvgPool1d(1)
|
||||||
|
self.MaxPool = nn.AdaptiveMaxPool1d(1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
# self.conv6 = nn.Conv1d(128, 128, kernel_size=5, stride=2, padding=1)
|
||||||
|
self.fc1 = nn.Linear(960, 128)
|
||||||
|
self.fc21 = nn.Linear(960, 32)
|
||||||
|
self.fc22 = nn.Linear(32, 128)
|
||||||
|
self.fc3 = nn.Linear(128, 1)
|
||||||
|
self.bn1 = nn.BatchNorm1d(16)
|
||||||
|
self.bn2 = nn.BatchNorm1d(32)
|
||||||
|
self.bn3 = nn.BatchNorm1d(64)
|
||||||
|
self.bn4 = nn.BatchNorm1d(128)
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
init.kaiming_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x) # 16
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv2(x) # 32
|
||||||
|
x = self.relu(x)
|
||||||
|
# x = self.conv3(x)
|
||||||
|
# x = self.relu(x)
|
||||||
|
# x = self.conv4(x) # 64
|
||||||
|
# x = self.relu(x)
|
||||||
|
# x = self.maxPool1(x)
|
||||||
|
|
||||||
|
x = self.conv5(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
# x = self.conv6(x)
|
||||||
|
# x = self.relu(x)
|
||||||
|
# x = self.maxPool2(x)
|
||||||
|
# x = self.MaxPool(x)
|
||||||
|
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
|
||||||
|
# pdb.set_trace()
|
||||||
|
x1 = self.fc1(x)
|
||||||
|
x2 = self.fc22(self.fc21(x))
|
||||||
|
x = self.fc3(x1 + x2)
|
||||||
|
x = self.sigmoid(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Net3(nn.Module): # 目前较合适的网络结构,相较于Net2,Net3的输出结果更加准确
|
||||||
|
def __init__(self, pretrained=True, progress=True, num_classes=1, scale=0.75):
|
||||||
|
super(Net3, self).__init__()
|
||||||
|
self.resnet18 = resnet18(pretrained=pretrained, progress=progress)
|
||||||
|
|
||||||
|
# Remove the last three layers (layer3, layer4, avgpool, fc)
|
||||||
|
# self.resnet18.layer3 = nn.Identity()
|
||||||
|
# self.resnet18.layer4 = nn.Identity()
|
||||||
|
self.resnet18.avgpool = nn.Identity()
|
||||||
|
self.resnet18.fc = nn.Identity()
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
# Calculate the output size after layer2
|
||||||
|
# Assuming input size is 224x224, layer2 will have output size of 56x56
|
||||||
|
# So, the flattened size will be 128 * scale * 56 * 56
|
||||||
|
self.flattened_size = int(128 * (56 * 56) * scale * scale)
|
||||||
|
|
||||||
|
# Add new layers for classification
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(384, num_classes), # layer1, layer2 in_features=96 # layer1 in_features=48 #layer3 in_features=192
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.Dropout(0.6),
|
||||||
|
# nn.Linear(256, num_classes),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.resnet18.layer1(x)
|
||||||
|
x = self.resnet18.layer2(x)
|
||||||
|
x = self.resnet18.layer3(x)
|
||||||
|
x = self.resnet18.layer4(x)
|
||||||
|
|
||||||
|
# Debugging: Print the shape of the tensor before flattening
|
||||||
|
# print("Shape before flattening:", x.shape)
|
||||||
|
|
||||||
|
# Ensure the tensor is flattened correctly
|
||||||
|
# x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
def __init__(self, block, layers, num_classes=1, zero_init_residual=False,
|
||||||
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None, scale=0.75):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self._norm_layer = norm_layer
|
||||||
|
|
||||||
|
self.inplanes = 64
|
||||||
|
self.dilation = 1
|
||||||
|
if replace_stride_with_dilation is None:
|
||||||
|
# each element in the tuple indicates if we should replace
|
||||||
|
# the 2x2 stride with a dilated convolution instead
|
||||||
|
replace_stride_with_dilation = [False, False, False]
|
||||||
|
if len(replace_stride_with_dilation) != 3:
|
||||||
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[1])
|
||||||
|
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[2])
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
norm_layer = self._norm_layer
|
||||||
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
norm_layer(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||||
|
base_width=self.base_width, dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _forward_impl(self, x):
|
||||||
|
# See note [TorchScript super()]
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
x = self.sigmoid(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
def Net4(arch, pretrained, progress, **kwargs):
|
||||||
|
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||||
|
src_state_dict = state_dict
|
||||||
|
target_state_dict = model.state_dict()
|
||||||
|
skip_keys = []
|
||||||
|
# skip mismatch size tensors in case of pretraining
|
||||||
|
for k in src_state_dict.keys():
|
||||||
|
if k not in target_state_dict:
|
||||||
|
continue
|
||||||
|
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||||
|
skip_keys.append(k)
|
||||||
|
for k in skip_keys:
|
||||||
|
del src_state_dict[k]
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
'''
|
||||||
|
net2 = Net2()
|
||||||
|
input_tensor = torch.randn(10, 1, 64)
|
||||||
|
# 前向传播
|
||||||
|
output_tensor = net2(input_tensor)
|
||||||
|
# pdb.set_trace()
|
||||||
|
print("输入张量形状:", input_tensor.shape)
|
||||||
|
print("输出张量形状:", output_tensor.shape)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# model = Net3(pretrained=True, num_classes=1) # 预训练从resnet中间结果获取数据训练模型
|
||||||
|
model = Net4('resnet18', True, True)
|
||||||
|
input_tensor = torch.randn(1, 3, 224, 244) # Adjust batch size to 10
|
||||||
|
output = model(input_tensor)
|
||||||
|
print(output.shape) # Should be [10, 2]
|
148
model/mobilenet_v1.py
Normal file
148
model/mobilenet_v1.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.ops.misc import Conv2dNormActivation
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MobileNetV1",
|
||||||
|
"DepthWiseSeparableConv2d",
|
||||||
|
"mobilenet_v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV1(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: int = conf.embedding_size,
|
||||||
|
) -> None:
|
||||||
|
super(MobileNetV1, self).__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
Conv2dNormActivation(3,
|
||||||
|
32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
activation_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
|
||||||
|
DepthWiseSeparableConv2d(32, 64, 1),
|
||||||
|
DepthWiseSeparableConv2d(64, 128, 2),
|
||||||
|
DepthWiseSeparableConv2d(128, 128, 1),
|
||||||
|
DepthWiseSeparableConv2d(128, 256, 2),
|
||||||
|
DepthWiseSeparableConv2d(256, 256, 1),
|
||||||
|
DepthWiseSeparableConv2d(256, 512, 2),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 512, 1),
|
||||||
|
DepthWiseSeparableConv2d(512, 1024, 2),
|
||||||
|
DepthWiseSeparableConv2d(1024, 1024, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d((7, 7))
|
||||||
|
|
||||||
|
self.classifier = nn.Linear(1024, num_classes)
|
||||||
|
|
||||||
|
# Initialize neural network weights
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
out = self._forward_impl(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Support torch.script function
|
||||||
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||||
|
out = self.features(x)
|
||||||
|
out = self.avgpool(out)
|
||||||
|
out = torch.flatten(out, 1)
|
||||||
|
out = self.classifier(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _initialize_weights(self) -> None:
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
nn.init.normal_(module.weight, 0, 0.01)
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseSeparableConv2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super(DepthWiseSeparableConv2d, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
if stride not in [1, 2]:
|
||||||
|
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
Conv2dNormActivation(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
groups=in_channels,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
activation_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
Conv2dNormActivation(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
activation_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
out = self.conv(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
|
||||||
|
model = MobileNetV1(**kwargs)
|
||||||
|
|
||||||
|
return model
|
200
model/mobilenet_v2.py
Normal file
200
model/mobilenet_v2.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from torch import nn
|
||||||
|
from .utils import load_state_dict_from_url
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||||
|
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_divisible(v, divisor, min_value=None):
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by 8
|
||||||
|
It can be seen here:
|
||||||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
:param v:
|
||||||
|
:param divisor:
|
||||||
|
:param min_value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNReLU(nn.Sequential):
|
||||||
|
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
super(ConvBNReLU, self).__init__(
|
||||||
|
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||||
|
norm_layer(out_planes),
|
||||||
|
nn.ReLU6(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
hidden_dim = int(round(inp * expand_ratio))
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
if expand_ratio != 1:
|
||||||
|
# pw
|
||||||
|
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
|
||||||
|
layers.extend([
|
||||||
|
# dw
|
||||||
|
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
])
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=conf.embedding_size,
|
||||||
|
width_mult=1.0,
|
||||||
|
inverted_residual_setting=None,
|
||||||
|
round_nearest=8,
|
||||||
|
block=None,
|
||||||
|
norm_layer=None):
|
||||||
|
"""
|
||||||
|
MobileNet V2 main class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of classes
|
||||||
|
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||||
|
inverted_residual_setting: Network structure
|
||||||
|
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||||
|
Set to 1 to turn off rounding
|
||||||
|
block: Module specifying inverted residual building block for mobilenet
|
||||||
|
norm_layer: Module specifying the normalization layer to use
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(MobileNetV2, self).__init__()
|
||||||
|
|
||||||
|
if block is None:
|
||||||
|
block = InvertedResidual
|
||||||
|
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
|
||||||
|
input_channel = 32
|
||||||
|
last_channel = 1280
|
||||||
|
|
||||||
|
if inverted_residual_setting is None:
|
||||||
|
inverted_residual_setting = [
|
||||||
|
# t, c, n, s
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 2],
|
||||||
|
[6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 2],
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 2],
|
||||||
|
[6, 320, 1, 1],
|
||||||
|
]
|
||||||
|
|
||||||
|
# only check the first element, assuming user knows t,c,n,s are required
|
||||||
|
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||||
|
raise ValueError("inverted_residual_setting should be non-empty "
|
||||||
|
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||||
|
|
||||||
|
# building first layer
|
||||||
|
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||||
|
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||||
|
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
||||||
|
# building inverted residual blocks
|
||||||
|
for t, c, n, s in inverted_residual_setting:
|
||||||
|
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||||
|
for i in range(n):
|
||||||
|
stride = s if i == 0 else 1
|
||||||
|
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||||
|
input_channel = output_channel
|
||||||
|
# building last several layers
|
||||||
|
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
|
||||||
|
# make it nn.Sequential
|
||||||
|
self.features = nn.Sequential(*features)
|
||||||
|
|
||||||
|
# building classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(self.last_channel, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight initialization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def _forward_impl(self, x):
|
||||||
|
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||||
|
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||||
|
x = self.features(x)
|
||||||
|
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
|
||||||
|
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV2 architecture from
|
||||||
|
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
model = MobileNetV2(**kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||||
|
progress=progress)
|
||||||
|
src_state_dict = state_dict
|
||||||
|
target_state_dict = model.state_dict()
|
||||||
|
skip_keys = []
|
||||||
|
# skip mismatch size tensors in case of pretraining
|
||||||
|
for k in src_state_dict.keys():
|
||||||
|
if k not in target_state_dict:
|
||||||
|
continue
|
||||||
|
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||||
|
skip_keys.append(k)
|
||||||
|
for k in skip_keys:
|
||||||
|
del src_state_dict[k]
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||||
|
#.load_state_dict(state_dict)
|
||||||
|
return model
|
200
model/mobilenet_v3.py
Normal file
200
model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
'''MobileNetV3 in PyTorch.
|
||||||
|
|
||||||
|
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||||
|
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import init
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
|
||||||
|
class hswish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class hsigmoid(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.relu6(x + 3, inplace=True) / 6
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SeModule(nn.Module):
|
||||||
|
def __init__(self, in_size, reduction=4):
|
||||||
|
super(SeModule, self).__init__()
|
||||||
|
self.se = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d(1),
|
||||||
|
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(in_size // reduction),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(in_size),
|
||||||
|
hsigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.se(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
'''expand + depthwise + pointwise'''
|
||||||
|
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||||
|
super(Block, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.se = semodule
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||||
|
self.nolinear1 = nolinear
|
||||||
|
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||||
|
self.nolinear2 = nolinear
|
||||||
|
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_size)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride == 1 and in_size != out_size:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(out_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
if self.se != None:
|
||||||
|
out = self.se(out)
|
||||||
|
out = out + self.shortcut(x) if self.stride==1 else out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3_Large(nn.Module):
|
||||||
|
def __init__(self, num_classes=conf.embedding_size):
|
||||||
|
super(MobileNetV3_Large, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(16)
|
||||||
|
self.hs1 = hswish()
|
||||||
|
|
||||||
|
self.bneck = nn.Sequential(
|
||||||
|
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||||
|
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||||
|
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||||
|
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||||
|
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||||
|
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||||
|
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||||
|
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||||
|
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||||
|
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||||
|
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||||
|
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||||
|
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||||
|
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||||
|
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(960)
|
||||||
|
self.hs2 = hswish()
|
||||||
|
self.linear3 = nn.Linear(960, 1280)
|
||||||
|
self.bn3 = nn.BatchNorm1d(1280)
|
||||||
|
self.hs3 = hswish()
|
||||||
|
self.linear4 = nn.Linear(1280, num_classes)
|
||||||
|
self.init_params()
|
||||||
|
|
||||||
|
def init_params(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.hs1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bneck(out)
|
||||||
|
out = self.hs2(self.bn2(self.conv2(out)))
|
||||||
|
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.hs3(self.bn3(self.linear3(out)))
|
||||||
|
out = self.linear4(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3_Small(nn.Module):
|
||||||
|
def __init__(self, num_classes=conf.embedding_size):
|
||||||
|
super(MobileNetV3_Small, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(16)
|
||||||
|
self.hs1 = hswish()
|
||||||
|
|
||||||
|
self.bneck = nn.Sequential(
|
||||||
|
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||||
|
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||||
|
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||||
|
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||||
|
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||||
|
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||||
|
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||||
|
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||||
|
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||||
|
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||||
|
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(576)
|
||||||
|
self.hs2 = hswish()
|
||||||
|
self.linear3 = nn.Linear(576, 1280)
|
||||||
|
self.bn3 = nn.BatchNorm1d(1280)
|
||||||
|
self.hs3 = hswish()
|
||||||
|
self.linear4 = nn.Linear(1280, num_classes)
|
||||||
|
self.init_params()
|
||||||
|
|
||||||
|
def init_params(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.hs1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bneck(out)
|
||||||
|
out = self.hs2(self.bn2(self.conv2(out)))
|
||||||
|
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
|
||||||
|
out = self.hs3(self.bn3(self.linear3(out)))
|
||||||
|
out = self.linear4(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
net = MobileNetV3_Small()
|
||||||
|
x = torch.randn(2,3,224,224)
|
||||||
|
y = net(x)
|
||||||
|
print(y.size())
|
||||||
|
|
||||||
|
# test()
|
265
model/mobilevit.py
Normal file
265
model/mobilevit.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
|
||||||
|
def conv_1x1_bn(inp, oup):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(oup),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(oup),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreNorm(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.fn(self.norm(x), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
project_out = not (heads == 1 and dim_head == dim)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.attend = nn.Softmax(dim=-1)
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
) if project_out else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||||
|
|
||||||
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||||
|
attn = self.attend(dots)
|
||||||
|
out = torch.matmul(attn, v)
|
||||||
|
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||||
|
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
x = attn(x) + x
|
||||||
|
x = ff(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MV2Block(nn.Module):
|
||||||
|
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
hidden_dim = int(inp * expansion)
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
|
||||||
|
if expansion == 1:
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
# dw
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(oup),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
# pw
|
||||||
|
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
# dw
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(oup),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViTBlock(nn.Module):
|
||||||
|
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.ph, self.pw = patch_size
|
||||||
|
|
||||||
|
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||||
|
self.conv2 = conv_1x1_bn(channel, dim)
|
||||||
|
|
||||||
|
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||||
|
|
||||||
|
self.conv3 = conv_1x1_bn(dim, channel)
|
||||||
|
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x.clone()
|
||||||
|
|
||||||
|
# Local representations
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
# Global representations
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
|
||||||
|
pw=self.pw)
|
||||||
|
|
||||||
|
# Fusion
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = torch.cat((x, y), 1)
|
||||||
|
x = self.conv4(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MobileViT(nn.Module):
|
||||||
|
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
ih, iw = image_size
|
||||||
|
ph, pw = patch_size
|
||||||
|
assert ih % ph == 0 and iw % pw == 0
|
||||||
|
|
||||||
|
L = [2, 4, 3]
|
||||||
|
|
||||||
|
self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
|
||||||
|
|
||||||
|
self.mv2 = nn.ModuleList([])
|
||||||
|
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||||
|
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||||
|
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||||
|
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
|
||||||
|
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
|
||||||
|
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
|
||||||
|
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
|
||||||
|
|
||||||
|
self.mvit = nn.ModuleList([])
|
||||||
|
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
|
||||||
|
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
|
||||||
|
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
|
||||||
|
|
||||||
|
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
|
||||||
|
|
||||||
|
self.pool = nn.AvgPool2d(ih // 32, 1)
|
||||||
|
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
#print('x',x.shape)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.mv2[0](x)
|
||||||
|
|
||||||
|
x = self.mv2[1](x)
|
||||||
|
x = self.mv2[2](x)
|
||||||
|
x = self.mv2[3](x) # Repeat
|
||||||
|
|
||||||
|
x = self.mv2[4](x)
|
||||||
|
x = self.mvit[0](x)
|
||||||
|
|
||||||
|
x = self.mv2[5](x)
|
||||||
|
x = self.mvit[1](x)
|
||||||
|
|
||||||
|
x = self.mv2[6](x)
|
||||||
|
x = self.mvit[2](x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
|
||||||
|
#print('pool_before',x.shape)
|
||||||
|
x = self.pool(x).view(-1, x.shape[1])
|
||||||
|
#print('self_pool',self.pool)
|
||||||
|
#print('pool_after',x.shape)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mobilevit_xxs():
|
||||||
|
dims = [64, 80, 96]
|
||||||
|
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
|
||||||
|
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
|
||||||
|
|
||||||
|
|
||||||
|
def mobilevit_xs():
|
||||||
|
dims = [96, 120, 144]
|
||||||
|
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
|
||||||
|
return MobileViT((256, 256), dims, channels, num_classes=1000)
|
||||||
|
|
||||||
|
|
||||||
|
def mobilevit_s():
|
||||||
|
dims = [144, 192, 240]
|
||||||
|
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
|
||||||
|
return MobileViT((conf.img_size, conf.img_size), dims, channels, num_classes=conf.embedding_size)
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
img = torch.randn(5, 3, 256, 256)
|
||||||
|
|
||||||
|
vit = mobilevit_xxs()
|
||||||
|
out = vit(img)
|
||||||
|
print(out.shape)
|
||||||
|
print(count_parameters(vit))
|
||||||
|
|
||||||
|
vit = mobilevit_xs()
|
||||||
|
out = vit(img)
|
||||||
|
print(out.shape)
|
||||||
|
print(count_parameters(vit))
|
||||||
|
|
||||||
|
vit = mobilevit_s()
|
||||||
|
out = vit(img)
|
||||||
|
print(out.shape)
|
||||||
|
print(count_parameters(vit))
|
412
model/quant_test_resnet.py
Normal file
412
model/quant_test_resnet.py
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn as nn
|
||||||
|
from .utils import load_state_dict_from_url
|
||||||
|
from typing import Type, Any, Callable, Union, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||||
|
'wide_resnet50_2', 'wide_resnet101_2']
|
||||||
|
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||||
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||||
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||||
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||||
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||||
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||||
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||||
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion: int = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inplanes: int,
|
||||||
|
planes: int,
|
||||||
|
stride: int = 1,
|
||||||
|
downsample: Optional[nn.Module] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
base_width: int = 64,
|
||||||
|
dilation: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
if dilation > 1:
|
||||||
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bn1 = norm_layer(planes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizableBasicBlock(BasicBlock):
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.add_relu = torch.nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out = self.add_relu.add_relu(out, identity)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||||
|
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||||
|
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||||
|
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||||
|
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||||
|
|
||||||
|
expansion: int = 4
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inplanes: int,
|
||||||
|
planes: int,
|
||||||
|
stride: int = 1,
|
||||||
|
downsample: Optional[nn.Module] = None,
|
||||||
|
groups: int = 1,
|
||||||
|
base_width: int = 64,
|
||||||
|
dilation: int = 1,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
|
self.bn1 = norm_layer(width)
|
||||||
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block: Type[Union[BasicBlock, Bottleneck]],
|
||||||
|
layers: List[int],
|
||||||
|
num_classes: int = 1000,
|
||||||
|
zero_init_residual: bool = False,
|
||||||
|
groups: int = 1,
|
||||||
|
width_per_group: int = 64,
|
||||||
|
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||||
|
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||||
|
) -> None:
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self._norm_layer = norm_layer
|
||||||
|
|
||||||
|
self.inplanes = 64
|
||||||
|
self.dilation = 1
|
||||||
|
if replace_stride_with_dilation is None:
|
||||||
|
# each element in the tuple indicates if we should replace
|
||||||
|
# the 2x2 stride with a dilated convolution instead
|
||||||
|
replace_stride_with_dilation = [False, False, False]
|
||||||
|
if len(replace_stride_with_dilation) != 3:
|
||||||
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[1])
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[2])
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# Zero-initialize the last BN in each residual branch,
|
||||||
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||||
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||||
|
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||||
|
norm_layer = self._norm_layer
|
||||||
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
norm_layer(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||||
|
base_width=self.base_width, dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||||
|
# See note [TorchScript super()]
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet(
|
||||||
|
arch: str,
|
||||||
|
block: Type[Union[BasicBlock, Bottleneck]],
|
||||||
|
layers: List[int],
|
||||||
|
pretrained: bool,
|
||||||
|
progress: bool,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> ResNet:
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
progress=progress)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNet-18 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||||
|
return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNet-34 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNet-50 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNet-101 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNet-152 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNeXt-50 32x4d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""ResNeXt-101 32x8d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""Wide ResNet-50-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||||
|
r"""Wide ResNet-101-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
142
model/resbam.py
Normal file
142
model/resbam.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from model.CBAM import CBAM
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from model.Tool import GeM as gem
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inchannel, outchannel, stride=1, dowsample=None):
|
||||||
|
# super(Bottleneck, self).__init__()
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(outchannel)
|
||||||
|
self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False,
|
||||||
|
stride=stride, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm2d(outchannel)
|
||||||
|
self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False,
|
||||||
|
kernel_size=1)
|
||||||
|
self.bn3 = nn.BatchNorm2d(outchannel * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = dowsample
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.identity = x
|
||||||
|
# print('>>>>>>>>',type(x))
|
||||||
|
if self.downsample is not None:
|
||||||
|
# print('>>>>downsample>>>>', type(self.downsample))
|
||||||
|
self.identity = self.downsample(x)
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
# print('>>>>out>>>identity',out.size(),self.identity.size())
|
||||||
|
out = out + self.identity
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class resnet(nn.Module):
|
||||||
|
def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channel = 64
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=3,
|
||||||
|
out_channels=self.in_channel,
|
||||||
|
stride=2,
|
||||||
|
kernel_size=7,
|
||||||
|
padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.cbam = CBAM(self.in_channel)
|
||||||
|
self.cbam1 = CBAM(2048)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
|
||||||
|
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.gem = gem()
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_class)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal(m.weight, mode='fan_out',
|
||||||
|
nonlinearity='relu')
|
||||||
|
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
nn.init.constant_(m.bias, 1.0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, channel, block_num, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(channel * block.expansion))
|
||||||
|
layer = []
|
||||||
|
layer.append(block(self.in_channel, channel, stride, downsample))
|
||||||
|
self.in_channel = channel * block.expansion
|
||||||
|
for _ in range(1, block_num):
|
||||||
|
layer.append(block(self.in_channel, channel))
|
||||||
|
return nn.Sequential(*layer)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
x = self.cbam(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.cbam1(x)
|
||||||
|
# x = self.avgpool(x)
|
||||||
|
x = self.gem(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TripletNet(nn.Module):
|
||||||
|
def __init__(self, num_class, flag=True):
|
||||||
|
super(TripletNet, self).__init__()
|
||||||
|
self.initnet = rescbam(num_class)
|
||||||
|
self.flag = flag
|
||||||
|
|
||||||
|
def forward(self, x1, x2=None, x3=None):
|
||||||
|
if self.flag:
|
||||||
|
output1 = self.initnet(x1)
|
||||||
|
output2 = self.initnet(x2)
|
||||||
|
output3 = self.initnet(x3)
|
||||||
|
return output1, output2, output3
|
||||||
|
else:
|
||||||
|
output = self.initnet(x1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rescbam(num_class):
|
||||||
|
return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
input1 = torch.randn(4, 3, 640, 640)
|
||||||
|
input2 = torch.randn(4, 3, 640, 640)
|
||||||
|
input3 = torch.randn(4, 3, 640, 640)
|
||||||
|
|
||||||
|
# rescbam测试
|
||||||
|
# Resnet50 = rescbam(512)
|
||||||
|
# output = Resnet50.forward(input1)
|
||||||
|
# print(Resnet50)
|
||||||
|
|
||||||
|
# trnet测试
|
||||||
|
trnet = TripletNet(512)
|
||||||
|
output = trnet(input1, input2, input3)
|
||||||
|
print(output)
|
189
model/resnet.py
Normal file
189
model/resnet.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
"""resnet in pytorch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||||
|
|
||||||
|
Deep Residual Learning for Image Recognition
|
||||||
|
https://arxiv.org/abs/1512.03385v1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from config import config as conf
|
||||||
|
from CBAM import CBAM
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
"""Basic Block for resnet 18 and resnet 34
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
#BasicBlock and BottleNeck block
|
||||||
|
#have different output size
|
||||||
|
#we use class attribute expansion
|
||||||
|
#to distinct
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
#residual function
|
||||||
|
self.residual_function = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||||
|
)
|
||||||
|
|
||||||
|
#shortcut
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
|
||||||
|
#the shortcut output dimension is not the same with residual function
|
||||||
|
#use 1*1 convolution to match the dimension
|
||||||
|
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||||
|
|
||||||
|
class BottleNeck(nn.Module):
|
||||||
|
"""Residual block for resnet over 50 layers
|
||||||
|
|
||||||
|
"""
|
||||||
|
expansion = 4
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
self.residual_function = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
|
||||||
|
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block, num_block, cbam = False, num_classes=conf.embedding_size):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = 64
|
||||||
|
|
||||||
|
# self.conv1 = nn.Sequential(
|
||||||
|
# nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||||
|
# nn.BatchNorm2d(64),
|
||||||
|
# nn.ReLU(inplace=True))
|
||||||
|
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 64,stride=2,kernel_size=7,padding=3,bias=False),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||||
|
|
||||||
|
self.cbam = CBAM(self.in_channels)
|
||||||
|
|
||||||
|
#we use a different inputsize than the original paper
|
||||||
|
#so conv2_x's stride is 1
|
||||||
|
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||||
|
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||||
|
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||||
|
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||||
|
self.cbam1 = CBAM(self.in_channels)
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal(m.weight,mode = 'fan_out',
|
||||||
|
nonlinearity='relu')
|
||||||
|
if isinstance(m, (nn.BatchNorm2d)):
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
nn.init.constant_(m.bias, 1.0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||||
|
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||||
|
same as a neuron netowork layer, ex. conv layer), one layer may
|
||||||
|
contain more than one residual block
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: block type, basic block or bottle neck block
|
||||||
|
out_channels: output depth channel number of this layer
|
||||||
|
num_blocks: how many blocks per layer
|
||||||
|
stride: the stride of the first block of this layer
|
||||||
|
|
||||||
|
Return:
|
||||||
|
return a resnet layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# we have num_block blocks per layer, the first block
|
||||||
|
# could be 1 or 2, other blocks would always be 1
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_channels, out_channels, stride))
|
||||||
|
self.in_channels = out_channels * block.expansion
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.conv1(x)
|
||||||
|
if cbam:
|
||||||
|
output = self.cbam(x)
|
||||||
|
output = self.conv2_x(output)
|
||||||
|
output = self.conv3_x(output)
|
||||||
|
output = self.conv4_x(output)
|
||||||
|
output = self.conv5_x(output)
|
||||||
|
if cbam:
|
||||||
|
output = self.cbam1(x)
|
||||||
|
print('pollBefore',output.shape)
|
||||||
|
output = self.avg_pool(output)
|
||||||
|
print('poolAfter',output.shape)
|
||||||
|
output = output.view(output.size(0), -1)
|
||||||
|
print('fcBefore',output.shape)
|
||||||
|
output = self.fc(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def resnet18(cbam = False):
|
||||||
|
""" return a ResNet 18 object
|
||||||
|
"""
|
||||||
|
return ResNet(BasicBlock, [2, 2, 2, 2], cbam)
|
||||||
|
|
||||||
|
def resnet34():
|
||||||
|
""" return a ResNet 34 object
|
||||||
|
"""
|
||||||
|
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||||
|
|
||||||
|
def resnet50():
|
||||||
|
""" return a ResNet 50 object
|
||||||
|
"""
|
||||||
|
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||||
|
|
||||||
|
def resnet101():
|
||||||
|
""" return a ResNet 101 object
|
||||||
|
"""
|
||||||
|
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||||
|
|
||||||
|
def resnet152():
|
||||||
|
""" return a ResNet 152 object
|
||||||
|
"""
|
||||||
|
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||||
|
|
||||||
|
|
271
model/resnet_attention.py
Normal file
271
model/resnet_attention.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAttention(nn.Module):
|
||||||
|
"""通道注意力模块,通过全局平均池化和最大池化提取特征,经过MLP生成通道权重"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, reduction_ratio=16):
|
||||||
|
super(ChannelAttention, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||||
|
|
||||||
|
# 共享的MLP层
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
avg_out = self.fc(self.avg_pool(x))
|
||||||
|
max_out = self.fc(self.max_pool(x))
|
||||||
|
out = avg_out + max_out
|
||||||
|
return torch.sigmoid(out)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
"""空间注意力模块,通过通道维度的平均和最大值操作,生成空间权重"""
|
||||||
|
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(SpatialAttention, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||||
|
out = torch.cat([avg_out, max_out], dim=1)
|
||||||
|
out = self.conv(out)
|
||||||
|
return torch.sigmoid(out)
|
||||||
|
|
||||||
|
|
||||||
|
class CBAM(nn.Module):
|
||||||
|
"""CBAM注意力模块,串联通道注意力和空间注意力"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
|
||||||
|
super(CBAM, self).__init__()
|
||||||
|
self.channel_att = ChannelAttention(in_channels, reduction_ratio)
|
||||||
|
self.spatial_att = SpatialAttention(kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x * self.channel_att(x)
|
||||||
|
x = x * self.spatial_att(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
"""ResNet基础残差块,适用于ResNet18和ResNet34"""
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# 是否使用CBAM注意力机制
|
||||||
|
self.use_cbam = use_cbam
|
||||||
|
if use_cbam:
|
||||||
|
self.cbam = CBAM(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
# # 如果使用注意力机制,应用CBAM
|
||||||
|
if self.use_cbam:
|
||||||
|
out = self.cbam(out)
|
||||||
|
|
||||||
|
# 如果有下采样,调整shortcut连接
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
# 残差连接
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
"""ResNet瓶颈残差块,适用于ResNet50及更深的网络"""
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
# 1x1卷积降维
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||||
|
# 3x3卷积
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||||
|
# 1x1卷积升维
|
||||||
|
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# 是否使用CBAM注意力机制
|
||||||
|
self.use_cbam = use_cbam
|
||||||
|
if use_cbam:
|
||||||
|
self.cbam = CBAM(out_channels * self.expansion)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
# # 如果使用注意力机制,应用CBAM
|
||||||
|
if self.use_cbam:
|
||||||
|
out = self.cbam(out)
|
||||||
|
|
||||||
|
# 如果有下采样,调整shortcut连接
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
# 残差连接
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
"""集成了CBAM注意力机制的ResNet模型"""
|
||||||
|
|
||||||
|
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, use_cbam=True):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
self.in_channels = 64
|
||||||
|
self.use_cbam = use_cbam
|
||||||
|
|
||||||
|
# 初始卷积层
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||||
|
self.cbam1 = CBAM(64)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
# 残差块层
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
|
||||||
|
self.cbam2 = CBAM(512)
|
||||||
|
# 全局平均池化和分类器
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
# 初始化权重
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# 零初始化最后一个BN层的权重,使残差分支初始为0
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, out_channels, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
# 如果通道数不匹配或需要调整步长,创建下采样层
|
||||||
|
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
# 第一个块可能需要下采样
|
||||||
|
layers.append(block(self.in_channels, out_channels, stride, downsample, use_cbam=self.use_cbam))
|
||||||
|
self.in_channels = out_channels * block.expansion
|
||||||
|
|
||||||
|
# 添加剩余的块
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.in_channels, out_channels, use_cbam=self.use_cbam))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 特征提取
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
# if self.use_cbam:
|
||||||
|
# x = self.cbam1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
# if self.use_cbam:
|
||||||
|
# x = self.cbam2(x)
|
||||||
|
# 分类
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# 工厂函数,创建不同深度的ResNet模型
|
||||||
|
def resnet18_cbam(pretrained=False, **kwargs):
|
||||||
|
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34_cbam(pretrained=False, **kwargs):
|
||||||
|
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50_cbam(pretrained=False, **kwargs):
|
||||||
|
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101_cbam(pretrained=False, **kwargs):
|
||||||
|
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152_cbam(pretrained=False, **kwargs):
|
||||||
|
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 创建一个带有CBAM注意力机制的ResNet50模型
|
||||||
|
model = resnet50_cbam(num_classes=10)
|
||||||
|
# 测试输入
|
||||||
|
x = torch.randn(1, 3, 224, 224)
|
||||||
|
y = model(x)
|
||||||
|
print(f"输入形状: {x.shape}")
|
||||||
|
print(f"输出形状: {y.shape}")
|
480
model/resnet_pre.py
Normal file
480
model/resnet_pre.py
Normal file
@ -0,0 +1,480 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from config import config as conf
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
except ImportError:
|
||||||
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||||
|
# from .utils import load_state_dict_from_url
|
||||||
|
|
||||||
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||||
|
'wide_resnet50_2', 'wide_resnet101_2']
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||||
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||||
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||||
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||||
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||||
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||||
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||||
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttention(nn.Module):
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(SpatialAttention, self).__init__()
|
||||||
|
|
||||||
|
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||||
|
padding = 3 if kernel_size == 7 else 1
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||||
|
x = torch.cat([avg_out, max_out], dim=1)
|
||||||
|
x = self.conv1(x)
|
||||||
|
return self.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
if dilation > 1:
|
||||||
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||||
|
self.cam = cam
|
||||||
|
self.bam = bam
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bn1 = norm_layer(planes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
if self.cam:
|
||||||
|
if planes == 64:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||||
|
elif planes == 128:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||||
|
elif planes == 256:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||||
|
elif planes == 512:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||||
|
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
if self.bam:
|
||||||
|
self.bam = SpatialAttention()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
if self.cam:
|
||||||
|
ori_out = self.globalAvgPool(out)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.fc1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.fc2(out)
|
||||||
|
out = self.sigmod(out)
|
||||||
|
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||||
|
out = out * ori_out
|
||||||
|
|
||||||
|
if self.bam:
|
||||||
|
out = out * self.bam(out)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||||
|
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||||
|
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||||
|
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||||
|
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||||
|
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
|
self.cam = cam
|
||||||
|
self.bam = bam
|
||||||
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
|
self.bn1 = norm_layer(width)
|
||||||
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
if self.cam:
|
||||||
|
if planes == 64:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||||
|
elif planes == 128:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||||
|
elif planes == 256:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||||
|
elif planes == 512:
|
||||||
|
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||||
|
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
if self.cam:
|
||||||
|
ori_out = self.globalAvgPool(out)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.fc1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.fc2(out)
|
||||||
|
out = self.sigmod(out)
|
||||||
|
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||||
|
out = out * ori_out
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||||
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None, scale=conf.channel_ratio):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self._norm_layer = norm_layer
|
||||||
|
print("ResNet scale: >>>>>>>>>> ", scale)
|
||||||
|
self.inplanes = 64
|
||||||
|
self.dilation = 1
|
||||||
|
if replace_stride_with_dilation is None:
|
||||||
|
# each element in the tuple indicates if we should replace
|
||||||
|
# the 2x2 stride with a dilated convolution instead
|
||||||
|
replace_stride_with_dilation = [False, False, False]
|
||||||
|
if len(replace_stride_with_dilation) != 3:
|
||||||
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||||
|
self.maxpool2 = nn.Sequential(
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||||
|
)
|
||||||
|
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||||
|
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[1])
|
||||||
|
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
|
||||||
|
dilate=replace_stride_with_dilation[2])
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
# Zero-initialize the last BN in each residual branch,
|
||||||
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||||
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
norm_layer = self._norm_layer
|
||||||
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
|
norm_layer(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||||
|
base_width=self.base_width, dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _forward_impl(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
# model = ResNet(block, layers, **kwargs)
|
||||||
|
# if pretrained:
|
||||||
|
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
# progress=progress)
|
||||||
|
# model.load_state_dict(state_dict, strict=False)
|
||||||
|
# return model
|
||||||
|
|
||||||
|
class CustomResNet18(nn.Module):
|
||||||
|
def __init__(self, model, num_classes=conf.custom_num_classes):
|
||||||
|
super(CustomResNet18, self).__init__()
|
||||||
|
self.custom_model = nn.Sequential(*list(model.children())[:-1])
|
||||||
|
self.fc = nn.Linear(model.fc.in_features, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.custom_model(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
progress=progress)
|
||||||
|
|
||||||
|
src_state_dict = state_dict
|
||||||
|
target_state_dict = model.state_dict()
|
||||||
|
skip_keys = []
|
||||||
|
# skip mismatch size tensors in case of pretraining
|
||||||
|
for k in src_state_dict.keys():
|
||||||
|
if k not in target_state_dict:
|
||||||
|
continue
|
||||||
|
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||||
|
skip_keys.append(k)
|
||||||
|
for k in skip_keys:
|
||||||
|
del src_state_dict[k]
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||||
|
r"""ResNet-14 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||||
|
r"""ResNet-18 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
**kwargs: Additional arguments passed to ResNet, including:
|
||||||
|
scale (float): Channel scaling ratio (default: conf.channel_ratio)
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-34 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-50 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-101 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-152 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-50 32x4d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-101 32x8d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-50-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-101-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
4
model/utils.py
Normal file
4
model/utils.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
try:
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
except ImportError:
|
||||||
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
42
model/vit.py
Normal file
42
model/vit.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partial, reduce
|
||||||
|
from operator import mul
|
||||||
|
|
||||||
|
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'vit_small',
|
||||||
|
'vit_base',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def vit_small(**kwargs):
|
||||||
|
model = VisionTransformer(
|
||||||
|
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||||
|
# model.default_cfg = _cfg()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def vit_base(**kwargs):
|
||||||
|
model = VisionTransformer(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||||
|
model.default_cfg = _cfg(num_classes=256)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
img = torch.randn(8, 3, 224, 224)
|
||||||
|
vit = vit_base()
|
||||||
|
out = vit(img)
|
||||||
|
print(out.shape)
|
||||||
|
# print(count_parameters(vit))
|
331
test_ori.py
Normal file
331
test_ori.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# from config import config as conf
|
||||||
|
from tools.dataset import get_transform
|
||||||
|
from configs import trainer_tools
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
with open('configs/test.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
# Constants from config
|
||||||
|
embedding_size = conf["base"]["embedding_size"]
|
||||||
|
img_size = conf["transform"]["img_size"]
|
||||||
|
device = conf["base"]["device"]
|
||||||
|
|
||||||
|
def unique_image(pair_list: str) -> Set[str]:
|
||||||
|
unique_images = set()
|
||||||
|
try:
|
||||||
|
with open(pair_list, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
img1, img2, _ = line.split()
|
||||||
|
unique_images.update([img1, img2])
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Skipping malformed line: {line}")
|
||||||
|
except IOError as e:
|
||||||
|
print(f"Error reading pair list file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return unique_images
|
||||||
|
|
||||||
|
|
||||||
|
def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
|
||||||
|
"""
|
||||||
|
Group image paths into batches of specified size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: Set of image paths to group
|
||||||
|
batch_size: Number of images per batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of batches, where each batch is a list of image paths
|
||||||
|
"""
|
||||||
|
image_list = list(images)
|
||||||
|
num_images = len(image_list)
|
||||||
|
batches = []
|
||||||
|
|
||||||
|
for i in range(0, num_images, batch_size):
|
||||||
|
batch_end = min(i + batch_size, num_images)
|
||||||
|
batches.append(image_list[i:batch_end])
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess(images: list, transform) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
im = Image.open(img)
|
||||||
|
im = transform(im)
|
||||||
|
res.append(im)
|
||||||
|
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||||||
|
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess(images: list, transform) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
im = Image.open(img)
|
||||||
|
if im.mode == 'RGBA':
|
||||||
|
im = im.convert('RGB')
|
||||||
|
im = transform(im)
|
||||||
|
res.append(im)
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def featurize(
|
||||||
|
images: List[str],
|
||||||
|
transform: callable,
|
||||||
|
net: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
train: bool = False
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
try:
|
||||||
|
# Select appropriate preprocessing
|
||||||
|
preprocess_fn = _preprocess if train else test_preprocess
|
||||||
|
|
||||||
|
# Preprocess and move to device
|
||||||
|
data = preprocess_fn(images, transform)
|
||||||
|
data = data.to(device)
|
||||||
|
net = net.to(device)
|
||||||
|
|
||||||
|
# Extract features with automatic mixed precision
|
||||||
|
with torch.no_grad():
|
||||||
|
if conf['models']['half']:
|
||||||
|
data = data.half()
|
||||||
|
features = net(data)
|
||||||
|
# Create path-to-feature mapping
|
||||||
|
return {img: feature for img, feature in zip(images, features)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in feature extraction: {e}")
|
||||||
|
raise
|
||||||
|
def cosin_metric(x1, x2):
|
||||||
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||||
|
def threshold_search(y_score, y_true):
|
||||||
|
y_score = np.asarray(y_score)
|
||||||
|
y_true = np.asarray(y_true)
|
||||||
|
best_acc = 0
|
||||||
|
best_th = 0
|
||||||
|
for i in range(len(y_score)):
|
||||||
|
th = y_score[i]
|
||||||
|
y_test = (y_score >= th)
|
||||||
|
acc = np.mean((y_test == y_true).astype(int))
|
||||||
|
if acc > best_acc:
|
||||||
|
best_acc = acc
|
||||||
|
best_th = th
|
||||||
|
return best_acc, best_th
|
||||||
|
|
||||||
|
|
||||||
|
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||||
|
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||||
|
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||||
|
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||||
|
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||||
|
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
# plt.ylabel('Similarity')
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
plt.savefig('grid.png')
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def showHist(same, cross):
|
||||||
|
Same = np.array(same)
|
||||||
|
Cross = np.array(cross)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||||
|
axs[0].set_xlim([-0.1, 1])
|
||||||
|
axs[0].set_title('Same Barcode')
|
||||||
|
|
||||||
|
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||||
|
axs[1].set_xlim([-0.1, 1])
|
||||||
|
axs[1].set_title('Cross Barcode')
|
||||||
|
plt.savefig('plot.png')
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy_recall(score, labels):
|
||||||
|
th = 0.1
|
||||||
|
squence = np.linspace(-1, 1, num=50)
|
||||||
|
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||||
|
Same = score[:len(score) // 2]
|
||||||
|
Cross = score[len(score) // 2:]
|
||||||
|
for th in squence:
|
||||||
|
t_score = (score > th)
|
||||||
|
t_labels = (labels == 1)
|
||||||
|
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||||
|
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||||
|
f_score = (score < th)
|
||||||
|
f_labels = (labels == 0)
|
||||||
|
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||||
|
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||||
|
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||||
|
|
||||||
|
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||||
|
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||||
|
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||||
|
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||||
|
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||||
|
|
||||||
|
showHist(Same, Cross)
|
||||||
|
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy(
|
||||||
|
feature_dict: Dict[str, torch.Tensor],
|
||||||
|
pair_list: str,
|
||||||
|
test_root: str
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
try:
|
||||||
|
with open(pair_list, 'r') as f:
|
||||||
|
pairs = f.readlines()
|
||||||
|
except IOError as e:
|
||||||
|
print(f"Error reading pair list: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
similarities = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for pair in pairs:
|
||||||
|
pair = pair.strip()
|
||||||
|
if not pair:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
img1, img2, label = pair.split()
|
||||||
|
img1_path = osp.join(test_root, img1)
|
||||||
|
img2_path = osp.join(test_root, img2)
|
||||||
|
|
||||||
|
# Verify features exist
|
||||||
|
if img1_path not in feature_dict or img2_path not in feature_dict:
|
||||||
|
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
||||||
|
|
||||||
|
# Get features and compute similarity
|
||||||
|
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||||
|
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||||
|
similarity = cosin_metric(feat1, feat2)
|
||||||
|
|
||||||
|
similarities.append(similarity)
|
||||||
|
labels.append(int(label))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find optimal threshold and accuracy
|
||||||
|
accuracy, threshold = threshold_search(similarities, labels)
|
||||||
|
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||||||
|
|
||||||
|
return accuracy, threshold
|
||||||
|
|
||||||
|
|
||||||
|
def deal_group_pair(pairList1, pairList2):
|
||||||
|
allsimilarity = []
|
||||||
|
one_similarity = []
|
||||||
|
for pair1 in pairList1:
|
||||||
|
for pair2 in pairList2:
|
||||||
|
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||||
|
one_similarity.append(similarity)
|
||||||
|
allsimilarity.append(max(one_similarity)) # 最大值
|
||||||
|
# allsimilarity.append(sum(one_similarity) / len(one_similarity)) # 均值
|
||||||
|
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
|
||||||
|
# print(allsimilarity)
|
||||||
|
# print(labels)
|
||||||
|
return allsimilarity
|
||||||
|
|
||||||
|
|
||||||
|
def compute_group_accuracy(content_list_read):
|
||||||
|
allSimilarity, allLabel = [], []
|
||||||
|
Same, Cross = [], []
|
||||||
|
for data_loaded in content_list_read:
|
||||||
|
print(data_loaded)
|
||||||
|
one_group_list = []
|
||||||
|
try:
|
||||||
|
for i in range(2):
|
||||||
|
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||||
|
group = group_image(images, conf.test_batch_size)
|
||||||
|
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||||
|
one_group_list.append(d.values())
|
||||||
|
if data_loaded[-1] == '1':
|
||||||
|
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||||
|
Same.append(similarity)
|
||||||
|
else:
|
||||||
|
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||||
|
Cross.append(similarity)
|
||||||
|
allLabel.append(data_loaded[-1])
|
||||||
|
allSimilarity.extend(similarity)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
# print(allSimilarity)
|
||||||
|
# print(allLabel)
|
||||||
|
return allSimilarity, allLabel
|
||||||
|
|
||||||
|
|
||||||
|
def init_model():
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
print('load model {} '.format(conf['models']['backbone']))
|
||||||
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
|
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||||
|
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||||
|
if conf['models']['half']:
|
||||||
|
model.half()
|
||||||
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
|
||||||
|
if conf.model_half:
|
||||||
|
model.half()
|
||||||
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = init_model()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if not conf['data']['group_test']:
|
||||||
|
images = unique_image(conf['data']['test_list'])
|
||||||
|
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||||||
|
groups = group_image(images, conf['data']['test_batch_size']) # 根据batch_size取图片
|
||||||
|
feature_dict = dict()
|
||||||
|
_, test_transform = get_transform(conf)
|
||||||
|
for group in groups:
|
||||||
|
d = featurize(group, test_transform, model, conf['base']['device'])
|
||||||
|
feature_dict.update(d)
|
||||||
|
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
|
||||||
|
print(
|
||||||
|
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||||||
|
)
|
||||||
|
elif conf['data']['group_test']:
|
||||||
|
filename = conf['data']['test_group_json']
|
||||||
|
with open(filename, 'r', encoding='utf-8') as file:
|
||||||
|
content_list_read = json.load(file)
|
||||||
|
Similarity, Label = compute_group_accuracy(content_list_read)
|
||||||
|
compute_accuracy_recall(np.array(Similarity), np.array(Label))
|
||||||
|
# compute_group_accuracy(data_loaded)
|
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
BIN
tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
Normal file
BIN
tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
Normal file
Binary file not shown.
68
tools/dataset.py
Normal file
68
tools/dataset.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
|
import torchvision.transforms as T
|
||||||
|
# from config import config as conf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def pad_to_square(img):
|
||||||
|
w, h = img.size
|
||||||
|
max_wh = max(w, h)
|
||||||
|
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||||
|
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||||
|
|
||||||
|
def get_transform(cfg):
|
||||||
|
train_transform = T.Compose([
|
||||||
|
T.Lambda(pad_to_square), # 补边
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||||
|
# T.RandomCrop(img_size * 4 // 5),
|
||||||
|
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
|
||||||
|
T.RandomRotation(cfg['transform']['RandomRotation']),
|
||||||
|
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||||
|
])
|
||||||
|
test_transform = T.Compose([
|
||||||
|
# T.Lambda(pad_to_square), # 补边
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||||
|
T.ConvertImageDtype(torch.float32),
|
||||||
|
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||||
|
])
|
||||||
|
return train_transform, test_transform
|
||||||
|
|
||||||
|
def load_data(training=True, cfg=None):
|
||||||
|
train_transform, test_transform = get_transform(cfg)
|
||||||
|
if training:
|
||||||
|
dataroot = cfg['data']['data_train_dir']
|
||||||
|
transform = train_transform
|
||||||
|
# transform = conf.train_transform
|
||||||
|
batch_size = cfg['data']['train_batch_size']
|
||||||
|
else:
|
||||||
|
dataroot = cfg['data']['data_val_dir']
|
||||||
|
# transform = conf.test_transform
|
||||||
|
transform = test_transform
|
||||||
|
batch_size = cfg['data']['val_batch_size']
|
||||||
|
|
||||||
|
data = ImageFolder(dataroot, transform=transform)
|
||||||
|
class_num = len(data.classes)
|
||||||
|
loader = DataLoader(data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=cfg['base']['pin_memory'],
|
||||||
|
num_workers=cfg['data']['num_workers'],
|
||||||
|
drop_last=True)
|
||||||
|
return loader, class_num
|
||||||
|
|
||||||
|
# def load_gift_data(action):
|
||||||
|
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
||||||
|
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||||
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
|
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||||
|
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||||
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
|
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||||
|
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||||
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
|
# return train_dataset, val_dataset, test_dataset
|
10
tools/dataset.txt
Normal file
10
tools/dataset.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||||
|
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||||
|
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||||
|
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||||
|
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||||
|
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||||
|
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||||
|
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||||
|
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||||
|
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
112
tools/fp32comparefp16.py
Normal file
112
tools/fp32comparefp16.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from test_ori import group_image, init_model, featurize
|
||||||
|
from config import config as conf
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
|
||||||
|
if dataTest:
|
||||||
|
norm_values_pf16 = torch.norm(values_pf16, p=2)
|
||||||
|
norm_values_pf32 = torch.norm(values_pf32, p=2)
|
||||||
|
euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
|
||||||
|
print(f"欧几里得距离: {euclidean_distance}")
|
||||||
|
cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
|
||||||
|
print(f"余弦相似度: {cosine_sim}")
|
||||||
|
else:
|
||||||
|
|
||||||
|
pass
|
||||||
|
def cosin_metric(x1, x2, fp32=True):
|
||||||
|
if fp32:
|
||||||
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||||
|
else:
|
||||||
|
x1_fp16 = x1.astype(np.float16)
|
||||||
|
x2_fp16 = x2.astype(np.float16)
|
||||||
|
# print(type(x1))
|
||||||
|
# pdb.set_trace()
|
||||||
|
return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
|
||||||
|
def deal_group_pair(pairList1, pairList2):
|
||||||
|
one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
|
||||||
|
for pair1 in pairList1:
|
||||||
|
for pair2 in pairList2:
|
||||||
|
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||||
|
one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
|
||||||
|
one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
|
||||||
|
allsimilarity_fp32.append(one_similarity_fp32)
|
||||||
|
allsimilarity_fp16.append(one_similarity_fp16)
|
||||||
|
one_similarity_fp16, one_similarity_fp32 = [], []
|
||||||
|
return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
|
||||||
|
|
||||||
|
def compute_group_accuracy(content_list_read, model):
|
||||||
|
allSimilarity, allLabel = [], []
|
||||||
|
Same, Cross = [], []
|
||||||
|
flag_same = True
|
||||||
|
flag_diff = True
|
||||||
|
for data_loaded in content_list_read:
|
||||||
|
one_group_list = []
|
||||||
|
try:
|
||||||
|
if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
|
||||||
|
for i in range(2):
|
||||||
|
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||||
|
group = group_image(images, conf.test_batch_size)
|
||||||
|
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||||
|
one_group_list.append(d.values())
|
||||||
|
if str(data_loaded[-1]) == '1':
|
||||||
|
flag_same = False
|
||||||
|
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||||
|
print('fp32 same-- >', allsimilarity_fp32)
|
||||||
|
print('fp16 same-- >', allsimilarity_fp16)
|
||||||
|
else:
|
||||||
|
flag_diff = False
|
||||||
|
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||||
|
print('fp32 diff-- >', allsimilarity_fp32)
|
||||||
|
print('fp16 diff-- >', allsimilarity_fp16)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
# print(allSimilarity)
|
||||||
|
# print(allLabel)
|
||||||
|
return allSimilarity, allLabel
|
||||||
|
def get_feature_list(imgPth):
|
||||||
|
imgs = get_files(imgPth)
|
||||||
|
group = group_image(imgs, conf.test_batch_size)
|
||||||
|
model = init_model()
|
||||||
|
model.eval()
|
||||||
|
fe = featurize(group[0], conf.test_transform, model, conf.device)
|
||||||
|
return fe
|
||||||
|
|
||||||
|
|
||||||
|
def get_files(imgPth):
|
||||||
|
imgsList = []
|
||||||
|
for img in os.walk(imgPth):
|
||||||
|
for img_name in img[2]:
|
||||||
|
img_path = os.sep.join([img[0], img_name])
|
||||||
|
imgsList.append(img_path)
|
||||||
|
return imgsList
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
def compare(imgPth, group=False):
|
||||||
|
model = init_model()
|
||||||
|
model.eval()
|
||||||
|
if not group:
|
||||||
|
values_pf16, values_pf32 = [], []
|
||||||
|
fe = get_feature_list(imgPth)
|
||||||
|
# pdb.set_trace()
|
||||||
|
values_pf32 += [value.cpu() for value in fe.values()]
|
||||||
|
values_pf16 += [value.cpu().half() for value in fe.values()]
|
||||||
|
for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
|
||||||
|
compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
|
||||||
|
else:
|
||||||
|
filename = conf.test_group_json
|
||||||
|
with open(filename, 'r', encoding='utf-8') as file:
|
||||||
|
content_list_read = json.load(file)
|
||||||
|
compute_group_accuracy(content_list_read, model)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
imgPth = './data/test/inner/3701375401900'
|
||||||
|
compare(imgPth)
|
369
tools/gift_assessment.py
Normal file
369
tools/gift_assessment.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append('../model')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from model.mlp import Net2, Net3, Net4
|
||||||
|
from model import resnet18
|
||||||
|
import torch
|
||||||
|
from gift_data_pretreatment import getFeatureList
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(pkl_flag):
|
||||||
|
res_pth = r"../checkpoints/resnet18_1009/best.pth"
|
||||||
|
if pkl_flag:
|
||||||
|
gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
|
||||||
|
gift_model = Net3(pretrained=True, num_classes=1)
|
||||||
|
gift_model.load_state_dict(torch.load(gift_pth))
|
||||||
|
else:
|
||||||
|
gift_pth = r'../checkpoints/gift_model/action3/best.pth'
|
||||||
|
gift_model = Net4('resnet18', True, True) # 预训练模型
|
||||||
|
try:
|
||||||
|
print('>>multiple_cards load pre model <<')
|
||||||
|
gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
||||||
|
torch.load(gift_pth,
|
||||||
|
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
|
||||||
|
except Exception as e:
|
||||||
|
print('>> load pre model <<')
|
||||||
|
gift_model.load_state_dict(torch.load(gift_pth,
|
||||||
|
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
||||||
|
res_model = resnet18()
|
||||||
|
res_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
||||||
|
torch.load(res_pth, map_location=torch.device(device)).items()})
|
||||||
|
return res_model, gift_model
|
||||||
|
|
||||||
|
|
||||||
|
def showHist(nongifts, gifts):
|
||||||
|
# Same = filtered_data[:, 1].astype(np.float32)
|
||||||
|
# Cross = filtered_data[:, 2].astype(np.float32)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].hist(nongifts, bins=50, edgecolor='blue')
|
||||||
|
axs[0].set_xlim([-0.1, 1])
|
||||||
|
axs[0].set_title('nongifts')
|
||||||
|
|
||||||
|
axs[1].hist(gifts, bins=50, edgecolor='green')
|
||||||
|
axs[1].set_xlim([-0.1, 1])
|
||||||
|
axs[1].set_title('gifts')
|
||||||
|
# plt.savefig('plot.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_precision_recall(nongift, gift, points):
|
||||||
|
precision, recall = [], []
|
||||||
|
for point in points:
|
||||||
|
TP = np.sum(gift > point)
|
||||||
|
FN = np.sum(gift < point)
|
||||||
|
FP = np.sum(nongift > point)
|
||||||
|
TN = np.sum(nongift < point)
|
||||||
|
if TP == 0:
|
||||||
|
precision.append(0)
|
||||||
|
recall.append(0)
|
||||||
|
else:
|
||||||
|
precision.append(TP / (TP + FP))
|
||||||
|
recall.append(TP / (TP + FN))
|
||||||
|
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
||||||
|
if point == 0.5:
|
||||||
|
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
||||||
|
return precision, recall
|
||||||
|
|
||||||
|
|
||||||
|
def showgrid(all_prec, all_recall, points):
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
|
||||||
|
plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
# plt.ylabel('Similarity')
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
# plt.savefig('grid.png')
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def discriminate_action(roots): # 判断加购还是退购
|
||||||
|
pth = os.sep.join([roots, 'process.data'])
|
||||||
|
with open(pth, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
content = line.strip()
|
||||||
|
if 'weightValue' in content:
|
||||||
|
# print(content.split(":")[-1].split(',')[0])
|
||||||
|
if int(content.split(":")[-1].split(',')[0]) > 0:
|
||||||
|
return 'add'
|
||||||
|
else:
|
||||||
|
return 'return'
|
||||||
|
|
||||||
|
|
||||||
|
def median(lst):
|
||||||
|
sorted_lst = sorted(lst)
|
||||||
|
n = len(sorted_lst)
|
||||||
|
if n % 2 == 1:
|
||||||
|
# 如果列表长度是奇数,中位数是中间的那个元素
|
||||||
|
return sorted_lst[n // 2]
|
||||||
|
else:
|
||||||
|
# 如果列表长度是偶数,中位数是中间两个元素的平均值
|
||||||
|
mid1 = sorted_lst[(n // 2) - 1]
|
||||||
|
mid2 = sorted_lst[n // 2]
|
||||||
|
return (mid1 + mid2) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def get_special_data(data, p):
|
||||||
|
# print(data)
|
||||||
|
length = len(data)
|
||||||
|
if length > 5:
|
||||||
|
if p == 'max':
|
||||||
|
return max(data[:round(length * 0.5)])
|
||||||
|
elif p == 'average':
|
||||||
|
return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
|
||||||
|
elif p == 'median':
|
||||||
|
return median(data[:round(length * 0.5)])
|
||||||
|
else:
|
||||||
|
return sum(data) / len(data)
|
||||||
|
|
||||||
|
|
||||||
|
def read_data_file(pth):
|
||||||
|
result = []
|
||||||
|
with open(pth, 'r') as data_file:
|
||||||
|
lines = data_file.readlines()
|
||||||
|
for line in lines:
|
||||||
|
if line.split(':')[0] == 'free_gift__result':
|
||||||
|
if '0_tracking_output.data' in pth:
|
||||||
|
result = line.split(':')[1].split(',')[:-1]
|
||||||
|
else:
|
||||||
|
result = line.split(':')[1].split(',')[:-2]
|
||||||
|
result = [float(i) for i in result]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_tracking_data(pth):
|
||||||
|
result = []
|
||||||
|
with open(pth, 'r') as data_file:
|
||||||
|
lines = data_file.readlines()
|
||||||
|
for line in lines:
|
||||||
|
if len(line.split(',')) == 65:
|
||||||
|
result.append([float(item) for item in line.split(',')[:-1]])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def clean_reurn_data(pth):
|
||||||
|
for roots, dirs, files in os.walk(pth):
|
||||||
|
# print(roots, dirs, files)
|
||||||
|
if len(dirs) == 0:
|
||||||
|
flag = discriminate_action(roots)
|
||||||
|
if flag == 'return':
|
||||||
|
shutil.rmtree(roots)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gift_files(pth): # 测试后直接分析测试结果文件
|
||||||
|
add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
|
||||||
|
add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
|
||||||
|
for roots, dirs, files in os.walk(pth):
|
||||||
|
# print(roots, dirs, files)
|
||||||
|
if len(dirs) == 0:
|
||||||
|
flag = discriminate_action(roots)
|
||||||
|
for file in files:
|
||||||
|
if file == '0_tracking_output.data':
|
||||||
|
result = read_data_file(os.path.join(roots, file))
|
||||||
|
if not len(result) == 0:
|
||||||
|
if flag == 'add':
|
||||||
|
add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
|
||||||
|
else:
|
||||||
|
return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
|
||||||
|
if flag == 'add':
|
||||||
|
add_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
||||||
|
else:
|
||||||
|
return_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
||||||
|
elif file == '1_tracking_output.data':
|
||||||
|
result = read_data_file(os.path.join(roots, file))
|
||||||
|
if not len(result) == 0:
|
||||||
|
if flag == 'add':
|
||||||
|
add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
|
||||||
|
else:
|
||||||
|
return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
|
||||||
|
if flag == 'add':
|
||||||
|
add_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
||||||
|
else:
|
||||||
|
return_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
||||||
|
comprehensive_dicts = {"add_special_output_0": add_special_output_0,
|
||||||
|
"return_special_output_0": return_special_output_0,
|
||||||
|
"add_tracking_output_0": add_tracking_output_0,
|
||||||
|
"return_tracking_output_0": return_tracking_output_0,
|
||||||
|
"add_special_output_1": add_special_output_1,
|
||||||
|
"return_special_output_1": return_special_output_1,
|
||||||
|
"add_tracking_output_1": add_tracking_output_1,
|
||||||
|
"return_tracking_output_1": return_tracking_output_1,
|
||||||
|
}
|
||||||
|
# print(tracking_output_0, tracking_output_1)
|
||||||
|
showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
|
||||||
|
np.array(comprehensive_dicts['add_tracking_output_1']))
|
||||||
|
# showHist(np.array(comprehensive_dicts['add_special_output_0']),
|
||||||
|
# np.array(comprehensive_dicts['add_special_output_1']))
|
||||||
|
return comprehensive_dicts
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
|
||||||
|
features_np = []
|
||||||
|
if pkl_flag:
|
||||||
|
for img_lists in img_pth_lists:
|
||||||
|
# print(img_lists)
|
||||||
|
fe_nps = getFeatureList(None, img_lists, res_model)
|
||||||
|
# fe_nps.squeeze()
|
||||||
|
try:
|
||||||
|
fe_nps = fe_nps[0][:, 256:]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
|
fe_nps = torch.from_numpy(fe_nps)
|
||||||
|
fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
|
||||||
|
if len(fe_nps):
|
||||||
|
fe_np = gift_model(fe_nps)
|
||||||
|
fe_np = np.squeeze(fe_np.detach().numpy())
|
||||||
|
features_np.append(fe_np)
|
||||||
|
else:
|
||||||
|
for img_lists in img_pth_lists:
|
||||||
|
fe_nps = getFeatureList(None, img_lists, gift_model)
|
||||||
|
if len(fe_nps) > 0:
|
||||||
|
fe_nps = np.concatenate(fe_nps)
|
||||||
|
features_np.append(fe_nps)
|
||||||
|
return features_np
|
||||||
|
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
def create_gift_subimg_np(data_pth, pkl_flag):
|
||||||
|
gift_array_pth = os.path.join(data_pth, 'gift.pkl')
|
||||||
|
nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
|
||||||
|
res_model, gift_model = init_model(pkl_flag)
|
||||||
|
res_model = res_model.eval()
|
||||||
|
gift_model = gift_model.eval()
|
||||||
|
gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(data_pth):
|
||||||
|
if ('commodity' in root and 'subimg' in root):
|
||||||
|
print("commodity >> {}".format(root))
|
||||||
|
for file in files:
|
||||||
|
nongift_img_pth_list.append(os.sep.join([root, file]))
|
||||||
|
nongift_lists.append(nongift_img_pth_list)
|
||||||
|
nongift_img_pth_list = []
|
||||||
|
elif ('Havegift' in root and 'subimg' in root):
|
||||||
|
print("Havegift >> {}".format(root))
|
||||||
|
for file in files:
|
||||||
|
gift_img_pth_list.append(os.sep.join([root, file]))
|
||||||
|
gift_lists.append(gift_img_pth_list)
|
||||||
|
gift_img_pth_list = []
|
||||||
|
nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
|
||||||
|
gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
|
||||||
|
with open(nongift_array_pth, 'wb') as file:
|
||||||
|
pickle.dump(nongift, file)
|
||||||
|
with open(gift_array_pth, 'wb') as file:
|
||||||
|
pickle.dump(gift, file)
|
||||||
|
|
||||||
|
|
||||||
|
def top_25_percent_mean(arr):
|
||||||
|
# 1. 对数组进行从高到低排序
|
||||||
|
sorted_arr = np.sort(arr)[::-1]
|
||||||
|
|
||||||
|
# 2. 计算数组长度的25%
|
||||||
|
top_25_percent_length = int(len(sorted_arr) * 0.25)
|
||||||
|
|
||||||
|
# 3. 取排序后数组的前25%元素
|
||||||
|
top_25_percent = sorted_arr[:top_25_percent_length]
|
||||||
|
|
||||||
|
# 4. 计算这些元素的平均值
|
||||||
|
mean_value = np.mean(top_25_percent)
|
||||||
|
|
||||||
|
return top_25_percent
|
||||||
|
|
||||||
|
|
||||||
|
def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
|
||||||
|
points = (np.linspace(1, 100, 100)) / 100
|
||||||
|
gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
|
||||||
|
nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
|
||||||
|
if not os.path.exists(gift_pkl_pth):
|
||||||
|
create_gift_subimg_np(data_pth, pkl_flag)
|
||||||
|
with open(nongift_pkl_pth, 'rb') as f:
|
||||||
|
nongift = pickle.load(f)
|
||||||
|
with open(gift_pkl_pth, 'rb') as f:
|
||||||
|
gift = pickle.load(f)
|
||||||
|
# showHist(nongift.flatten(), gift.flatten())
|
||||||
|
|
||||||
|
'''
|
||||||
|
一分位均值
|
||||||
|
'''
|
||||||
|
nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
|
||||||
|
gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
|
||||||
|
'''
|
||||||
|
中位数
|
||||||
|
'''
|
||||||
|
# nongift_mean = [np.median(items) for items in nongift]
|
||||||
|
# gift_mean = [np.median(items) for items in gift] # 平均值
|
||||||
|
|
||||||
|
'''
|
||||||
|
全部结果
|
||||||
|
'''
|
||||||
|
# nongifts = [items for items in nongift]
|
||||||
|
# gifts = [items for items in gift]
|
||||||
|
# showHist(nongifts, gifts)
|
||||||
|
|
||||||
|
'''
|
||||||
|
平均值
|
||||||
|
'''
|
||||||
|
# nongift_mean = [np.mean(items) for items in nongift]
|
||||||
|
# gift_mean = [np.mean(items) for items in gift]
|
||||||
|
|
||||||
|
showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
|
||||||
|
precision, recall = calculate_precision_recall(np.array(nongift_mean),
|
||||||
|
np.array(gift_mean),
|
||||||
|
points)
|
||||||
|
showgrid(precision, recall, points)
|
||||||
|
|
||||||
|
|
||||||
|
def get_comprehensive_dicts(data_pth):
|
||||||
|
gift_pth = r'../checkpoints/gift_model/action2/best.pth'
|
||||||
|
g_model = Net3(pretrained=True, num_classes=1)
|
||||||
|
g_model.load_state_dict(torch.load(gift_pth))
|
||||||
|
g_model.eval()
|
||||||
|
result = []
|
||||||
|
file_name = ['0_tracking_output.data',
|
||||||
|
'1_tracking_output.data']
|
||||||
|
for root, dirs, files in os.walk(data_pth):
|
||||||
|
if not len(dirs):
|
||||||
|
for file in files:
|
||||||
|
if file in file_name:
|
||||||
|
print(os.path.join(root, file))
|
||||||
|
result += get_tracking_data(os.path.join(root, file))
|
||||||
|
result = torch.from_numpy(np.array(result))
|
||||||
|
input = result.view(result.shape[0], 64, 1, 1)
|
||||||
|
input = input.to('cpu')
|
||||||
|
input = input.to(torch.float32)
|
||||||
|
ji = g_model(input)
|
||||||
|
print(ji)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
|
||||||
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
|
||||||
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
|
||||||
|
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
|
||||||
|
# pth = r'C:\Users\HP\Desktop\zengpin\1227'
|
||||||
|
# get_gift_files(pth)
|
||||||
|
|
||||||
|
# 根据子图分析结果
|
||||||
|
pth = r'D:\Project\contrast_nettest\data\gift_test'
|
||||||
|
assess_gift_subimg(pth)
|
||||||
|
|
||||||
|
# 根据完整数据集分析结果
|
||||||
|
# pth = r'C:\Users\HP\Desktop\zengpin\1231'
|
||||||
|
# get_comprehensive_dicts(pth)
|
||||||
|
|
||||||
|
# 删除退购视频
|
||||||
|
# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
|
||||||
|
# clean_reurn_data(pth)
|
92
tools/gift_data_pretreatment.py
Normal file
92
tools/gift_data_pretreatment.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import torch
|
||||||
|
from config import config as conf
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rgba_to_rgb(image_path, output_path=None):
|
||||||
|
"""
|
||||||
|
将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
|
||||||
|
|
||||||
|
:param image_path: 输入图像的路径
|
||||||
|
:param output_path: 转换后的图像保存路径
|
||||||
|
"""
|
||||||
|
# 打开图像
|
||||||
|
img = Image.open(image_path)
|
||||||
|
# 转换图像模式从RGBA到RGB
|
||||||
|
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||||
|
if img.mode == 'RGBA':
|
||||||
|
# 转换为RGB模式
|
||||||
|
img_rgb = img.convert('RGB')
|
||||||
|
# 保存转换后的图像
|
||||||
|
img_rgb.save(image_path)
|
||||||
|
# print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||||
|
# else:
|
||||||
|
# # 如果已经是RGB或其他模式,直接保存
|
||||||
|
# img.save(image_path)
|
||||||
|
# print(f"Image already in {img.mode} mode, saved to {image_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
try:
|
||||||
|
# print(img)
|
||||||
|
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||||
|
res.append(im)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def inference(images, model, actionModel=False):
|
||||||
|
data = test_preprocess(images, actionModel)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
data = data.to(conf.device)
|
||||||
|
features = model(data)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def group_image(images, batch=64) -> list:
|
||||||
|
"""Group image paths by batch size"""
|
||||||
|
size = len(images)
|
||||||
|
res = []
|
||||||
|
for i in range(0, size, batch):
|
||||||
|
end = min(batch + i, size)
|
||||||
|
res.append(images[i:end])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def normalize(queFeatList):
|
||||||
|
for num1 in range(len(queFeatList)):
|
||||||
|
for num2 in range(len(queFeatList[num1])):
|
||||||
|
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
|
||||||
|
return queFeatList
|
||||||
|
|
||||||
|
def getFeatureList(barList, imgList, model):
|
||||||
|
# featList = [[] for i in range(len(barList))]
|
||||||
|
# for index, feat in enumerate(imgList):
|
||||||
|
fe_nps = []
|
||||||
|
groups = group_image(imgList)
|
||||||
|
for group in groups:
|
||||||
|
feat_tensor = inference(group, model)
|
||||||
|
# for fe in feat_tensor:
|
||||||
|
if feat_tensor.device == 'cpu':
|
||||||
|
fe_np = feat_tensor.squeeze().detach().numpy()
|
||||||
|
# fe_np = fe_np[:, 256:]
|
||||||
|
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||||
|
else:
|
||||||
|
fe_np = feat_tensor.squeeze().detach().cpu().numpy()
|
||||||
|
# fe_np = fe_np[:, 256:]
|
||||||
|
# fe_np = fe_np[256:]
|
||||||
|
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||||
|
# fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
|
||||||
|
# print(fe_np)
|
||||||
|
|
||||||
|
fe_nps.append(fe_np)
|
||||||
|
# if fe_nps:
|
||||||
|
# merged_fe_np = np.concatenate(fe_nps, axis=0)
|
||||||
|
# else:
|
||||||
|
# merged_fe_np = np.array([]) #
|
||||||
|
# fe_list = normalize(fe_nps)
|
||||||
|
return fe_nps
|
118
tools/json_contrast.py
Normal file
118
tools/json_contrast.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def showHist(same, cross):
|
||||||
|
Same = np.array(same)
|
||||||
|
Cross = np.array(cross)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||||
|
axs[0].set_xlim([-0.1, 1])
|
||||||
|
axs[0].set_title('Same Barcode')
|
||||||
|
|
||||||
|
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||||
|
axs[1].set_xlim([-0.1, 1])
|
||||||
|
axs[1].set_title('Cross Barcode')
|
||||||
|
# plt.savefig('plot.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||||
|
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||||
|
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||||
|
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||||
|
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||||
|
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
# plt.ylabel('Similarity')
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
plt.savefig('grid.png')
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy_recall(score, labels):
|
||||||
|
th = 0.1
|
||||||
|
squence = np.linspace(-1, 1, num=50)
|
||||||
|
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||||
|
Same = score[:len(score) // 2]
|
||||||
|
Cross = score[len(score) // 2:]
|
||||||
|
for th in squence:
|
||||||
|
t_score = (score > th)
|
||||||
|
t_labels = (labels == 1)
|
||||||
|
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||||
|
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||||
|
f_score = (score < th)
|
||||||
|
f_labels = (labels == 0)
|
||||||
|
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||||
|
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||||
|
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||||
|
|
||||||
|
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||||
|
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||||
|
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||||
|
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||||
|
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||||
|
|
||||||
|
showHist(Same, Cross)
|
||||||
|
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||||
|
|
||||||
|
|
||||||
|
def get_similarity(features1, features2, n, m):
|
||||||
|
features1 = np.array(features1)
|
||||||
|
features2 = np.array(features2)
|
||||||
|
all_similarity = []
|
||||||
|
for feature1 in features1:
|
||||||
|
for feature2 in features2:
|
||||||
|
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||||
|
all_similarity.append(similarity)
|
||||||
|
test_similarity = np.array(all_similarity)
|
||||||
|
np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
|
||||||
|
if n == 5 and m == 5:
|
||||||
|
print(all_similarity)
|
||||||
|
return np.mean(np_all_array), all_similarity
|
||||||
|
# return sum(all_similarity)/len(all_similarity), all_similarity
|
||||||
|
# return max(all_similarity), all_similarity
|
||||||
|
|
||||||
|
|
||||||
|
def deal_similarity(dicts):
|
||||||
|
all_similarity = []
|
||||||
|
similarity = []
|
||||||
|
same_barcode, diff_barcode = [], []
|
||||||
|
for n, (key1, value1) in enumerate(dicts.items()):
|
||||||
|
print('key1 >> {}'.format(key1))
|
||||||
|
for m, (key2, value2) in enumerate(dicts.items()):
|
||||||
|
print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
|
||||||
|
max_similarity, some_similarity = get_similarity(value1, value2, n, m)
|
||||||
|
similarity.append(max_similarity)
|
||||||
|
if key1 == key2:
|
||||||
|
same_barcode += some_similarity
|
||||||
|
else:
|
||||||
|
diff_barcode += some_similarity
|
||||||
|
all_similarity.append(similarity)
|
||||||
|
similarity = []
|
||||||
|
all_similarity = np.array(all_similarity)
|
||||||
|
random.shuffle(diff_barcode)
|
||||||
|
same_list = [1] * len(same_barcode)
|
||||||
|
diff_list = [0] * len(same_barcode)
|
||||||
|
all_list = same_list + diff_list
|
||||||
|
all_score = same_barcode + diff_barcode[:len(same_barcode)]
|
||||||
|
compute_accuracy_recall(np.array(all_score), np.array(all_list))
|
||||||
|
print(all_similarity.shape)
|
||||||
|
|
||||||
|
|
||||||
|
with open('../search_library/data_zhanting.json', 'r') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
dicts = {}
|
||||||
|
for dict in data['total']:
|
||||||
|
key = dict['key']
|
||||||
|
value = dict['value']
|
||||||
|
dicts[key] = value
|
||||||
|
deal_similarity(dicts)
|
63
tools/model_onnx_transform.py
Normal file
63
tools/model_onnx_transform.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import pdb
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from model import resnet18
|
||||||
|
from config import config as conf
|
||||||
|
from collections import OrderedDict
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||||||
|
# 定义模型
|
||||||
|
if model_name == 'resnet18':
|
||||||
|
model = resnet18(scale=0.75)
|
||||||
|
|
||||||
|
print('model_name >>> {}'.format(model_name))
|
||||||
|
if conf.multiple_cards:
|
||||||
|
model = model.to(torch.device('cpu'))
|
||||||
|
checkpoint = torch.load(pretrained_weights)
|
||||||
|
new_state_dict = OrderedDict()
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
name = k[7:] # remove "module."
|
||||||
|
new_state_dict[name] = v
|
||||||
|
model.load_state_dict(new_state_dict)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||||
|
# try:
|
||||||
|
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||||
|
# except Exception as e:
|
||||||
|
# print(e)
|
||||||
|
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||||||
|
# model = nn.DataParallel(model).to(conf.device)
|
||||||
|
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||||||
|
|
||||||
|
|
||||||
|
# 转换为ONNX
|
||||||
|
if model_name == 'gift_type2':
|
||||||
|
input_shape = [1, 64, 13, 13]
|
||||||
|
elif model_name == 'gift_type3':
|
||||||
|
input_shape = [1, 3, 224, 224]
|
||||||
|
else:
|
||||||
|
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||||||
|
input_shape = [1, 3, 224, 224]
|
||||||
|
|
||||||
|
img = cv2.imread('./dog_224x224.jpg')
|
||||||
|
|
||||||
|
output_file = pretrained_weights.replace('pth', 'onnx')
|
||||||
|
|
||||||
|
# 导出模型
|
||||||
|
torch.onnx.export(model,
|
||||||
|
torch.randn(input_shape),
|
||||||
|
output_file,
|
||||||
|
verbose=True,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output']) ##, optset_version=12
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
|
||||||
|
trace_model.save(output_file.replace('.onnx', '.pt'))
|
||||||
|
print(f"Model exported to {output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||||||
|
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
186
tools/model_rknn_transform.py
Normal file
186
tools/model_rknn_transform.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
import urllib
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from config import config as conf
|
||||||
|
from rknn.api import RKNN
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||||
|
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||||
|
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
||||||
|
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
||||||
|
|
||||||
|
|
||||||
|
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||||
|
# RKNN_MODEL = 'v3_small_0424.rknn'
|
||||||
|
|
||||||
|
def show_outputs(outputs):
|
||||||
|
# print('***************outputs', outputs)
|
||||||
|
output = outputs[0][0]
|
||||||
|
# print('len(outputs)',len(output), output)
|
||||||
|
output_sorted = sorted(output, reverse=True)
|
||||||
|
top5_str = 'resnet50v2\n-----TOP 5-----\n'
|
||||||
|
for i in range(5):
|
||||||
|
value = output_sorted[i]
|
||||||
|
index = np.where(output == value)
|
||||||
|
for j in range(len(index)):
|
||||||
|
if (i + j) >= 5:
|
||||||
|
break
|
||||||
|
if value > 0:
|
||||||
|
topi = '{}: {}\n'.format(index[j], value)
|
||||||
|
else:
|
||||||
|
topi = '-1: 0.0\n'
|
||||||
|
top5_str += topi
|
||||||
|
# pdb.set_trace()
|
||||||
|
print(top5_str)
|
||||||
|
|
||||||
|
|
||||||
|
def readable_speed(speed):
|
||||||
|
speed_bytes = float(speed)
|
||||||
|
speed_kbytes = speed_bytes / 1024
|
||||||
|
if speed_kbytes > 1024:
|
||||||
|
speed_mbytes = speed_kbytes / 1024
|
||||||
|
if speed_mbytes > 1024:
|
||||||
|
speed_gbytes = speed_mbytes / 1024
|
||||||
|
return "{:.2f} GB/s".format(speed_gbytes)
|
||||||
|
else:
|
||||||
|
return "{:.2f} MB/s".format(speed_mbytes)
|
||||||
|
else:
|
||||||
|
return "{:.2f} KB/s".format(speed_kbytes)
|
||||||
|
|
||||||
|
|
||||||
|
def show_progress(blocknum, blocksize, totalsize):
|
||||||
|
speed = (blocknum * blocksize) / (time.time() - start_time)
|
||||||
|
speed_str = " Speed: {}".format(readable_speed(speed))
|
||||||
|
recv_size = blocknum * blocksize
|
||||||
|
|
||||||
|
f = sys.stdout
|
||||||
|
progress = (recv_size / totalsize)
|
||||||
|
progress_str = "{:.2f}%".format(progress * 100)
|
||||||
|
n = round(progress * 50)
|
||||||
|
s = ('#' * n).ljust(50, '-')
|
||||||
|
f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str)
|
||||||
|
f.flush()
|
||||||
|
f.write('\r\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Create RKNN object
|
||||||
|
rknn = RKNN(verbose=True)
|
||||||
|
|
||||||
|
# If resnet50v2 does not exist, download it.
|
||||||
|
# Download address:
|
||||||
|
# https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx
|
||||||
|
if not os.path.exists(ONNX_MODEL):
|
||||||
|
print('--> Download {}'.format(ONNX_MODEL))
|
||||||
|
url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx'
|
||||||
|
download_file = ONNX_MODEL
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
urllib.request.urlretrieve(url, download_file, show_progress)
|
||||||
|
except:
|
||||||
|
print('Download {} failed.'.format(download_file))
|
||||||
|
print(traceback.format_exc())
|
||||||
|
exit(-1)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# pre-process config
|
||||||
|
print('--> config model')
|
||||||
|
# rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.82, 58.82, 58.82])
|
||||||
|
rknn.config(
|
||||||
|
mean_values=[[127.5, 127.5, 127.5]],
|
||||||
|
std_values=[[127.5, 127.5, 127.5]],
|
||||||
|
target_platform='rk3588',
|
||||||
|
model_pruning=False,
|
||||||
|
compress_weight=False,
|
||||||
|
single_core_mode=True)
|
||||||
|
# rknn.config(
|
||||||
|
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||||
|
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||||
|
# target_platform='rk3588', # 设置目标平台
|
||||||
|
# # quantize_dtype='int8',
|
||||||
|
# # quantize_algo='normal',
|
||||||
|
# # output_optimize=False,
|
||||||
|
# # output_format='rknnb'
|
||||||
|
# )
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
print('--> Loading model')
|
||||||
|
ret = rknn.load_onnx(model=ONNX_MODEL)
|
||||||
|
if ret != 0:
|
||||||
|
print('Load model failed!')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# Build model
|
||||||
|
print('--> Building model')
|
||||||
|
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
|
||||||
|
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
||||||
|
if ret != 0:
|
||||||
|
print('Build model failed!')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# Export rknn model
|
||||||
|
print('--> Export rknn model')
|
||||||
|
ret = rknn.export_rknn(RKNN_MODEL)
|
||||||
|
if ret != 0:
|
||||||
|
print('Export rknn model failed!')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# Set inputs
|
||||||
|
img = cv2.imread('./dog_224x224.jpg')
|
||||||
|
# img = cv2.imread('./data/gift_test/Havegift/20241213-161415-cb8e0762-f376-45d1-8f36-7dc070990fa5/subimg/cam1_9_tid2_fid(18, 33250169482).png')
|
||||||
|
# print('img', img)
|
||||||
|
# with open('pixel_values.txt', 'w') as file:
|
||||||
|
|
||||||
|
# for y in range(img.shape[0]):
|
||||||
|
# for x in range(img.shape[1]):
|
||||||
|
# b, g, r = img[y, x]
|
||||||
|
# file.write(f'{r},{g},{b}\n')
|
||||||
|
|
||||||
|
# img = cv2.imread('./810115161912_810115161912_20240131-145622_0da14e4d-a3da-499f-b512-2d4168ab1c87_front_addGood_70f75407b7ae_29_01.jpg')
|
||||||
|
img = cv2.resize(img, (224, 224))
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# img = conf.test_transform(img)
|
||||||
|
# img = img.numpy()
|
||||||
|
# img = img.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Init runtime environment
|
||||||
|
print('--> Init runtime environment')
|
||||||
|
ret = rknn.init_runtime()
|
||||||
|
# ret = rknn.init_runtime('rk3588')
|
||||||
|
if ret != 0:
|
||||||
|
print('Init runtime environment failed!')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
print('--> Running model')
|
||||||
|
T1 = time.time()
|
||||||
|
outputs = rknn.inference(inputs=[img])
|
||||||
|
# outputs = rknn.inference(inputs=img)
|
||||||
|
T2 = time.time()
|
||||||
|
print('消耗时间 >>> {}'.format(T2 - T1))
|
||||||
|
with open('result_0415_128.txt', 'a') as f:
|
||||||
|
f.write(str(outputs))
|
||||||
|
# pdb.set_trace()
|
||||||
|
print('***outputs', outputs)
|
||||||
|
np.save('./onnx_resnet50v2_0.npy', outputs[0])
|
||||||
|
x = outputs[0]
|
||||||
|
output = np.exp(x) / np.sum(np.exp(x))
|
||||||
|
outputs = [output]
|
||||||
|
show_outputs(outputs)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
rknn.release()
|
233
tools/operate_usearch.py
Normal file
233
tools/operate_usearch.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from usearch.index import Index
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
|
||||||
|
|
||||||
|
def create_index():
|
||||||
|
index = Index(
|
||||||
|
ndim=256,
|
||||||
|
metric='cos',
|
||||||
|
# dtype='f32',
|
||||||
|
dtype='f16',
|
||||||
|
connectivity=32,
|
||||||
|
expansion_add=40, # 128,
|
||||||
|
expansion_search=10, # 64,
|
||||||
|
multi=True
|
||||||
|
)
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def compare_feature(features1, features2, model='1'):
|
||||||
|
"""
|
||||||
|
:param model 比对策略
|
||||||
|
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||||||
|
'1':带对比的所有相似度的均值
|
||||||
|
'2':比对1:1的最大值
|
||||||
|
:param feature1:
|
||||||
|
:param feature2:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
similarity_group, similarity_groups = [], []
|
||||||
|
if model == '0':
|
||||||
|
for feature1 in features1:
|
||||||
|
for feature2 in features2[0]:
|
||||||
|
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||||
|
similarity_group.append(similarity)
|
||||||
|
similarity_groups.append(max(similarity_group))
|
||||||
|
similarity_group = []
|
||||||
|
return sum(similarity_groups) / len(similarity_groups)
|
||||||
|
|
||||||
|
elif model == '1':
|
||||||
|
feature2 = features2[0]
|
||||||
|
for feature1 in features1:
|
||||||
|
for num in range(len(feature2)):
|
||||||
|
similarity = np.dot(feature1, feature2[num]) / (
|
||||||
|
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||||
|
similarity_group.append(similarity)
|
||||||
|
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||||||
|
similarity_group = []
|
||||||
|
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||||||
|
if len(similarity_groups) == 0:
|
||||||
|
return -1
|
||||||
|
return sum(similarity_groups) / len(similarity_groups)
|
||||||
|
elif model == '2':
|
||||||
|
feature2 = features2[0]
|
||||||
|
for feature1 in features1:
|
||||||
|
for num in range(len(feature2)):
|
||||||
|
similarity = np.dot(feature1, feature2[num]) / (
|
||||||
|
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||||
|
similarity_group.append(similarity)
|
||||||
|
return max(similarity_group)
|
||||||
|
|
||||||
|
def get_barcode_feature(data):
|
||||||
|
barcode = data['key']
|
||||||
|
features = data['value']
|
||||||
|
return [barcode] * len(features), features
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_file(file_path):
|
||||||
|
"""
|
||||||
|
:param file_path:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
barcodes, features = [], []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for dic in data['total']:
|
||||||
|
barcode, feature = get_barcode_feature(dic)
|
||||||
|
barcodes.append(barcode)
|
||||||
|
features.append(feature)
|
||||||
|
return barcodes, features
|
||||||
|
|
||||||
|
|
||||||
|
def create_base_index(index_file_pth=None,
|
||||||
|
barcodes=None,
|
||||||
|
features=None,
|
||||||
|
save_index_name=None):
|
||||||
|
index = create_index()
|
||||||
|
if index_file_pth is not None:
|
||||||
|
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||||||
|
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||||||
|
barcodes, features = analysis_file(index_file_pth)
|
||||||
|
else:
|
||||||
|
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||||||
|
for barcode, feature in zip(barcodes, features):
|
||||||
|
try:
|
||||||
|
index.add(np.array(barcode), np.array(feature))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
|
index.save(save_index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_index(index_file_pth=None,
|
||||||
|
barcodes=None):
|
||||||
|
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||||||
|
index = Index.restore(index_file_pth, view=True)
|
||||||
|
feature_lists = index.get(np.array(barcodes))
|
||||||
|
print("memory {} size {}".format(index.memory_usage, index.size))
|
||||||
|
print("feature_lists {}".format(feature_lists))
|
||||||
|
return feature_lists
|
||||||
|
|
||||||
|
|
||||||
|
def search_in_index(query=None,
|
||||||
|
barcode=None, # barcode -> int or np.ndarray
|
||||||
|
index_name=None,
|
||||||
|
temp_index=False, # 是否为临时库
|
||||||
|
model='0',
|
||||||
|
):
|
||||||
|
if temp_index:
|
||||||
|
assert index_name is not None, 'index_name must be not None'
|
||||||
|
index = Index.restore(index_name, view=True)
|
||||||
|
if barcode is not None: # 1:1对比测试
|
||||||
|
feature_lists = index.get(np.array(barcode))
|
||||||
|
results = compare_feature(query, feature_lists)
|
||||||
|
else:
|
||||||
|
results = index.search(query, count=5)
|
||||||
|
return results
|
||||||
|
else: # 标准库
|
||||||
|
assert index_name is not None, 'index_name must be not None'
|
||||||
|
index = Index.restore(index_name, view=True)
|
||||||
|
if barcode is not None: # 1:1对比测试
|
||||||
|
feature_lists = index.get(np.array(barcode))
|
||||||
|
results = compare_feature(query, feature_lists, model)
|
||||||
|
else:
|
||||||
|
results = index.search(query, count=10)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def delete_index(index_name=None, key=None, index=None):
|
||||||
|
assert key is not None, 'key must be not None'
|
||||||
|
if index is None:
|
||||||
|
assert index_name is not None, 'index_name must be not None'
|
||||||
|
index = Index.restore(index_name, view=True)
|
||||||
|
index.remove(index_name)
|
||||||
|
else:
|
||||||
|
index.remove(key)
|
||||||
|
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
def compute_similarity_matrix(featurelists1, featurelists2):
|
||||||
|
"""计算图片之间的余弦相似度矩阵"""
|
||||||
|
# 计算所有向量对之间的余弦相似度
|
||||||
|
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
|
||||||
|
cosine_similarities = np.around(cosine_similarities, decimals=3)
|
||||||
|
return cosine_similarities
|
||||||
|
|
||||||
|
def check_usearch_json_diff(index_file_pth, json_file_pth):
|
||||||
|
json_features = None
|
||||||
|
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
|
||||||
|
with open(json_file_pth, 'r') as json_file:
|
||||||
|
json_data = json.load(json_file)
|
||||||
|
for data in json_data['total']:
|
||||||
|
if data['key'] == '6923644272159':
|
||||||
|
json_features = data['value']
|
||||||
|
json_features = np.array(json_features)
|
||||||
|
feature_lists = np.array(feature_lists[0])
|
||||||
|
compute_similarity_matrix(json_features, feature_lists)
|
||||||
|
|
||||||
|
|
||||||
|
def write_binary_file(filename, datas):
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
# 先写入数据中的key数量(为C++读取提供便利)
|
||||||
|
key_count = len(datas)
|
||||||
|
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||||
|
|
||||||
|
for data in datas:
|
||||||
|
key = data['key']
|
||||||
|
feats = data['value']
|
||||||
|
key_bytes = key.encode('utf-8')
|
||||||
|
key_len = len(key)
|
||||||
|
length_byte = struct.pack('<B', key_len)
|
||||||
|
f.write(length_byte)
|
||||||
|
# f.write(struct.pack('Q', len(key_bytes)))
|
||||||
|
f.write(key_bytes)
|
||||||
|
value_count = len(feats)
|
||||||
|
f.write(struct.pack('I', (value_count * 256)))
|
||||||
|
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||||
|
for values in feats:
|
||||||
|
# 写入每个浮点数值(保留小数点后六位)
|
||||||
|
for value in values:
|
||||||
|
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||||
|
value_half = np.float16(value)
|
||||||
|
# print(value_half.tobytes())
|
||||||
|
f.write(value_half.tobytes())
|
||||||
|
def create_binary_file(json_path, flag=True):
|
||||||
|
# 1. 打开JSON文件
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as file:
|
||||||
|
# 2. 读取并解析JSON文件内容
|
||||||
|
data = json.load(file)
|
||||||
|
if flag:
|
||||||
|
for flag, values in data.items():
|
||||||
|
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||||
|
write_binary_file(index_file_pth.replace('json', 'bin'), values)
|
||||||
|
else:
|
||||||
|
write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||||
|
|
||||||
|
def create_binary_files(index_file_pth):
|
||||||
|
if os.path.isfile(index_file_pth):
|
||||||
|
create_binary_file(index_file_pth)
|
||||||
|
else:
|
||||||
|
for name in os.listdir(index_file_pth):
|
||||||
|
jsonpth = os.sep.join([index_file_pth, name])
|
||||||
|
create_binary_file(jsonpth, False)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
|
||||||
|
index_file_pth = '../search_library/yunhedian_30-04.json'
|
||||||
|
# create_base_index(index_file_pth) # 生成usearch文件
|
||||||
|
create_binary_files(index_file_pth) # 生成二进制文件 多文件
|
||||||
|
|
||||||
|
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||||||
|
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||||||
|
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||||||
|
|
||||||
|
# # check index data file
|
||||||
|
# index_file_pth = '../search_library/data_zhanting.data'
|
||||||
|
# # # get_feature_index(index_file_pth, ['6901070602818'])
|
||||||
|
# get_feature_index(index_file_pth, ['6923644272159'])
|
||||||
|
|
||||||
|
# index_file_pth = '../search_library/data_zhanting.data'
|
||||||
|
# json_file_pth = '../search_library/data_zhanting.json'
|
||||||
|
# check_usearch_json_diff(index_file_pth, json_file_pth)
|
84
tools/threshold_partition.py
Normal file
84
tools/threshold_partition.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
'''
|
||||||
|
现场1:N测试,确定阈值
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def showHist(filtered_data):
|
||||||
|
Same = filtered_data[:, 1].astype(np.float32)
|
||||||
|
Cross = filtered_data[:, 2].astype(np.float32)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||||
|
axs[0].set_xlim([-0.1, 1])
|
||||||
|
axs[0].set_title('first')
|
||||||
|
|
||||||
|
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||||
|
axs[1].set_xlim([-0.1, 1])
|
||||||
|
axs[1].set_title('second')
|
||||||
|
# plt.savefig('plot.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def get_tartget_list(nested_list):
|
||||||
|
filtered_list = np.array(list(filter(lambda x: len(x) >= 2, nested_list))) # 去除无轨迹的数据
|
||||||
|
filtered_correct = filtered_list[filtered_list[:, 0] != 'wrong'] # 获取比对正确的时项
|
||||||
|
filtered_wrong = filtered_list[filtered_list[:, 0] == 'wrong'] # 获取比对错误的时项
|
||||||
|
showHist(filtered_correct)
|
||||||
|
# showHist(filtered_wrong)
|
||||||
|
print(filtered_list)
|
||||||
|
|
||||||
|
|
||||||
|
def deal_process(file_pth):
|
||||||
|
flag = False
|
||||||
|
event = file_pth.split('\\')[-2]
|
||||||
|
target_barcode = file_pth.split('\\')[-2].split('_')[-1]
|
||||||
|
temp_list = []
|
||||||
|
|
||||||
|
with open(file_pth, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
if 'oneToOne' in line:
|
||||||
|
flag = True
|
||||||
|
continue
|
||||||
|
if flag:
|
||||||
|
line = line.replace('\n', '')
|
||||||
|
comparison_data = line.split(',')
|
||||||
|
forecast_barcode = comparison_data[0]
|
||||||
|
value = comparison_data[-1].split(':')[-1]
|
||||||
|
if value == '':
|
||||||
|
break
|
||||||
|
if len(temp_list) == 0:
|
||||||
|
if forecast_barcode == target_barcode:
|
||||||
|
temp_list.append('correct')
|
||||||
|
else:
|
||||||
|
temp_list.append('wrong')
|
||||||
|
temp_list.append(float(value))
|
||||||
|
temp_list.append(event)
|
||||||
|
return temp_list
|
||||||
|
|
||||||
|
|
||||||
|
def anaylze_scratch(scratch_pth):
|
||||||
|
purchase, back = [], []
|
||||||
|
for root, dirs, files in os.walk(scratch_pth):
|
||||||
|
if len(root) > 0:
|
||||||
|
if len(root.split('_')) == 4: # 加购
|
||||||
|
process = os.path.join(root, 'process.data')
|
||||||
|
if not os.path.exists(process):
|
||||||
|
continue
|
||||||
|
purchase.append(deal_process(process))
|
||||||
|
elif len(root.split('_')) == 3:
|
||||||
|
process = os.path.join(root, 'process.data')
|
||||||
|
if not os.path.exists(process):
|
||||||
|
continue
|
||||||
|
back.append(deal_process(process))
|
||||||
|
# get_tartget_list(purchase)
|
||||||
|
get_tartget_list(back)
|
||||||
|
print(purchase)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1108_展厅模型v800测试\\'
|
||||||
|
scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1120_展厅模型v801测试\\扫A放A\\'
|
||||||
|
anaylze_scratch(scratch_pth)
|
411
tools/write_feature_json.py
Normal file
411
tools/write_feature_json.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from tools.dataset import get_transform
|
||||||
|
from model import resnet18
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
import yaml
|
||||||
|
import shutil
|
||||||
|
import struct
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractor:
|
||||||
|
def __init__(self, conf):
|
||||||
|
self.conf = conf
|
||||||
|
self.model = self.initModel()
|
||||||
|
_, self.test_transform = get_transform(self.conf)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Initialize and load the ResNet18 model for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_model: Optional path to model weights. Uses conf.test_model if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded and configured PyTorch model in evaluation mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If model weights file is not found
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify model file exists
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = resnet18().to(self.conf['base']['device'])
|
||||||
|
|
||||||
|
# Handle multi-GPU case
|
||||||
|
if conf['base']['distributed']:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
state_dict = torch.load(model_path, map_location=conf['base']['device'])
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
logger.info(f"Successfully loaded model from {model_path}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize model: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def convert_rgba_to_rgb(self, image_path):
|
||||||
|
# 打开图像
|
||||||
|
img = Image.open(image_path)
|
||||||
|
# 转换图像模式从RGBA到RGB
|
||||||
|
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||||
|
if img.mode == 'RGBA':
|
||||||
|
# 转换为RGB模式
|
||||||
|
img_rgb = img.convert('RGB')
|
||||||
|
# 保存转换后的图像
|
||||||
|
img_rgb.save(image_path)
|
||||||
|
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||||
|
|
||||||
|
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
|
||||||
|
res = []
|
||||||
|
for img in images:
|
||||||
|
try:
|
||||||
|
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
|
||||||
|
res.append(im)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
data = torch.stack(res)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def inference(self, images, model, actionModel=False):
|
||||||
|
data = self.test_preprocess(images, actionModel)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
data = data.to(conf['base']['device'])
|
||||||
|
features = model(data)
|
||||||
|
if conf['data']['half']:
|
||||||
|
features = features.half()
|
||||||
|
return features
|
||||||
|
|
||||||
|
def group_image(self, images, batch=64) -> list:
|
||||||
|
"""Group image paths by batch size"""
|
||||||
|
size = len(images)
|
||||||
|
res = []
|
||||||
|
for i in range(0, size, batch):
|
||||||
|
end = min(batch + i, size)
|
||||||
|
res.append(images[i:end])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def getFeatureList(self, barList, imgList):
|
||||||
|
featList = [[] for _ in range(len(barList))]
|
||||||
|
|
||||||
|
for index, image_paths in enumerate(imgList):
|
||||||
|
try:
|
||||||
|
# Process images in batches
|
||||||
|
for batch in self.group_image(image_paths):
|
||||||
|
# Get features for batch
|
||||||
|
features = self.inference(batch, self.model)
|
||||||
|
|
||||||
|
# Process each feature in batch
|
||||||
|
for feat in features:
|
||||||
|
# Move to CPU and convert to numpy
|
||||||
|
feat_np = feat.squeeze().detach().cpu().numpy()
|
||||||
|
|
||||||
|
# Normalize first 256 dimensions
|
||||||
|
normalized = self.normalize_256(feat_np[:256])
|
||||||
|
|
||||||
|
# Combine with remaining dimensions
|
||||||
|
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
|
||||||
|
|
||||||
|
featList[index].append(combined)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing images for index {index}: {str(e)}")
|
||||||
|
continue
|
||||||
|
return featList
|
||||||
|
|
||||||
|
def get_files(
|
||||||
|
self,
|
||||||
|
folder: str,
|
||||||
|
filter: Optional[List[str]] = None,
|
||||||
|
create_single_json: bool = False
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Recursively collect image files from directory structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder: Root directory to scan
|
||||||
|
filter: Optional list of barcodes to include
|
||||||
|
create_single_json: Whether to create individual JSON files per barcode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping barcode names to lists of image paths
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
|
||||||
|
"barcode2": ["path/to/img3.jpg"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
file_dicts = {}
|
||||||
|
total_files = 0
|
||||||
|
feature_counts = []
|
||||||
|
barcode_count = 0
|
||||||
|
subclass = [str(i) for i in range(100)]
|
||||||
|
# Validate input directory
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
raise ValueError(f"Invalid directory: {folder}")
|
||||||
|
|
||||||
|
# Process each barcode directory
|
||||||
|
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||||
|
if not dirs: # Leaf directory (contains images)
|
||||||
|
basename = os.path.basename(root)
|
||||||
|
if basename in subclass:
|
||||||
|
ori_barcode = root.split('/')[-2]
|
||||||
|
barcode = root.split('/')[-2] + '_' + basename
|
||||||
|
else:
|
||||||
|
ori_barcode = basename
|
||||||
|
barcode = basename
|
||||||
|
# Apply filter if provided
|
||||||
|
if filter and ori_barcode not in filter:
|
||||||
|
continue
|
||||||
|
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||||
|
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||||
|
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||||
|
f.write(ori_barcode + '\n')
|
||||||
|
f.close()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process image files
|
||||||
|
if files:
|
||||||
|
image_paths = self._process_image_files(root, files)
|
||||||
|
if not image_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update counters
|
||||||
|
barcode_count += 1
|
||||||
|
file_count = len(image_paths)
|
||||||
|
total_files += file_count
|
||||||
|
feature_counts.append(file_count)
|
||||||
|
|
||||||
|
# Handle output mode
|
||||||
|
if create_single_json:
|
||||||
|
self._process_single_barcode(barcode, image_paths)
|
||||||
|
else:
|
||||||
|
if barcode.split('_')[-1] == '0':
|
||||||
|
barcode = barcode.split('_')[0]
|
||||||
|
file_dicts[barcode] = image_paths
|
||||||
|
|
||||||
|
# # Log summary
|
||||||
|
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
|
||||||
|
# logger.debug(f"Image counts per barcode: {feature_counts}")
|
||||||
|
|
||||||
|
# Batch process if not creating individual JSONs
|
||||||
|
if not create_single_json and file_dicts:
|
||||||
|
self.createFeatureDict(
|
||||||
|
file_dicts,
|
||||||
|
create_single_json=False,
|
||||||
|
)
|
||||||
|
return file_dicts
|
||||||
|
|
||||||
|
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
|
||||||
|
"""Process and validate image files in a directory."""
|
||||||
|
valid_paths = []
|
||||||
|
for filename in files:
|
||||||
|
file_path = os.path.join(root, filename)
|
||||||
|
try:
|
||||||
|
# Convert RGBA to RGB if needed
|
||||||
|
self.convert_rgba_to_rgb(file_path)
|
||||||
|
valid_paths.append(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
|
||||||
|
return valid_paths
|
||||||
|
|
||||||
|
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
|
||||||
|
"""Process a single barcode and create individual JSON file."""
|
||||||
|
temp_dict = {barcode: image_paths}
|
||||||
|
self.createFeatureDict(
|
||||||
|
temp_dict,
|
||||||
|
create_single_json=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def normalize_256(self, queFeatList):
|
||||||
|
queFeatList = queFeatList / np.linalg.norm(queFeatList)
|
||||||
|
return queFeatList
|
||||||
|
|
||||||
|
def img2feature(
|
||||||
|
self,
|
||||||
|
imgs_dict: Dict[str, List[str]]
|
||||||
|
) -> Tuple[List[str], List[List[np.ndarray]]]:
|
||||||
|
"""
|
||||||
|
Extract features for all images in the dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs_dict: Dictionary mapping barcodes to image paths
|
||||||
|
model: Pretrained feature extraction model
|
||||||
|
barcode_flag: Whether to include barcode info (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- List of barcode IDs
|
||||||
|
- List of feature lists (one per barcode)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input dictionary is empty
|
||||||
|
RuntimeError: If feature extraction fails
|
||||||
|
"""
|
||||||
|
if not imgs_dict:
|
||||||
|
raise ValueError("No images provided for feature extraction")
|
||||||
|
|
||||||
|
try:
|
||||||
|
barcode_list = list(imgs_dict.keys())
|
||||||
|
image_list = list(imgs_dict.values())
|
||||||
|
feature_list = self.getFeatureList(barcode_list, image_list)
|
||||||
|
|
||||||
|
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
|
||||||
|
return barcode_list, feature_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Feature extraction failed: {str(e)}")
|
||||||
|
raise RuntimeError(f"Feature extraction failed: {str(e)}")
|
||||||
|
|
||||||
|
def createFeatureDict(self, imgs_dict,
|
||||||
|
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||||
|
dicts_all = {}
|
||||||
|
value_list = []
|
||||||
|
barcode_list, imgs_list = self.img2feature(imgs_dict)
|
||||||
|
for i in range(len(barcode_list)):
|
||||||
|
dicts = {}
|
||||||
|
|
||||||
|
imgs_list_ = []
|
||||||
|
for j in range(len(imgs_list[i])):
|
||||||
|
imgs_list_.append(imgs_list[i][j].tolist())
|
||||||
|
|
||||||
|
dicts['key'] = barcode_list[i]
|
||||||
|
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
|
||||||
|
dicts['value'] = truncated_imgs_list
|
||||||
|
if create_single_json:
|
||||||
|
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||||
|
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||||
|
with open(json_path, 'w') as json_file:
|
||||||
|
json.dump(dicts, json_file)
|
||||||
|
else:
|
||||||
|
value_list.append(dicts)
|
||||||
|
if not create_single_json:
|
||||||
|
dicts_all['total'] = value_list
|
||||||
|
with open(self.conf['save']['json_bin'], 'w') as json_file:
|
||||||
|
json.dump(dicts_all, json_file)
|
||||||
|
self.create_binary_files(self.conf['save']['json_bin'])
|
||||||
|
|
||||||
|
def statisticsBarcodes(self, pth, filter=None):
|
||||||
|
feature_num = 0
|
||||||
|
feature_num_lists = []
|
||||||
|
nn = 0
|
||||||
|
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||||
|
for barcode in os.listdir(pth):
|
||||||
|
print("barcode length >> {}".format(len(barcode)))
|
||||||
|
if len(barcode) > 13 or len(barcode) < 8:
|
||||||
|
continue
|
||||||
|
if filter is not None:
|
||||||
|
f.writelines(barcode + '\n')
|
||||||
|
if barcode in filter:
|
||||||
|
print(barcode)
|
||||||
|
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||||
|
nn += 1
|
||||||
|
else:
|
||||||
|
print('barcode name >>{}'.format(barcode))
|
||||||
|
f.writelines(barcode + '\n')
|
||||||
|
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||||
|
feature_num_lists.append(feature_num)
|
||||||
|
print("特征总量: {}".format(feature_num))
|
||||||
|
print("barcode总量: {}".format(nn))
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
def get_shop_barcodes(self, file_path):
|
||||||
|
if file_path:
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
column_values = list(df.iloc[:, 6].values)
|
||||||
|
column_values = list(map(str, column_values))
|
||||||
|
return column_values
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def del_base_dir(self, pth):
|
||||||
|
for root, dirs, files in os.walk(pth):
|
||||||
|
if len(dirs) == 1:
|
||||||
|
if dirs[0] == 'base':
|
||||||
|
shutil.rmtree(os.path.join(root, dirs[0]))
|
||||||
|
|
||||||
|
def write_binary_file(self, filename, datas):
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
# 先写入数据中的key数量(为C++读取提供便利)
|
||||||
|
key_count = len(datas)
|
||||||
|
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||||
|
for data in datas:
|
||||||
|
key = data['key']
|
||||||
|
feats = data['value']
|
||||||
|
key_bytes = key.encode('utf-8')
|
||||||
|
key_len = len(key)
|
||||||
|
length_byte = struct.pack('<B', key_len)
|
||||||
|
f.write(length_byte)
|
||||||
|
# f.write(struct.pack('Q', len(key_bytes)))
|
||||||
|
f.write(key_bytes)
|
||||||
|
value_count = len(feats)
|
||||||
|
f.write(struct.pack('I', (value_count * 256)))
|
||||||
|
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||||
|
for values in feats:
|
||||||
|
# 写入每个浮点数值(保留小数点后六位)
|
||||||
|
for value in values:
|
||||||
|
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||||
|
value_half = np.float16(value)
|
||||||
|
# print(value_half.tobytes())
|
||||||
|
f.write(value_half.tobytes())
|
||||||
|
|
||||||
|
def create_binary_file(self, json_path, flag=True):
|
||||||
|
# 1. 打开JSON文件
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as file:
|
||||||
|
# 2. 读取并解析JSON文件内容
|
||||||
|
data = json.load(file)
|
||||||
|
if flag:
|
||||||
|
for flag, values in data.items():
|
||||||
|
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||||
|
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
|
||||||
|
else:
|
||||||
|
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||||
|
|
||||||
|
def create_binary_files(self, index_file_pth):
|
||||||
|
if os.path.isfile(index_file_pth):
|
||||||
|
self.create_binary_file(index_file_pth)
|
||||||
|
else:
|
||||||
|
for name in os.listdir(index_file_pth):
|
||||||
|
jsonpth = os.sep.join([index_file_pth, name])
|
||||||
|
self.create_binary_file(jsonpth, False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open('../configs/write_feature.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
###将图片名称和模型推理特征向量字典存为json文件
|
||||||
|
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
|
||||||
|
# xlsx_pth = None
|
||||||
|
# del_base_dir(mg_path)
|
||||||
|
|
||||||
|
extractor = FeatureExtractor(conf)
|
||||||
|
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||||||
|
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||||||
|
filter=column_values,
|
||||||
|
create_single_json=False) # False
|
||||||
|
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
142
train_compare.py
Normal file
142
train_compare.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from model.loss import FocalLoss
|
||||||
|
from tools.dataset import load_data
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from configs import trainer_tools
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
with open('configs/scatter.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
# Data Setup
|
||||||
|
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||||
|
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||||
|
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
metric_mapping = tr_tools.get_metric(class_num)
|
||||||
|
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
|
||||||
|
if conf['training']['metric'] in metric_mapping:
|
||||||
|
metric = metric_mapping[conf['training']['metric']]()
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
metric = nn.DataParallel(metric)
|
||||||
|
|
||||||
|
# Training Setup
|
||||||
|
if conf['training']['loss'] == 'focal_loss':
|
||||||
|
criterion = FocalLoss(gamma=2)
|
||||||
|
else:
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||||
|
if conf['training']['optimizer'] in optimizer_mapping:
|
||||||
|
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||||
|
scheduler = optim.lr_scheduler.StepLR(
|
||||||
|
optimizer,
|
||||||
|
step_size=conf['training']['lr_step'],
|
||||||
|
gamma=conf['training']['lr_decay']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
|
|
||||||
|
# Checkpoints Setup
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('backbone>{} '.format(conf['models']['backbone']),
|
||||||
|
'metric>{} '.format(conf['training']['metric']),
|
||||||
|
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
||||||
|
)
|
||||||
|
train_losses = []
|
||||||
|
val_losses = []
|
||||||
|
epochs = []
|
||||||
|
temp_loss = 100
|
||||||
|
if conf['training']['restore']:
|
||||||
|
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||||
|
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
||||||
|
map_location=conf['base']['device']))
|
||||||
|
|
||||||
|
for e in range(conf['training']['epochs']):
|
||||||
|
train_loss = 0
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for train_data, train_labels in tqdm(train_dataloader,
|
||||||
|
desc="Epoch {}/{}"
|
||||||
|
.format(e, conf['training']['epochs']),
|
||||||
|
ascii=True,
|
||||||
|
total=len(train_dataloader)):
|
||||||
|
train_data = train_data.to(conf['base']['device'])
|
||||||
|
train_labels = train_labels.to(conf['base']['device'])
|
||||||
|
|
||||||
|
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
||||||
|
# pdb.set_trace()
|
||||||
|
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(train_embeddings, train_labels) # [256,357]
|
||||||
|
else:
|
||||||
|
thetas = metric(train_embeddings)
|
||||||
|
tloss = criterion(thetas, train_labels)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
tloss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
train_loss += tloss.item()
|
||||||
|
train_lossAvg = train_loss / len(train_dataloader)
|
||||||
|
train_losses.append(train_lossAvg)
|
||||||
|
epochs.append(e)
|
||||||
|
val_loss = 0
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
||||||
|
ascii=True, total=len(val_dataloader)):
|
||||||
|
val_data = val_data.to(conf['base']['device'])
|
||||||
|
val_labels = val_labels.to(conf['base']['device'])
|
||||||
|
val_embeddings = model(val_data).to(conf['base']['device'])
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(val_embeddings, val_labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(val_embeddings)
|
||||||
|
vloss = criterion(thetas, val_labels)
|
||||||
|
val_loss += vloss.item()
|
||||||
|
val_lossAvg = val_loss / len(val_dataloader)
|
||||||
|
val_losses.append(val_lossAvg)
|
||||||
|
if val_lossAvg < temp_loss:
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||||
|
else:
|
||||||
|
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||||
|
temp_loss = val_lossAvg
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||||
|
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
|
||||||
|
print(log_info)
|
||||||
|
# 写入日志文件
|
||||||
|
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
||||||
|
f.write(log_info + '\n')
|
||||||
|
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||||
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
|
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||||
|
else:
|
||||||
|
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||||
|
plt.plot(epochs, train_losses, color='blue')
|
||||||
|
plt.plot(epochs, val_losses, color='red')
|
||||||
|
# plt.savefig('lossMobilenetv3.png')
|
||||||
|
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
205
train_distill.py
Normal file
205
train_distill.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
"""
|
||||||
|
ResNet50蒸馏训练ResNet18实现
|
||||||
|
学生网络使用ArcFace损失
|
||||||
|
支持单机双卡训练
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
from model import resnet18, resnet50, ArcFace
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tools.dataset import load_data
|
||||||
|
# from config import config as conf
|
||||||
|
import yaml
|
||||||
|
import math
|
||||||
|
def setup(rank, world_size):
|
||||||
|
os.environ['MASTER_ADDR'] = '0.0.0.0'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
class DistillTrainer:
|
||||||
|
def __init__(self, rank, world_size, conf):
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
self.device = torch.device(f'cuda:{rank}')
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
|
||||||
|
self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
|
||||||
|
|
||||||
|
# 加载预训练教师模型
|
||||||
|
# teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
|
||||||
|
teacher_path = conf['models']['teacher_model_path']
|
||||||
|
if os.path.exists(teacher_path):
|
||||||
|
teacher_state = torch.load(teacher_path, map_location=self.device)
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in teacher_state.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
new_state_dict[k[7:]] = v # 去除前7个字符'module.'
|
||||||
|
else:
|
||||||
|
new_state_dict[k] = v
|
||||||
|
# 加载处理后的状态字典
|
||||||
|
self.teacher.load_state_dict(new_state_dict, strict=False)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Successfully loaded teacher model from {teacher_path}")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
|
||||||
|
|
||||||
|
# 数据加载
|
||||||
|
self.train_loader, num_classes = load_data(training=True, cfg=conf)
|
||||||
|
self.val_loader, _ = load_data(training=False, cfg=conf)
|
||||||
|
|
||||||
|
# ArcFace损失
|
||||||
|
self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
|
||||||
|
|
||||||
|
# 分布式训练
|
||||||
|
if world_size > 1:
|
||||||
|
self.teacher = DDP(self.teacher, device_ids=[rank])
|
||||||
|
self.student = DDP(self.student, device_ids=[rank])
|
||||||
|
self.metric = DDP(self.metric, device_ids=[rank])
|
||||||
|
|
||||||
|
# 优化器
|
||||||
|
self.optimizer = torch.optim.SGD([
|
||||||
|
{'params': self.student.parameters()},
|
||||||
|
{'params': self.metric.parameters()}
|
||||||
|
], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
|
||||||
|
|
||||||
|
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
|
||||||
|
self.scaler = GradScaler()
|
||||||
|
|
||||||
|
# 损失函数
|
||||||
|
self.arcface_loss = nn.CrossEntropyLoss()
|
||||||
|
self.distill_loss = nn.KLDivLoss(reduction='batchmean')
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
|
||||||
|
"""
|
||||||
|
余弦退火法动态调整蒸馏权重
|
||||||
|
参数:
|
||||||
|
epoch: 当前训练轮次
|
||||||
|
total_epochs: 总训练轮次
|
||||||
|
initial_weight: 初始蒸馏权重(如0.8)
|
||||||
|
final_weight: 最终蒸馏权重(如0.1)
|
||||||
|
返回:
|
||||||
|
当前轮次的蒸馏权重
|
||||||
|
"""
|
||||||
|
return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
|
||||||
|
def train_epoch(self, epoch):
|
||||||
|
self.teacher.eval()
|
||||||
|
self.student.train()
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"\nTeacher network type: {type(self.teacher)}")
|
||||||
|
print(f"Student network type: {type(self.student)}")
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
|
||||||
|
data = data.to(self.device)
|
||||||
|
labels = labels.to(self.device)
|
||||||
|
|
||||||
|
# with autocast():
|
||||||
|
# 教师输出
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_logits = self.teacher(data)
|
||||||
|
|
||||||
|
# 学生输出
|
||||||
|
student_features = self.student(data)
|
||||||
|
student_logits = self.metric(student_features, labels)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
arc_loss = self.arcface_loss(student_logits, labels)
|
||||||
|
distill_loss = self.distill_loss(
|
||||||
|
F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
|
||||||
|
F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
|
||||||
|
) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
|
||||||
|
current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
|
||||||
|
loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
self.scheduler.step()
|
||||||
|
return total_loss / len(self.train_loader)
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.student.eval()
|
||||||
|
total_loss = 0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, labels in self.val_loader:
|
||||||
|
data = data.to(self.device)
|
||||||
|
labels = labels.to(self.device)
|
||||||
|
|
||||||
|
features = self.student(data)
|
||||||
|
logits = self.metric(features, labels)
|
||||||
|
|
||||||
|
loss = self.arcface_loss(logits, labels)
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
_, predicted = torch.max(logits.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
return total_loss / len(self.val_loader), correct / total
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch, is_best=False):
|
||||||
|
if self.rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
state = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'student_state_dict': self.student.state_dict(),
|
||||||
|
'metric_state_dict': self.metric.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
|
||||||
|
if not os.path.exists(self.conf['training']['checkpoints']):
|
||||||
|
os.makedirs(self.conf['training']['checkpoints'])
|
||||||
|
if filename != 'best.pth':
|
||||||
|
torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
|
||||||
|
else:
|
||||||
|
torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
|
||||||
|
|
||||||
|
def train(rank, world_size):
|
||||||
|
setup(rank, world_size)
|
||||||
|
with open('configs/distill.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
trainer = DistillTrainer(rank, world_size, conf)
|
||||||
|
best_acc = 0
|
||||||
|
for epoch in range(conf['training']['epochs']):
|
||||||
|
train_loss = trainer.train_epoch(epoch)
|
||||||
|
val_loss, val_acc = trainer.validate()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
|
||||||
|
|
||||||
|
if val_acc > best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
trainer.save_checkpoint(epoch, is_best=True)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
world_size = torch.cuda.device_count()
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
train(0, 1)
|
Reference in New Issue
Block a user