rebuild
This commit is contained in:
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
*.pth
|
||||
blog/
|
||||
data/
|
||||
experiment/
|
||||
log/
|
||||
shop_xlsx/
|
||||
loss/
|
||||
checkpoints/
|
||||
search_library/
|
||||
quant_imgs/
|
||||
README.md
|
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
869
.idea/CopilotChatHistory.xml
generated
Normal file
869
.idea/CopilotChatHistory.xml
generated
Normal file
File diff suppressed because one or more lines are too long
6
.idea/CopilotSideBarWebPersist.xml
generated
Normal file
6
.idea/CopilotSideBarWebPersist.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CopilotSideBarWebPersist">
|
||||
<option name="autoAddFileCloseState" value="true" />
|
||||
</component>
|
||||
</project>
|
19326
.idea/CopilotWebChatHistory.xml
generated
Normal file
19326
.idea/CopilotWebChatHistory.xml
generated
Normal file
File diff suppressed because one or more lines are too long
8
.idea/contrast_nettest.iml
generated
Normal file
8
.idea/contrast_nettest.iml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="服务器3-NV4090-env:py-contrast-nettest" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
114
.idea/deployment.xml
generated
Normal file
114
.idea/deployment.xml
generated
Normal file
@ -0,0 +1,114 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="lc@192.168.10.89:22 password (6)" exclude=".svn;.cvs;.idea;.DS_Store;.git;.hg;*.hprof;*.pyc;*.jpg;*.mp4;data/" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false">
|
||||
<option name="confirmBeforeUploading" value="false" />
|
||||
<serverData>
|
||||
<paths name="contrast_nettest">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/contrast_nettest" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ieemoo0169@192.168.10.93:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/home/ieemoo0169/contrast_nettest" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.56:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (10)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (11)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (4)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (5)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (6)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/home/lc/contrast_nettest" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (7)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (8)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="lc@192.168.10.89:22 password (9)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="yolov5">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ALWAYS" />
|
||||
</component>
|
||||
</project>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Remote Python 3.8.18 (sftp://lc@192.168.1.142:22/home/lc/project/miniconda3/envs/my_env/bin/python)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="服务器3-NV4090-env:py-contrast-nettest" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/contrast_nettest.iml" filepath="$PROJECT_DIR$/.idea/contrast_nettest.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
8
.idea/sshConfigs.xml
generated
Normal file
8
.idea/sshConfigs.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SshConfigs">
|
||||
<configs>
|
||||
<sshConfig authType="PASSWORD" connectionConfig="{"serverAliveInterval":300}" host="192.168.1.28" id="f9cd63ee-d39a-42a7-b369-1eb74d4f71ae" port="22" nameFormat="DESCRIPTIVE" username="ieemoo0169" useOpenSSHConfig="true" />
|
||||
</configs>
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
14
.idea/webServers.xml
generated
Normal file
14
.idea/webServers.xml
generated
Normal file
@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebServers">
|
||||
<option name="servers">
|
||||
<webServer id="422a5cdc-8aff-4e1f-9f9a-2377f5a31f0b" name="contrast_nettest">
|
||||
<fileTransfer rootFolder="/home/ieemoo0169" accessType="SFTP" host="192.168.1.28" port="22" sshConfigId="74dc3f38-9a9b-4eb8-ae6f-ed04cca88f27" sshConfig="ieemoo0169@192.168.1.28:22 password">
|
||||
<advancedOptions>
|
||||
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
|
||||
</advancedOptions>
|
||||
</fileTransfer>
|
||||
</webServer>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
9
.vscode/sftp.json
vendored
Normal file
9
.vscode/sftp.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"name": "My Server",
|
||||
"host": "localhost",
|
||||
"protocol": "sftp",
|
||||
"port": 22,
|
||||
"username": "username",
|
||||
"remotePath": "/",
|
||||
"uploadOnSave": true
|
||||
}
|
BIN
__pycache__/config.cpython-38.pyc
Normal file
BIN
__pycache__/config.cpython-38.pyc
Normal file
Binary file not shown.
BIN
__pycache__/test_ori.cpython-38.pyc
Normal file
BIN
__pycache__/test_ori.cpython-38.pyc
Normal file
Binary file not shown.
122
config.py
Normal file
122
config.py
Normal file
@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
|
||||
class Config:
|
||||
# network settings
|
||||
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
|
||||
# mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
|
||||
metric = 'arcface' # [cosface, arcface, softmax]
|
||||
cbam = False
|
||||
embedding_size = 256 # 256 # gift:2 contrast:256
|
||||
drop_ratio = 0.5
|
||||
img_size = 224
|
||||
multiple_cards = True # 多卡加载
|
||||
model_half = False # 模型半精度测试
|
||||
data_half = True # 数据半精度测试
|
||||
channel_ratio = 0.75 # 通道剪枝比例
|
||||
quantization_test = False # int8量化模型测试
|
||||
|
||||
# custom base_data settings
|
||||
custom_backbone = False # 迁移学习载入除最后一层的所有层
|
||||
custom_num_classes = 128 # 迁移学习的类别数量
|
||||
|
||||
# if quantization_test:
|
||||
# device = torch.device('cpu')
|
||||
# else:
|
||||
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
|
||||
# PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||||
|
||||
student = 'resnet'
|
||||
# data preprocess
|
||||
"""transforms.RandomCrop(size),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
RandomRotate(15, 0.3),
|
||||
# RandomGaussianBlur()"""
|
||||
train_transform = T.Compose([
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size), antialias=True),
|
||||
# T.RandomCrop(img_size * 4 // 5),
|
||||
T.RandomHorizontalFlip(p=0.5),
|
||||
T.RandomRotation(180),
|
||||
T.ColorJitter(brightness=0.5),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
test_transform = T.Compose([
|
||||
# T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size), antialias=True),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
# T.Normalize(mean=[0,0,0], std=[255,255,255]),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
# dataset
|
||||
train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
|
||||
test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
|
||||
|
||||
# training settings
|
||||
checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||
restore = True
|
||||
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||||
restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
|
||||
|
||||
# test settings
|
||||
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||||
|
||||
# test_val = "./data/2250_train"
|
||||
# test_list = "./data/2250_train/val_pair.txt"
|
||||
# test_group_json = "./data/2250_train/cross_same.json"
|
||||
|
||||
test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
|
||||
test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
|
||||
test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
|
||||
# test_group_json = "./data/test/inner_group_pairs.json"
|
||||
|
||||
# test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
|
||||
test_model = "checkpoints/resnet18_1009/best.pth"
|
||||
# test_model = "checkpoints/zhanting/inland/res_801.pth"
|
||||
# test_model = "checkpoints/resnet18_20250504/best.pth"
|
||||
# test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
|
||||
group_test = False
|
||||
# group_test = False
|
||||
|
||||
train_batch_size = 128 # 256
|
||||
test_batch_size = 128 # 256
|
||||
|
||||
epoch = 5 # 512
|
||||
optimizer = 'sgd' # ['sgd', 'adam', 'adamw']
|
||||
lr = 5e-3 # 1e-2
|
||||
lr_step = 10 # 10
|
||||
lr_decay = 0.98 # 0.98
|
||||
weight_decay = 5e-4
|
||||
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
|
||||
log_path = './log'
|
||||
lr_min = 1e-6 # min lr
|
||||
|
||||
pin_memory = False # if memory is large, set it True to speed up a bit
|
||||
num_workers = 32 # 64
|
||||
compare = False # compare the result of different models
|
||||
|
||||
'''
|
||||
train_distill settings
|
||||
'''
|
||||
warmup_epochs = 3 # warmup_epoch
|
||||
distributed = True # distributed training
|
||||
teacher_path = "./checkpoints/resnet50_0519/best.pth"
|
||||
distill_weight = 0.8 # 蒸馏权重
|
||||
|
||||
config = Config()
|
1
configs/__init__.py
Normal file
1
configs/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .utils import trainer_tools
|
69
configs/compare.yml
Normal file
69
configs/compare.yml
Normal file
@ -0,0 +1,69 @@
|
||||
# configs/compare.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||
seed: 42 # 随机种子(保证可复现性)
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
epochs: 600 # 总训练轮次
|
||||
batch_size: 128 # 批次大小
|
||||
lr: 0.001 # 初始学习率
|
||||
optimizer: "sgd" # 优化器类型
|
||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
||||
restore: false
|
||||
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||
|
||||
# 验证参数
|
||||
validation:
|
||||
num_workers: 32 # 数据加载线程数
|
||||
val_batch_size: 128 # 测试批次大小
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 128 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
# 分布式训练(可选)
|
||||
distributed:
|
||||
enabled: false # 是否启用分布式训练
|
||||
backend: "nccl" # 分布式后端(nccl/gloo)
|
75
configs/distill.yml
Normal file
75
configs/distill.yml
Normal file
@ -0,0 +1,75 @@
|
||||
# configs/compare.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||
seed: 42 # 随机种子(保证可复现性)
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 1.0 # 主干特征通道缩放比例(默认)
|
||||
student_channel_ratio: 0.75
|
||||
teacher_model_path: "./checkpoints/resnet50_0519/best.pth"
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
epochs: 600 # 总训练轮次
|
||||
batch_size: 128 # 批次大小
|
||||
lr: 0.001 # 初始学习率
|
||||
optimizer: "sgd" # 优化器类型
|
||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
||||
restore: false
|
||||
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||
distill_weight: 0.8 # 蒸馏损失权重
|
||||
temperature: 4 # 蒸馏温度
|
||||
|
||||
|
||||
|
||||
# 验证参数
|
||||
validation:
|
||||
num_workers: 32 # 数据加载线程数
|
||||
val_batch_size: 128 # 测试批次大小
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 100 # 验证批次大小
|
||||
num_workers: 4 # 数据加载线程数
|
||||
data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
# 分布式训练(可选)
|
||||
distributed:
|
||||
enabled: false # 是否启用分布式训练
|
||||
backend: "nccl" # 分布式后端(nccl/gloo)
|
69
configs/scatter.yml
Normal file
69
configs/scatter.yml
Normal file
@ -0,0 +1,69 @@
|
||||
# configs/scatter.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 1.0
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
epochs: 300 # 总训练轮次
|
||||
batch_size: 64 # 批次大小
|
||||
lr: 0.005 # 初始学习率
|
||||
optimizer: "sgd" # 优化器类型
|
||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||
lr_step: 10 # 学习率调整间隔(epoch)
|
||||
lr_decay: 0.98 # 学习率衰减率
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
|
||||
restore: True
|
||||
restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
|
||||
|
||||
|
||||
|
||||
# 验证参数
|
||||
validation:
|
||||
num_workers: 32 # 数据加载线程数
|
||||
val_batch_size: 128 # 测试批次大小
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
val_batch_size: 100 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
|
||||
data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
# 分布式训练(可选)
|
||||
distributed:
|
||||
enabled: false # 是否启用分布式训练
|
||||
backend: "nccl" # 分布式后端(nccl/gloo)
|
41
configs/test.yml
Normal file
41
configs/test.yml
Normal file
@ -0,0 +1,41 @@
|
||||
# configs/test.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 1.0
|
||||
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
|
||||
half: false # 是否启用半精度测试(fp16)
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
||||
test_batch_size: 128 # 训练批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
test_dir: "../data_center/scatter/" # 验证数据集根目录
|
||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||
test_list: "../data_center/scatter/val_pair.txt"
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
save:
|
||||
save_dir: ""
|
||||
save_name: ""
|
||||
|
||||
|
56
configs/utils.py
Normal file
56
configs/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||
from model.metric import ArcFace, CosFace
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
|
||||
|
||||
class trainer_tools:
|
||||
def __init__(self, conf):
|
||||
self.conf = conf
|
||||
|
||||
def get_backbone(self):
|
||||
backbone_mapping = {
|
||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||
'mobilevit_s': lambda: mobilevit_s(),
|
||||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||
'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
|
||||
'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
|
||||
'mobilenetv3_large': lambda: MobileNetV3_Large(),
|
||||
'vit_base': lambda: vit_base_16(pretrained=True),
|
||||
'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
|
||||
num_classes=self.conf.embedding_size)
|
||||
}
|
||||
return backbone_mapping
|
||||
|
||||
def get_metric(self, class_num):
|
||||
# 优化后的metric选择代码块,使用字典映射提高可读性和扩展性
|
||||
metric_mapping = {
|
||||
'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
|
||||
'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
|
||||
'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
|
||||
}
|
||||
return metric_mapping
|
||||
|
||||
def get_optimizer(self, model, metric):
|
||||
optimizer_mapping = {
|
||||
'sgd': lambda: optim.SGD(
|
||||
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||||
lr=self.conf['training']['lr'],
|
||||
weight_decay=self.conf['training']['weight_decay']
|
||||
),
|
||||
'adam': lambda: optim.Adam(
|
||||
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||||
lr=self.conf['training']['lr'],
|
||||
weight_decay=self.conf['training']['weight_decay']
|
||||
),
|
||||
'adamw': lambda: optim.AdamW(
|
||||
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||||
lr=self.conf['training']['lr'],
|
||||
weight_decay=self.conf['training']['weight_decay']
|
||||
)
|
||||
}
|
||||
return optimizer_mapping
|
47
configs/write_feature.yml
Normal file
47
configs/write_feature.yml
Normal file
@ -0,0 +1,47 @@
|
||||
# configs/write_feature.yml
|
||||
# 专为模型训练对比设计的配置文件
|
||||
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||
|
||||
# 基础配置
|
||||
base:
|
||||
device: "cuda" # 训练设备(cuda/cpu)
|
||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||
embedding_size: 256 # 特征维度
|
||||
distributed: true # 是否启用分布式训练
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
checkpoints: "../checkpoints/resnet18_1009/best.pth"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
train_batch_size: 128 # 训练批次大小
|
||||
test_batch_size: 128 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
half: true # 是否启用半精度数据
|
||||
img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
|
||||
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
|
||||
xlsx_pth: false # 过滤商品, 默认None不进行过滤
|
||||
|
||||
transform:
|
||||
img_size: 224 # 图像尺寸
|
||||
img_mean: 0.5 # 图像均值
|
||||
img_std: 0.5 # 图像方差
|
||||
RandomHorizontalFlip: 0.5 # 随机水平翻转概率
|
||||
RandomRotation: 180 # 随机旋转角度
|
||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||
|
||||
# 日志与监控
|
||||
logging:
|
||||
logging_dir: "./logs" # 日志保存目录
|
||||
tensorboard: true # 是否启用TensorBoard
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
save:
|
||||
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
|
||||
json_path: "../data/feature_json_compare/" # 保存单个json文件
|
||||
error_barcodes: "error_barcodes.txt"
|
||||
barcodes_statistics: "../search_library/barcodes_statistics.txt"
|
88
model/BAM.py
Normal file
88
model/BAM.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
def __int__(self, channel, reduction, num_layers):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
gate_channels = [channel]
|
||||
gate_channels += [len(channel) // reduction] * num_layers
|
||||
gate_channels += [channel]
|
||||
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('flatten', Flatten())
|
||||
for i in range(len(gate_channels) - 2):
|
||||
self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.ReLU())
|
||||
self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.avgpool(x)
|
||||
res = self.ca(res)
|
||||
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
|
||||
super(SpatialAttention).__init__()
|
||||
self.sa = nn.Sequential()
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
|
||||
self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
for i in range(num_lay):
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=3,
|
||||
in_channels=(channel // reduction),
|
||||
out_channels=(channel // reduction),
|
||||
padding=1,
|
||||
dilation=2))
|
||||
self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.sa(x)
|
||||
res = res.expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class BAMblock(nn.Module):
|
||||
def __init__(self, channel=512, reduction=16, dia_val=2):
|
||||
super(BAMblock, self).__init__()
|
||||
self.ca = ChannelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(channel, reduction, dia_val)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal(m.weight, mode='fan_out')
|
||||
if m.bais is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
sa_out = self.sa(x)
|
||||
ca_out = self.ca(x)
|
||||
weight = self.sigmoid(sa_out + ca_out)
|
||||
out = (1 + weight) * x
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(512 // 14)
|
70
model/CBAM.py
Normal file
70
model/CBAM.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
class channelAttention(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(channelAttention, self).__init__()
|
||||
self.Maxpooling = nn.AdaptiveMaxPool2d(1)
|
||||
self.Avepooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
|
||||
self.ca.add_module('Relu', nn.ReLU())
|
||||
self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
M_out = self.Maxpooling(x)
|
||||
A_out = self.Avepooling(x)
|
||||
M_out = self.ca(M_out)
|
||||
A_out = self.ca(A_out)
|
||||
out = self.sigmod(M_out+A_out)
|
||||
return out
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
max_result, _ = torch.max(x, dim=1, keepdim=True)
|
||||
avg_result = torch.mean(x, dim=1, keepdim=True)
|
||||
result = torch.cat([max_result, avg_result], dim=1)
|
||||
output = self.conv(result)
|
||||
output = self.sigmoid(output)
|
||||
return output
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, channel, reduction=16, kernel_size=7):
|
||||
super().__init__()
|
||||
self.ca = channelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(kernel_size)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():#权重初始化
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
# b,c_,_ = x.size()
|
||||
# residual = x
|
||||
out = x*self.ca(x)
|
||||
out = out*self.sa(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
input=torch.randn(50,512,7,7)
|
||||
kernel_size=input.shape[2]
|
||||
cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
|
||||
output=cbam(input)
|
||||
print(output.shape)
|
37
model/Tool.py
Normal file
37
model/Tool.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM, self).__init__()
|
||||
self.p = nn.Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return self.gem(x, p=self.p, eps=self.eps, stride=2)
|
||||
|
||||
def gem(self, x, p=3, eps=1e-6, stride=2):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||
', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
def __init__(self, margin):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, anchor, positive, negative, size_average=True):
|
||||
distance_positive = (anchor - positive).pow(2).sum(1)
|
||||
distance_negative = (anchor - negative).pow(2).sum(1)
|
||||
losses = F.relu(distance_negative - distance_positive + self.margin)
|
||||
return losses.mean() if size_average else losses.sum()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('')
|
14
model/__init__.py
Normal file
14
model/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from .fmobilenet import FaceMobileNet
|
||||
# from .resnet_face import ResIRSE
|
||||
from .mobilevit import mobilevit_s
|
||||
from .metric import ArcFace, CosFace
|
||||
from .loss import FocalLoss
|
||||
from .resbam import resnet
|
||||
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
|
||||
from .mobilenet_v2 import mobilenet_v2
|
||||
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||
# from .mobilenet_v1 import mobilenet_v1
|
||||
from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, \
|
||||
PPLCNET_x2_5
|
||||
from .vit import vit_base
|
||||
from .mlp import MLP
|
BIN
model/__pycache__/CBAM.cpython-38.pyc
Normal file
BIN
model/__pycache__/CBAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/Tool.cpython-38.pyc
Normal file
BIN
model/__pycache__/Tool.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
BIN
model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/lcnet.cpython-38.pyc
Normal file
BIN
model/__pycache__/lcnet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mlp.cpython-38.pyc
Normal file
BIN
model/__pycache__/mlp.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v1.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v1.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v2.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/resbam.cpython-38.pyc
Normal file
BIN
model/__pycache__/resbam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
BIN
model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/utils.cpython-38.pyc
Normal file
BIN
model/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/vit.cpython-38.pyc
Normal file
BIN
model/__pycache__/vit.cpython-38.pyc
Normal file
Binary file not shown.
142
model/benchmark.py
Normal file
142
model/benchmark.py
Normal file
@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import numpy as np
|
||||
from resnet_attention import resnet18_cbam, resnet34_cbam, resnet50_cbam
|
||||
|
||||
# 设置随机种子以确保结果可复现
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
# 设备配置
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"测试设备: {device}")
|
||||
|
||||
# 测试参数
|
||||
batch_sizes = [1, 4, 8, 16]
|
||||
image_sizes = [224, 384, 512]
|
||||
num_runs = 100 # 每个配置运行的次数
|
||||
warmup_runs = 20 # 预热运行次数,排除启动开销
|
||||
|
||||
# 模型配置
|
||||
model_configs = {
|
||||
"resnet18": {
|
||||
"base_model": lambda: resnet18_cbam(use_cbam=False),
|
||||
"attention_model": lambda: resnet18_cbam(use_cbam=True)
|
||||
},
|
||||
"resnet34": {
|
||||
"base_model": lambda: resnet34_cbam(use_cbam=False),
|
||||
"attention_model": lambda: resnet34_cbam(use_cbam=True)
|
||||
},
|
||||
"resnet50": {
|
||||
"base_model": lambda: resnet50_cbam(use_cbam=False),
|
||||
"attention_model": lambda: resnet50_cbam(use_cbam=True)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 基准测试函数
|
||||
def benchmark_model(model, input_size, batch_size, num_runs, warmup_runs):
|
||||
"""
|
||||
测试模型的推理性能
|
||||
|
||||
参数:
|
||||
- model: 待测试的模型
|
||||
- input_size: 输入图像尺寸
|
||||
- batch_size: 批次大小
|
||||
- num_runs: 测试运行次数
|
||||
- warmup_runs: 预热运行次数
|
||||
|
||||
返回:
|
||||
- 平均推理时间(毫秒)
|
||||
- 吞吐量(样本/秒)
|
||||
"""
|
||||
# 设置为评估模式
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# 创建随机输入
|
||||
input_tensor = torch.randn(batch_size, 3, input_size, input_size, device=device)
|
||||
|
||||
# 预热
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup_runs):
|
||||
_ = model(input_tensor)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize() # 同步GPU操作
|
||||
|
||||
# 测量推理时间
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _ in range(num_runs):
|
||||
_ = model(input_tensor)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize() # 同步GPU操作
|
||||
end_time = time.time()
|
||||
|
||||
# 计算指标
|
||||
total_time = end_time - start_time
|
||||
avg_time_per_batch = total_time / num_runs * 1000 # 毫秒
|
||||
throughput = batch_size * num_runs / total_time # 样本/秒
|
||||
|
||||
return avg_time_per_batch, throughput
|
||||
|
||||
|
||||
# 运行测试
|
||||
results = {}
|
||||
|
||||
for model_name, config in model_configs.items():
|
||||
results[model_name] = {}
|
||||
|
||||
# 创建模型
|
||||
base_model = config["base_model"]()
|
||||
attention_model = config["attention_model"]()
|
||||
|
||||
# 计算参数量
|
||||
base_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
|
||||
attention_params = sum(p.numel() for p in attention_model.parameters() if p.requires_grad)
|
||||
param_increase = (attention_params - base_params) / base_params * 100
|
||||
|
||||
print(f"\n测试模型: {model_name}")
|
||||
print(f" 基础参数量: {base_params / 1e6:.2f}M")
|
||||
print(f" 带注意力参数量: {attention_params / 1e6:.2f}M")
|
||||
print(f" 参数量增加: {param_increase:.2f}%")
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for image_size in image_sizes:
|
||||
key = f"batch_{batch_size}_size_{image_size}"
|
||||
results[model_name][key] = {}
|
||||
|
||||
# 测试基础模型
|
||||
base_time, base_throughput = benchmark_model(
|
||||
base_model, image_size, batch_size, num_runs, warmup_runs
|
||||
)
|
||||
|
||||
# 测试注意力模型
|
||||
attention_time, attention_throughput = benchmark_model(
|
||||
attention_model, image_size, batch_size, num_runs, warmup_runs
|
||||
)
|
||||
|
||||
# 计算增加的百分比
|
||||
time_increase = (attention_time - base_time) / base_time * 100
|
||||
throughput_decrease = (base_throughput - attention_throughput) / base_throughput * 100
|
||||
|
||||
results[model_name][key]["base_time"] = base_time
|
||||
results[model_name][key]["attention_time"] = attention_time
|
||||
results[model_name][key]["time_increase"] = time_increase
|
||||
results[model_name][key]["base_throughput"] = base_throughput
|
||||
results[model_name][key]["attention_throughput"] = attention_throughput
|
||||
results[model_name][key]["throughput_decrease"] = throughput_decrease
|
||||
|
||||
print(f" 配置: 批次大小={batch_size}, 图像尺寸={image_size}x{image_size}")
|
||||
print(f" 基础模型: 平均时间={base_time:.2f}ms, 吞吐量={base_throughput:.2f}样本/秒")
|
||||
print(f" 注意力模型: 平均时间={attention_time:.2f}ms, 吞吐量={attention_throughput:.2f}样本/秒")
|
||||
print(f" 时间增加: {time_increase:.2f}%, 吞吐量下降: {throughput_decrease:.2f}%")
|
||||
|
||||
# 保存结果
|
||||
import json
|
||||
|
||||
with open('benchmark_results.json', 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print("\n测试完成,结果已保存到 benchmark_results.json")
|
48
model/compare.py
Normal file
48
model/compare.py
Normal file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from config import config as conf
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from model.resnet_pre import resnet18, resnet50
|
||||
# from model.vit import vit_base_patch16_224, vit_base_patch32_224
|
||||
|
||||
|
||||
class ContrastiveModel(nn.Module):
|
||||
def __init__(self, projection_dim, model_name, contraposition=False):
|
||||
super(ContrastiveModel, self).__init__()
|
||||
self.contraposition = contraposition
|
||||
self.base_model = self._get_model(model_name)
|
||||
if not self.contraposition:
|
||||
if 'vit' in model_name:
|
||||
dim_mlp = self.base_model.head.weight.shape[1]
|
||||
self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
|
||||
else:
|
||||
dim_mlp = self.base_model.fc.weight.shape[1]
|
||||
self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
|
||||
# # 冻结除 FC 层之外的所有层
|
||||
# for name, param in self.base_model.named_parameters():
|
||||
# if 'fc' not in name:
|
||||
# param.requires_grad = False
|
||||
|
||||
def _get_projection_layer(self, dim_mlp, projection_dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(dim_mlp, dim_mlp),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(dim_mlp, projection_dim)
|
||||
)
|
||||
|
||||
def _get_model(self, model_name):
|
||||
base_model = None
|
||||
if model_name == 'resnet18':
|
||||
base_model = resnet18(pretrained=True)
|
||||
elif model_name == 'resnet50':
|
||||
base_model = resnet50(pretrained=True)
|
||||
# elif model_name == 'vit':
|
||||
# base_model = vit_base_patch32_224()
|
||||
return base_model
|
||||
def forward(self, x):
|
||||
assert self.base_model is not None, 'base_model is none'
|
||||
x = self.base_model(x)
|
||||
return x
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
182
model/distill.py
Normal file
182
model/distill.py
Normal file
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
from einops import repeat
|
||||
from config import config as conf
|
||||
# helpers
|
||||
# Data Setup
|
||||
from tools.dataset import load_data
|
||||
train_dataloader, class_num = load_data(conf, training=True)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token=None):
|
||||
distilling = exists(distill_token)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b=b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b=b)
|
||||
x = torch.cat((x, distill_tokens), dim=1)
|
||||
|
||||
x = self._attend(x)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
|
||||
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
out = self.mlp_head(x)
|
||||
|
||||
if distilling:
|
||||
return out, distill_tokens
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DistillableViT(DistillMixin, ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = ViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
|
||||
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableT2TViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = T2TViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
|
||||
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
|
||||
def to_vit(self):
|
||||
v = EfficientViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
return self.transformer(x)
|
||||
|
||||
|
||||
# knowledge distillation wrapper
|
||||
|
||||
class DistillWrapper(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
teacher,
|
||||
student,
|
||||
temperature=1.,
|
||||
alpha=0.5,
|
||||
hard=False,
|
||||
mlp_layernorm=False
|
||||
):
|
||||
super().__init__()
|
||||
# assert (isinstance(student, (
|
||||
# DistillableViT, DistillableT2TViT, DistillableEfficientViT))), 'student must be a vision transformer'
|
||||
if isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT)):
|
||||
pass
|
||||
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
dim = conf.embedding_size # student.dim
|
||||
num_classes = class_num # class_num # student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
# student is vit
|
||||
# self.distill_mlp = nn.Sequential(
|
||||
# nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
# nn.Linear(dim, num_classes)
|
||||
# )
|
||||
|
||||
# student is resnet
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
nn.Linear(dim, num_classes).to(device)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature=None, alpha=None, **kwargs):
|
||||
|
||||
alpha = default(alpha, self.alpha)
|
||||
T = default(temperature, self.temperature)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(img)
|
||||
teacher_logits = self.distill_mlp(teacher_logits) # teach is vit 初始化
|
||||
# student is vit
|
||||
# student_logits, distill_tokens = self.student(img, distill_token=self.distillation_token, **kwargs)
|
||||
# distill_logits = self.distill_mlp(distill_tokens)
|
||||
|
||||
# student is resnet
|
||||
student_logits = self.student(img)
|
||||
distill_logits = self.distill_mlp(student_logits)
|
||||
loss = F.cross_entropy(distill_logits, labels)
|
||||
# pdb.set_trace()
|
||||
if not self.hard:
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim=-1),
|
||||
F.softmax(teacher_logits / T, dim=-1).detach(),
|
||||
reduction='batchmean')
|
||||
distill_loss *= T ** 2
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim=-1)
|
||||
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
|
||||
# pdb.set_trace()
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
124
model/fmobilenet.py
Normal file
124
model/fmobilenet.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
class ConvBn(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class ConvBnPrelu(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBn(in_c, out_c, kernel, stride, padding, groups),
|
||||
nn.PReLU(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWise(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
|
||||
ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
|
||||
ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWiseRes(nn.Module):
|
||||
"""DepthWise with Residual"""
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
|
||||
class MultiDepthWiseRes(nn.Module):
|
||||
|
||||
def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
|
||||
super().__init__()
|
||||
|
||||
self.net = nn.Sequential(*[
|
||||
DepthWiseRes(channels, channels, kernel, stride, padding, groups)
|
||||
for _ in range(num_block)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FaceMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
|
||||
self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
|
||||
self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
|
||||
self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
|
||||
self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
|
||||
self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
|
||||
self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
|
||||
self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
|
||||
self.flatten = Flatten()
|
||||
self.linear = nn.Linear(2048, embedding_size, bias=False)
|
||||
self.bn = nn.BatchNorm1d(embedding_size)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
out = self.conv4(out)
|
||||
out = self.conv5(out)
|
||||
out = self.conv6(out)
|
||||
out = self.conv7(out)
|
||||
out = self.conv8(out)
|
||||
out = self.conv9(out)
|
||||
out = self.conv10(out)
|
||||
out = self.flatten(out)
|
||||
out = self.linear(out)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
x = Image.open("../samples/009.jpg").convert('L')
|
||||
x = x.resize((128, 128))
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
x = x[None, None, ...]
|
||||
x = torch.from_numpy(x)
|
||||
net = FaceMobileNet(512)
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
out = net(x)
|
||||
print(out.shape)
|
233
model/lcnet.py
Normal file
233
model/lcnet.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import thop
|
||||
|
||||
# try:
|
||||
# import softpool_cuda
|
||||
# from SoftPool import soft_pool2d, SoftPool2d
|
||||
# except ImportError:
|
||||
# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
|
||||
# exit(0)
|
||||
|
||||
NET_CONFIG = {
|
||||
# k, in_c, out_c, s, use_se
|
||||
"blocks2": [[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||
}
|
||||
|
||||
|
||||
def autopad(k, p=None):
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||
return p
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSwish, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.relu6(x+3) / 6
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.relu6(x+3)) / 6
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
HardSigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
y = self.avgpool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(self, inp, oup, dw_size, stride, use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self.stride = stride
|
||||
self.inp = inp
|
||||
self.oup = oup
|
||||
self.dw_size = dw_size
|
||||
self.dw_sp = nn.Sequential(
|
||||
nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
|
||||
padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
|
||||
nn.BatchNorm2d(self.inp),
|
||||
HardSwish(),
|
||||
|
||||
nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.oup),
|
||||
HardSwish(),
|
||||
)
|
||||
self.se = SELayer(self.oup)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dw_sp(x)
|
||||
if self.use_se:
|
||||
x = self.se(x)
|
||||
return x
|
||||
|
||||
|
||||
class PP_LCNet(nn.Module):
|
||||
def __init__(self, scale=1.0, class_num=256, class_expand=1280, dropout_prob=0.2):
|
||||
super(PP_LCNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
|
||||
kernel_size=3, stride=2, padding=1, bias=False)
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks2 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
|
||||
])
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks5 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
|
||||
])
|
||||
|
||||
self.GAP = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
|
||||
out_channels=class_expand,
|
||||
kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
self.hardswish = HardSwish()
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
|
||||
self.fc = nn.Linear(class_expand, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks2(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks3(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks4(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks5(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks6(x)
|
||||
# print(x.shape)
|
||||
|
||||
x = self.GAP(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = torch.flatten(x, start_dim=1, end_dim=-1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def PPLCNET_x0_25(**kwargs):
|
||||
model = PP_LCNet(scale=0.25, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_35(**kwargs):
|
||||
model = PP_LCNet(scale=0.35, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_5(**kwargs):
|
||||
model = PP_LCNet(scale=0.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_75(**kwargs):
|
||||
model = PP_LCNet(scale=0.75, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_0(**kwargs):
|
||||
model = PP_LCNet(scale=1.0, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_5(**kwargs):
|
||||
model = PP_LCNet(scale=1.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x2_0(**kwargs):
|
||||
model = PP_LCNet(scale=2.0, **kwargs)
|
||||
return model
|
||||
|
||||
def PPLCNET_x2_5(**kwargs):
|
||||
model = PP_LCNet(scale=2.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# input = torch.randn(1, 3, 640, 640)
|
||||
# model = PPLCNET_x2_5()
|
||||
# flops, params = thop.profile(model, inputs=(input,))
|
||||
# print('flops:', flops / 1000000000)
|
||||
# print('params:', params / 1000000)
|
||||
|
||||
model = PPLCNET_x1_0()
|
||||
# model_1 = PW_Conv(3, 16)
|
||||
input = torch.randn(2, 3, 256, 256)
|
||||
print(input.shape)
|
||||
output = model(input)
|
||||
print(output.shape) # [1, num_class]
|
||||
|
18
model/loss.py
Normal file
18
model/loss.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma=2):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.ce = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
||||
logp = self.ce(input, target)
|
||||
p = torch.exp(-logp)
|
||||
loss = (1 - p) ** self.gamma * logp
|
||||
return loss.mean()
|
94
model/metric.py
Normal file
94
model/metric.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Definition of ArcFace loss and CosFace loss
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ArcFace(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
||||
"""ArcFace formula:
|
||||
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
||||
Note that:
|
||||
0 <= m + theta <= Pi
|
||||
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
||||
we have:
|
||||
cos(theta) < cos(Pi - m)
|
||||
So we can use cos(Pi - m) as threshold to check whether
|
||||
(m + theta) go out of [0, Pi]
|
||||
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = embedding_size
|
||||
self.out_features = class_num
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
self.cos_m = math.cos(m)
|
||||
self.sin_m = math.sin(m)
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
def forward(self, input, label):
|
||||
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
# print('F.normalize(input)',input.shape)
|
||||
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
||||
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
||||
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
||||
# update y_i by phi in cosine
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
# print(f'output {(output * self.s).shape}')
|
||||
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
||||
return output * self.s
|
||||
|
||||
|
||||
class CosFace(nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||
"""
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, input, label):
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
phi = cosine - self.m
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
return output * self.s
|
||||
|
||||
class Distillation(nn.Module):
|
||||
def __init__(self, in_features, out_features, T=1.0):
|
||||
super(Distillation, self).__init__()
|
||||
self.T = T
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
def forward(self, input, labels):
|
||||
pass
|
274
model/mlp.py
Normal file
274
model/mlp.py
Normal file
@ -0,0 +1,274 @@
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
from model.resnet_pre import resnet18, conv1x1, BasicBlock, load_state_dict_from_url, model_urls
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim=256, output_dim=1):
|
||||
super(MLP, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.fc1 = nn.Linear(self.input_dim, 128) # 32
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 32)
|
||||
self.fc4 = nn.Linear(32, 16)
|
||||
self.fc5 = nn.Linear(16, self.output_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.bn1 = nn.BatchNorm1d(128)
|
||||
self.bn2 = nn.BatchNorm1d(64)
|
||||
self.bn3 = nn.BatchNorm1d(32)
|
||||
self.bn4 = nn.BatchNorm1d(16)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(self.bn1(x))
|
||||
x = self.fc2(x)
|
||||
x = self.relu(self.bn2(x))
|
||||
x = self.fc3(x)
|
||||
x = self.relu(self.bn3(x))
|
||||
x = self.fc4(x)
|
||||
x = self.relu(self.bn4(x))
|
||||
x = self.sigmoid(self.fc5(x))
|
||||
return x
|
||||
|
||||
|
||||
class Net2(nn.Module): # 该网络部署有风险,dnn推理有障碍
|
||||
def __init__(self, input_dim=960, output_dim=1):
|
||||
super(Net2, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
|
||||
# self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
|
||||
# self.conv4 = nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=1)
|
||||
self.maxPool1 = nn.MaxPool1d(kernel_size=3, stride=2)
|
||||
self.conv5 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=1)
|
||||
self.maxPool2 = nn.MaxPool1d(kernel_size=3, stride=2)
|
||||
|
||||
self.avgPool = nn.AdaptiveAvgPool1d(1)
|
||||
self.MaxPool = nn.AdaptiveMaxPool1d(1)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.flatten = nn.Flatten()
|
||||
# self.conv6 = nn.Conv1d(128, 128, kernel_size=5, stride=2, padding=1)
|
||||
self.fc1 = nn.Linear(960, 128)
|
||||
self.fc21 = nn.Linear(960, 32)
|
||||
self.fc22 = nn.Linear(32, 128)
|
||||
self.fc3 = nn.Linear(128, 1)
|
||||
self.bn1 = nn.BatchNorm1d(16)
|
||||
self.bn2 = nn.BatchNorm1d(32)
|
||||
self.bn3 = nn.BatchNorm1d(64)
|
||||
self.bn4 = nn.BatchNorm1d(128)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x) # 16
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x) # 32
|
||||
x = self.relu(x)
|
||||
# x = self.conv3(x)
|
||||
# x = self.relu(x)
|
||||
# x = self.conv4(x) # 64
|
||||
# x = self.relu(x)
|
||||
# x = self.maxPool1(x)
|
||||
|
||||
x = self.conv5(x)
|
||||
x = self.relu(x)
|
||||
# x = self.conv6(x)
|
||||
# x = self.relu(x)
|
||||
# x = self.maxPool2(x)
|
||||
# x = self.MaxPool(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(x)
|
||||
x = self.flatten(x)
|
||||
|
||||
# pdb.set_trace()
|
||||
x1 = self.fc1(x)
|
||||
x2 = self.fc22(self.fc21(x))
|
||||
x = self.fc3(x1 + x2)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
class Net3(nn.Module): # 目前较合适的网络结构,相较于Net2,Net3的输出结果更加准确
|
||||
def __init__(self, pretrained=True, progress=True, num_classes=1, scale=0.75):
|
||||
super(Net3, self).__init__()
|
||||
self.resnet18 = resnet18(pretrained=pretrained, progress=progress)
|
||||
|
||||
# Remove the last three layers (layer3, layer4, avgpool, fc)
|
||||
# self.resnet18.layer3 = nn.Identity()
|
||||
# self.resnet18.layer4 = nn.Identity()
|
||||
self.resnet18.avgpool = nn.Identity()
|
||||
self.resnet18.fc = nn.Identity()
|
||||
self.flatten = nn.Flatten()
|
||||
# Calculate the output size after layer2
|
||||
# Assuming input size is 224x224, layer2 will have output size of 56x56
|
||||
# So, the flattened size will be 128 * scale * 56 * 56
|
||||
self.flattened_size = int(128 * (56 * 56) * scale * scale)
|
||||
|
||||
# Add new layers for classification
|
||||
self.classifier = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(),
|
||||
nn.Linear(384, num_classes), # layer1, layer2 in_features=96 # layer1 in_features=48 #layer3 in_features=192
|
||||
# nn.ReLU(),
|
||||
nn.Dropout(0.6),
|
||||
# nn.Linear(256, num_classes),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.resnet18.layer1(x)
|
||||
x = self.resnet18.layer2(x)
|
||||
x = self.resnet18.layer3(x)
|
||||
x = self.resnet18.layer4(x)
|
||||
|
||||
# Debugging: Print the shape of the tensor before flattening
|
||||
# print("Shape before flattening:", x.shape)
|
||||
|
||||
# Ensure the tensor is flattened correctly
|
||||
# x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=0.75):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def Net4(arch, pretrained, progress, **kwargs):
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
net2 = Net2()
|
||||
input_tensor = torch.randn(10, 1, 64)
|
||||
# 前向传播
|
||||
output_tensor = net2(input_tensor)
|
||||
# pdb.set_trace()
|
||||
print("输入张量形状:", input_tensor.shape)
|
||||
print("输出张量形状:", output_tensor.shape)
|
||||
'''
|
||||
|
||||
# model = Net3(pretrained=True, num_classes=1) # 预训练从resnet中间结果获取数据训练模型
|
||||
model = Net4('resnet18', True, True)
|
||||
input_tensor = torch.randn(1, 3, 224, 244) # Adjust batch size to 10
|
||||
output = model(input_tensor)
|
||||
print(output.shape) # Should be [10, 2]
|
148
model/mobilenet_v1.py
Normal file
148
model/mobilenet_v1.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torchvision.ops.misc import Conv2dNormActivation
|
||||
from config import config as conf
|
||||
|
||||
__all__ = [
|
||||
"MobileNetV1",
|
||||
"DepthWiseSeparableConv2d",
|
||||
"mobilenet_v1",
|
||||
]
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = conf.embedding_size,
|
||||
) -> None:
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
Conv2dNormActivation(3,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
DepthWiseSeparableConv2d(32, 64, 1),
|
||||
DepthWiseSeparableConv2d(64, 128, 2),
|
||||
DepthWiseSeparableConv2d(128, 128, 1),
|
||||
DepthWiseSeparableConv2d(128, 256, 2),
|
||||
DepthWiseSeparableConv2d(256, 256, 1),
|
||||
DepthWiseSeparableConv2d(256, 512, 2),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 1024, 2),
|
||||
DepthWiseSeparableConv2d(1024, 1024, 1),
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d((7, 7))
|
||||
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
|
||||
# Initialize neural network weights
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self._forward_impl(x)
|
||||
|
||||
return out
|
||||
|
||||
# Support torch.script function
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
out = self.features(x)
|
||||
out = self.avgpool(out)
|
||||
out = torch.flatten(out, 1)
|
||||
out = self.classifier(out)
|
||||
|
||||
return out
|
||||
|
||||
def _initialize_weights(self) -> None:
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0, 0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
class DepthWiseSeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(DepthWiseSeparableConv2d, self).__init__()
|
||||
self.stride = stride
|
||||
if stride not in [1, 2]:
|
||||
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
Conv2dNormActivation(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
Conv2dNormActivation(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self.conv(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
|
||||
model = MobileNetV1(**kwargs)
|
||||
|
||||
return model
|
200
model/mobilenet_v2.py
Normal file
200
model/mobilenet_v2.py
Normal file
@ -0,0 +1,200 @@
|
||||
from torch import nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from config import config as conf
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
|
||||
padding = (kernel_size - 1) // 2
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
norm_layer(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
norm_layer(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=conf.embedding_size,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
norm_layer=None):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
norm_layer: Module specifying the normalization layer to use
|
||||
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
|
||||
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNetV2 architecture from
|
||||
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||
progress=progress)
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
#.load_state_dict(state_dict)
|
||||
return model
|
200
model/mobilenet_v3.py
Normal file
200
model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
||||
'''MobileNetV3 in PyTorch.
|
||||
|
||||
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from config import config as conf
|
||||
|
||||
|
||||
class hswish(nn.Module):
|
||||
def forward(self, x):
|
||||
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class hsigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
out = F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class SeModule(nn.Module):
|
||||
def __init__(self, in_size, reduction=4):
|
||||
super(SeModule, self).__init__()
|
||||
self.se = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size),
|
||||
hsigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.se(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expand + depthwise + pointwise'''
|
||||
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
self.se = semodule
|
||||
|
||||
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear1 = nolinear
|
||||
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear2 = nolinear
|
||||
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_size)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 1 and in_size != out_size:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_size),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.se != None:
|
||||
out = self.se(out)
|
||||
out = out + self.shortcut(x) if self.stride==1 else out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV3_Large(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Large, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(960)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(960, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class MobileNetV3_Small(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Small, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(576)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(576, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNetV3_Small()
|
||||
x = torch.randn(2,3,224,224)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
265
model/mobilevit.py
Normal file
265
model/mobilevit.py
Normal file
@ -0,0 +1,265 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from config import config as conf
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
|
||||
pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
|
||||
super().__init__()
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
L = [2, 4, 3]
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
|
||||
|
||||
self.mv2 = nn.ModuleList([])
|
||||
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
|
||||
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
|
||||
|
||||
self.mvit = nn.ModuleList([])
|
||||
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
|
||||
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
|
||||
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
|
||||
|
||||
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
|
||||
|
||||
self.pool = nn.AvgPool2d(ih // 32, 1)
|
||||
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
x = self.conv1(x)
|
||||
x = self.mv2[0](x)
|
||||
|
||||
x = self.mv2[1](x)
|
||||
x = self.mv2[2](x)
|
||||
x = self.mv2[3](x) # Repeat
|
||||
|
||||
x = self.mv2[4](x)
|
||||
x = self.mvit[0](x)
|
||||
|
||||
x = self.mv2[5](x)
|
||||
x = self.mvit[1](x)
|
||||
|
||||
x = self.mv2[6](x)
|
||||
x = self.mvit[2](x)
|
||||
x = self.conv2(x)
|
||||
|
||||
|
||||
#print('pool_before',x.shape)
|
||||
x = self.pool(x).view(-1, x.shape[1])
|
||||
#print('self_pool',self.pool)
|
||||
#print('pool_after',x.shape)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobilevit_xxs():
|
||||
dims = [64, 80, 96]
|
||||
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
|
||||
|
||||
|
||||
def mobilevit_xs():
|
||||
dims = [96, 120, 144]
|
||||
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000)
|
||||
|
||||
|
||||
def mobilevit_s():
|
||||
dims = [144, 192, 240]
|
||||
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
|
||||
return MobileViT((conf.img_size, conf.img_size), dims, channels, num_classes=conf.embedding_size)
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = torch.randn(5, 3, 256, 256)
|
||||
|
||||
vit = mobilevit_xxs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_xs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_s()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
412
model/quant_test_resnet.py
Normal file
412
model/quant_test_resnet.py
Normal file
@ -0,0 +1,412 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class QuantizableBasicBlock(BasicBlock):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_relu = torch.nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.add_relu.add_relu(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
zero_init_residual: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _resnet(
|
||||
arch: str,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
) -> ResNet:
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||
return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
142
model/resbam.py
Normal file
142
model/resbam.py
Normal file
@ -0,0 +1,142 @@
|
||||
from model.CBAM import CBAM
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.Tool import GeM as gem
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inchannel, outchannel, stride=1, dowsample=None):
|
||||
# super(Bottleneck, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(outchannel)
|
||||
self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False,
|
||||
stride=stride, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(outchannel)
|
||||
self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False,
|
||||
kernel_size=1)
|
||||
self.bn3 = nn.BatchNorm2d(outchannel * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = dowsample
|
||||
|
||||
def forward(self, x):
|
||||
self.identity = x
|
||||
# print('>>>>>>>>',type(x))
|
||||
if self.downsample is not None:
|
||||
# print('>>>>downsample>>>>', type(self.downsample))
|
||||
self.identity = self.downsample(x)
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
# print('>>>>out>>>identity',out.size(),self.identity.size())
|
||||
out = out + self.identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class resnet(nn.Module):
|
||||
def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000):
|
||||
super().__init__()
|
||||
self.in_channel = 64
|
||||
self.conv1 = nn.Conv2d(in_channels=3,
|
||||
out_channels=self.in_channel,
|
||||
stride=2,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.cbam = CBAM(self.in_channel)
|
||||
self.cbam1 = CBAM(2048)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.gem = gem()
|
||||
self.fc = nn.Linear(512 * block.expansion, num_class)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight, mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, channel, block_num, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(channel * block.expansion))
|
||||
layer = []
|
||||
layer.append(block(self.in_channel, channel, stride, downsample))
|
||||
self.in_channel = channel * block.expansion
|
||||
for _ in range(1, block_num):
|
||||
layer.append(block(self.in_channel, channel))
|
||||
return nn.Sequential(*layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.cbam(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.cbam1(x)
|
||||
# x = self.avgpool(x)
|
||||
x = self.gem(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class TripletNet(nn.Module):
|
||||
def __init__(self, num_class, flag=True):
|
||||
super(TripletNet, self).__init__()
|
||||
self.initnet = rescbam(num_class)
|
||||
self.flag = flag
|
||||
|
||||
def forward(self, x1, x2=None, x3=None):
|
||||
if self.flag:
|
||||
output1 = self.initnet(x1)
|
||||
output2 = self.initnet(x2)
|
||||
output3 = self.initnet(x3)
|
||||
return output1, output2, output3
|
||||
else:
|
||||
output = self.initnet(x1)
|
||||
return output
|
||||
|
||||
|
||||
def rescbam(num_class):
|
||||
return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
input1 = torch.randn(4, 3, 640, 640)
|
||||
input2 = torch.randn(4, 3, 640, 640)
|
||||
input3 = torch.randn(4, 3, 640, 640)
|
||||
|
||||
# rescbam测试
|
||||
# Resnet50 = rescbam(512)
|
||||
# output = Resnet50.forward(input1)
|
||||
# print(Resnet50)
|
||||
|
||||
# trnet测试
|
||||
trnet = TripletNet(512)
|
||||
output = trnet(input1, input2, input3)
|
||||
print(output)
|
189
model/resnet.py
Normal file
189
model/resnet.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""resnet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||
|
||||
Deep Residual Learning for Image Recognition
|
||||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from config import config as conf
|
||||
from CBAM import CBAM
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic Block for resnet 18 and resnet 34
|
||||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, cbam = False, num_classes=conf.embedding_size):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 64
|
||||
|
||||
# self.conv1 = nn.Sequential(
|
||||
# nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64),
|
||||
# nn.ReLU(inplace=True))
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64,stride=2,kernel_size=7,padding=3,bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
self.cbam = CBAM(self.in_channels)
|
||||
|
||||
#we use a different inputsize than the original paper
|
||||
#so conv2_x's stride is 1
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||
self.cbam1 = CBAM(self.in_channels)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight,mode = 'fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
same as a neuron netowork layer, ex. conv layer), one layer may
|
||||
contain more than one residual block
|
||||
|
||||
Args:
|
||||
block: block type, basic block or bottle neck block
|
||||
out_channels: output depth channel number of this layer
|
||||
num_blocks: how many blocks per layer
|
||||
stride: the stride of the first block of this layer
|
||||
|
||||
Return:
|
||||
return a resnet layer
|
||||
"""
|
||||
|
||||
# we have num_block blocks per layer, the first block
|
||||
# could be 1 or 2, other blocks would always be 1
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.conv1(x)
|
||||
if cbam:
|
||||
output = self.cbam(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
output = self.conv4_x(output)
|
||||
output = self.conv5_x(output)
|
||||
if cbam:
|
||||
output = self.cbam1(x)
|
||||
print('pollBefore',output.shape)
|
||||
output = self.avg_pool(output)
|
||||
print('poolAfter',output.shape)
|
||||
output = output.view(output.size(0), -1)
|
||||
print('fcBefore',output.shape)
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
||||
def resnet18(cbam = False):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], cbam)
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
271
model/resnet_attention.py
Normal file
271
model/resnet_attention.py
Normal file
@ -0,0 +1,271 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""通道注意力模块,通过全局平均池化和最大池化提取特征,经过MLP生成通道权重"""
|
||||
|
||||
def __init__(self, in_channels, reduction_ratio=16):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
# 共享的MLP层
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
out = avg_out + max_out
|
||||
return torch.sigmoid(out)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""空间注意力模块,通过通道维度的平均和最大值操作,生成空间权重"""
|
||||
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
out = torch.cat([avg_out, max_out], dim=1)
|
||||
out = self.conv(out)
|
||||
return torch.sigmoid(out)
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""CBAM注意力模块,串联通道注意力和空间注意力"""
|
||||
|
||||
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
|
||||
super(CBAM, self).__init__()
|
||||
self.channel_att = ChannelAttention(in_channels, reduction_ratio)
|
||||
self.spatial_att = SpatialAttention(kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * self.channel_att(x)
|
||||
x = x * self.spatial_att(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""ResNet基础残差块,适用于ResNet18和ResNet34"""
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
# 是否使用CBAM注意力机制
|
||||
self.use_cbam = use_cbam
|
||||
if use_cbam:
|
||||
self.cbam = CBAM(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
# # 如果使用注意力机制,应用CBAM
|
||||
if self.use_cbam:
|
||||
out = self.cbam(out)
|
||||
|
||||
# 如果有下采样,调整shortcut连接
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# 残差连接
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""ResNet瓶颈残差块,适用于ResNet50及更深的网络"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
# 1x1卷积降维
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
# 3x3卷积
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
# 1x1卷积升维
|
||||
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
# 是否使用CBAM注意力机制
|
||||
self.use_cbam = use_cbam
|
||||
if use_cbam:
|
||||
self.cbam = CBAM(out_channels * self.expansion)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
# # 如果使用注意力机制,应用CBAM
|
||||
if self.use_cbam:
|
||||
out = self.cbam(out)
|
||||
|
||||
# 如果有下采样,调整shortcut连接
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# 残差连接
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""集成了CBAM注意力机制的ResNet模型"""
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, use_cbam=True):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_channels = 64
|
||||
self.use_cbam = use_cbam
|
||||
|
||||
# 初始卷积层
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.cbam1 = CBAM(64)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# 残差块层
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
self.cbam2 = CBAM(512)
|
||||
# 全局平均池化和分类器
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
# 初始化权重
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# 零初始化最后一个BN层的权重,使残差分支初始为0
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, out_channels, blocks, stride=1):
|
||||
downsample = None
|
||||
# 如果通道数不匹配或需要调整步长,创建下采样层
|
||||
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
# 第一个块可能需要下采样
|
||||
layers.append(block(self.in_channels, out_channels, stride, downsample, use_cbam=self.use_cbam))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
# 添加剩余的块
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.in_channels, out_channels, use_cbam=self.use_cbam))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# 特征提取
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
# if self.use_cbam:
|
||||
# x = self.cbam1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
# if self.use_cbam:
|
||||
# x = self.cbam2(x)
|
||||
# 分类
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# 工厂函数,创建不同深度的ResNet模型
|
||||
def resnet18_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
def resnet34_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet50_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet101_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet152_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
# 测试模型
|
||||
if __name__ == "__main__":
|
||||
# 创建一个带有CBAM注意力机制的ResNet50模型
|
||||
model = resnet50_cbam(num_classes=10)
|
||||
# 测试输入
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
y = model(x)
|
||||
print(f"输入形状: {x.shape}")
|
||||
print(f"输出形状: {y.shape}")
|
480
model/resnet_pre.py
Normal file
480
model/resnet_pre.py
Normal file
@ -0,0 +1,480 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from config import config as conf
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
# from .utils import load_state_dict_from_url
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
|
||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
|
||||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
x = torch.cat([avg_out, max_out], dim=1)
|
||||
x = self.conv1(x)
|
||||
return self.sigmoid(x)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
if self.bam:
|
||||
self.bam = SpatialAttention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
|
||||
if self.bam:
|
||||
out = out * self.bam(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=conf.channel_ratio):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
print("ResNet scale: >>>>>>>>>> ", scale)
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||
self.maxpool2 = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||
)
|
||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# return model
|
||||
|
||||
class CustomResNet18(nn.Module):
|
||||
def __init__(self, model, num_classes=conf.custom_num_classes):
|
||||
super(CustomResNet18, self).__init__()
|
||||
self.custom_model = nn.Sequential(*list(model.children())[:-1])
|
||||
self.fc = nn.Linear(model.fc.in_features, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.custom_model(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-14 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
**kwargs: Additional arguments passed to ResNet, including:
|
||||
scale (float): Channel scaling ratio (default: conf.channel_ratio)
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
4
model/utils.py
Normal file
4
model/utils.py
Normal file
@ -0,0 +1,4 @@
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
42
model/vit.py
Normal file
42
model/vit.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial, reduce
|
||||
from operator import mul
|
||||
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
|
||||
__all__ = [
|
||||
'vit_small',
|
||||
'vit_base',
|
||||
]
|
||||
|
||||
|
||||
def vit_small(**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg(num_classes=256)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = torch.randn(8, 3, 224, 224)
|
||||
vit = vit_base()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
# print(count_parameters(vit))
|
331
test_ori.py
Normal file
331
test_ori.py
Normal file
@ -0,0 +1,331 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# from config import config as conf
|
||||
from tools.dataset import get_transform
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
with open('configs/test.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Constants from config
|
||||
embedding_size = conf["base"]["embedding_size"]
|
||||
img_size = conf["transform"]["img_size"]
|
||||
device = conf["base"]["device"]
|
||||
|
||||
def unique_image(pair_list: str) -> Set[str]:
|
||||
unique_images = set()
|
||||
try:
|
||||
with open(pair_list, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
img1, img2, _ = line.split()
|
||||
unique_images.update([img1, img2])
|
||||
except ValueError as e:
|
||||
print(f"Skipping malformed line: {line}")
|
||||
except IOError as e:
|
||||
print(f"Error reading pair list file: {e}")
|
||||
raise
|
||||
|
||||
return unique_images
|
||||
|
||||
|
||||
def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
|
||||
"""
|
||||
Group image paths into batches of specified size.
|
||||
|
||||
Args:
|
||||
images: Set of image paths to group
|
||||
batch_size: Number of images per batch
|
||||
|
||||
Returns:
|
||||
List of batches, where each batch is a list of image paths
|
||||
"""
|
||||
image_list = list(images)
|
||||
num_images = len(image_list)
|
||||
batches = []
|
||||
|
||||
for i in range(0, num_images, batch_size):
|
||||
batch_end = min(i + batch_size, num_images)
|
||||
batches.append(image_list[i:batch_end])
|
||||
|
||||
return batches
|
||||
|
||||
|
||||
def _preprocess(images: list, transform) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
im = Image.open(img)
|
||||
im = transform(im)
|
||||
res.append(im)
|
||||
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||||
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def test_preprocess(images: list, transform) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
im = Image.open(img)
|
||||
if im.mode == 'RGBA':
|
||||
im = im.convert('RGB')
|
||||
im = transform(im)
|
||||
res.append(im)
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def featurize(
|
||||
images: List[str],
|
||||
transform: callable,
|
||||
net: nn.Module,
|
||||
device: torch.device,
|
||||
train: bool = False
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
try:
|
||||
# Select appropriate preprocessing
|
||||
preprocess_fn = _preprocess if train else test_preprocess
|
||||
|
||||
# Preprocess and move to device
|
||||
data = preprocess_fn(images, transform)
|
||||
data = data.to(device)
|
||||
net = net.to(device)
|
||||
|
||||
# Extract features with automatic mixed precision
|
||||
with torch.no_grad():
|
||||
if conf['models']['half']:
|
||||
data = data.half()
|
||||
features = net(data)
|
||||
# Create path-to-feature mapping
|
||||
return {img: feature for img, feature in zip(images, features)}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in feature extraction: {e}")
|
||||
raise
|
||||
def cosin_metric(x1, x2):
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
def threshold_search(y_score, y_true):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
best_acc = 0
|
||||
best_th = 0
|
||||
for i in range(len(y_score)):
|
||||
th = y_score[i]
|
||||
y_test = (y_score >= th)
|
||||
acc = np.mean((y_test == y_true).astype(int))
|
||||
if acc > best_acc:
|
||||
best_acc = acc
|
||||
best_th = th
|
||||
return best_acc, best_th
|
||||
|
||||
|
||||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig('grid.png')
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def showHist(same, cross):
|
||||
Same = np.array(same)
|
||||
Cross = np.array(cross)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('Same Barcode')
|
||||
|
||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('Cross Barcode')
|
||||
plt.savefig('plot.png')
|
||||
|
||||
|
||||
def compute_accuracy_recall(score, labels):
|
||||
th = 0.1
|
||||
squence = np.linspace(-1, 1, num=50)
|
||||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||
Same = score[:len(score) // 2]
|
||||
Cross = score[len(score) // 2:]
|
||||
for th in squence:
|
||||
t_score = (score > th)
|
||||
t_labels = (labels == 1)
|
||||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||
f_score = (score < th)
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||
|
||||
showHist(Same, Cross)
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||
|
||||
|
||||
def compute_accuracy(
|
||||
feature_dict: Dict[str, torch.Tensor],
|
||||
pair_list: str,
|
||||
test_root: str
|
||||
) -> Tuple[float, float]:
|
||||
try:
|
||||
with open(pair_list, 'r') as f:
|
||||
pairs = f.readlines()
|
||||
except IOError as e:
|
||||
print(f"Error reading pair list: {e}")
|
||||
raise
|
||||
|
||||
similarities = []
|
||||
labels = []
|
||||
|
||||
for pair in pairs:
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
|
||||
try:
|
||||
img1, img2, label = pair.split()
|
||||
img1_path = osp.join(test_root, img1)
|
||||
img2_path = osp.join(test_root, img2)
|
||||
|
||||
# Verify features exist
|
||||
if img1_path not in feature_dict or img2_path not in feature_dict:
|
||||
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
||||
|
||||
# Get features and compute similarity
|
||||
feat1 = feature_dict[img1_path].cpu().numpy()
|
||||
feat2 = feature_dict[img2_path].cpu().numpy()
|
||||
similarity = cosin_metric(feat1, feat2)
|
||||
|
||||
similarities.append(similarity)
|
||||
labels.append(int(label))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||||
continue
|
||||
|
||||
# Find optimal threshold and accuracy
|
||||
accuracy, threshold = threshold_search(similarities, labels)
|
||||
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||||
|
||||
return accuracy, threshold
|
||||
|
||||
|
||||
def deal_group_pair(pairList1, pairList2):
|
||||
allsimilarity = []
|
||||
one_similarity = []
|
||||
for pair1 in pairList1:
|
||||
for pair2 in pairList2:
|
||||
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||
one_similarity.append(similarity)
|
||||
allsimilarity.append(max(one_similarity)) # 最大值
|
||||
# allsimilarity.append(sum(one_similarity) / len(one_similarity)) # 均值
|
||||
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
|
||||
# print(allsimilarity)
|
||||
# print(labels)
|
||||
return allsimilarity
|
||||
|
||||
|
||||
def compute_group_accuracy(content_list_read):
|
||||
allSimilarity, allLabel = [], []
|
||||
Same, Cross = [], []
|
||||
for data_loaded in content_list_read:
|
||||
print(data_loaded)
|
||||
one_group_list = []
|
||||
try:
|
||||
for i in range(2):
|
||||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||
group = group_image(images, conf.test_batch_size)
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if data_loaded[-1] == '1':
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
Same.append(similarity)
|
||||
else:
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
Cross.append(similarity)
|
||||
allLabel.append(data_loaded[-1])
|
||||
allSimilarity.extend(similarity)
|
||||
except Exception as e:
|
||||
continue
|
||||
# print(allSimilarity)
|
||||
# print(allLabel)
|
||||
return allSimilarity, allLabel
|
||||
|
||||
|
||||
def init_model():
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
print('load model {} '.format(conf['models']['backbone']))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||
if conf['models']['half']:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
else:
|
||||
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
|
||||
if conf.model_half:
|
||||
model.half()
|
||||
first_param_dtype = next(model.parameters()).dtype
|
||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = init_model()
|
||||
model.eval()
|
||||
|
||||
if not conf['data']['group_test']:
|
||||
images = unique_image(conf['data']['test_list'])
|
||||
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||||
groups = group_image(images, conf['data']['test_batch_size']) # 根据batch_size取图片
|
||||
feature_dict = dict()
|
||||
_, test_transform = get_transform(conf)
|
||||
for group in groups:
|
||||
d = featurize(group, test_transform, model, conf['base']['device'])
|
||||
feature_dict.update(d)
|
||||
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
|
||||
print(
|
||||
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||||
)
|
||||
elif conf['data']['group_test']:
|
||||
filename = conf['data']['test_group_json']
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
content_list_read = json.load(file)
|
||||
Similarity, Label = compute_group_accuracy(content_list_read)
|
||||
compute_accuracy_recall(np.array(Similarity), np.array(Label))
|
||||
# compute_group_accuracy(data_loaded)
|
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
BIN
tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
Normal file
BIN
tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
Normal file
Binary file not shown.
68
tools/dataset.py
Normal file
68
tools/dataset.py
Normal file
@ -0,0 +1,68 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms as T
|
||||
# from config import config as conf
|
||||
import torch
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
def get_transform(cfg):
|
||||
train_transform = T.Compose([
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||
# T.RandomCrop(img_size * 4 // 5),
|
||||
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
|
||||
T.RandomRotation(cfg['transform']['RandomRotation']),
|
||||
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||
])
|
||||
test_transform = T.Compose([
|
||||
# T.Lambda(pad_to_square), # 补边
|
||||
T.ToTensor(),
|
||||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||||
])
|
||||
return train_transform, test_transform
|
||||
|
||||
def load_data(training=True, cfg=None):
|
||||
train_transform, test_transform = get_transform(cfg)
|
||||
if training:
|
||||
dataroot = cfg['data']['data_train_dir']
|
||||
transform = train_transform
|
||||
# transform = conf.train_transform
|
||||
batch_size = cfg['data']['train_batch_size']
|
||||
else:
|
||||
dataroot = cfg['data']['data_val_dir']
|
||||
# transform = conf.test_transform
|
||||
transform = test_transform
|
||||
batch_size = cfg['data']['val_batch_size']
|
||||
|
||||
data = ImageFolder(dataroot, transform=transform)
|
||||
class_num = len(data.classes)
|
||||
loader = DataLoader(data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg['base']['pin_memory'],
|
||||
num_workers=cfg['data']['num_workers'],
|
||||
drop_last=True)
|
||||
return loader, class_num
|
||||
|
||||
# def load_gift_data(action):
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||
# return train_dataset, val_dataset, test_dataset
|
10
tools/dataset.txt
Normal file
10
tools/dataset.txt
Normal file
@ -0,0 +1,10 @@
|
||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
112
tools/fp32comparefp16.py
Normal file
112
tools/fp32comparefp16.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from test_ori import group_image, init_model, featurize
|
||||
from config import config as conf
|
||||
import json
|
||||
import os.path as osp
|
||||
|
||||
def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
|
||||
if dataTest:
|
||||
norm_values_pf16 = torch.norm(values_pf16, p=2)
|
||||
norm_values_pf32 = torch.norm(values_pf32, p=2)
|
||||
euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
|
||||
print(f"欧几里得距离: {euclidean_distance}")
|
||||
cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
|
||||
print(f"余弦相似度: {cosine_sim}")
|
||||
else:
|
||||
|
||||
pass
|
||||
def cosin_metric(x1, x2, fp32=True):
|
||||
if fp32:
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
else:
|
||||
x1_fp16 = x1.astype(np.float16)
|
||||
x2_fp16 = x2.astype(np.float16)
|
||||
# print(type(x1))
|
||||
# pdb.set_trace()
|
||||
return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
|
||||
def deal_group_pair(pairList1, pairList2):
|
||||
one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
|
||||
for pair1 in pairList1:
|
||||
for pair2 in pairList2:
|
||||
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||
one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
|
||||
one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
|
||||
allsimilarity_fp32.append(one_similarity_fp32)
|
||||
allsimilarity_fp16.append(one_similarity_fp16)
|
||||
one_similarity_fp16, one_similarity_fp32 = [], []
|
||||
return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
|
||||
|
||||
def compute_group_accuracy(content_list_read, model):
|
||||
allSimilarity, allLabel = [], []
|
||||
Same, Cross = [], []
|
||||
flag_same = True
|
||||
flag_diff = True
|
||||
for data_loaded in content_list_read:
|
||||
one_group_list = []
|
||||
try:
|
||||
if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
|
||||
for i in range(2):
|
||||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||
group = group_image(images, conf.test_batch_size)
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
if str(data_loaded[-1]) == '1':
|
||||
flag_same = False
|
||||
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
print('fp32 same-- >', allsimilarity_fp32)
|
||||
print('fp16 same-- >', allsimilarity_fp16)
|
||||
else:
|
||||
flag_diff = False
|
||||
allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
print('fp32 diff-- >', allsimilarity_fp32)
|
||||
print('fp16 diff-- >', allsimilarity_fp16)
|
||||
except Exception as e:
|
||||
continue
|
||||
# print(allSimilarity)
|
||||
# print(allLabel)
|
||||
return allSimilarity, allLabel
|
||||
def get_feature_list(imgPth):
|
||||
imgs = get_files(imgPth)
|
||||
group = group_image(imgs, conf.test_batch_size)
|
||||
model = init_model()
|
||||
model.eval()
|
||||
fe = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
return fe
|
||||
|
||||
|
||||
def get_files(imgPth):
|
||||
imgsList = []
|
||||
for img in os.walk(imgPth):
|
||||
for img_name in img[2]:
|
||||
img_path = os.sep.join([img[0], img_name])
|
||||
imgsList.append(img_path)
|
||||
return imgsList
|
||||
import pdb
|
||||
|
||||
def compare(imgPth, group=False):
|
||||
model = init_model()
|
||||
model.eval()
|
||||
if not group:
|
||||
values_pf16, values_pf32 = [], []
|
||||
fe = get_feature_list(imgPth)
|
||||
# pdb.set_trace()
|
||||
values_pf32 += [value.cpu() for value in fe.values()]
|
||||
values_pf16 += [value.cpu().half() for value in fe.values()]
|
||||
for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
|
||||
compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
|
||||
else:
|
||||
filename = conf.test_group_json
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
content_list_read = json.load(file)
|
||||
compute_group_accuracy(content_list_read, model)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
imgPth = './data/test/inner/3701375401900'
|
||||
compare(imgPth)
|
369
tools/gift_assessment.py
Normal file
369
tools/gift_assessment.py
Normal file
@ -0,0 +1,369 @@
|
||||
import os
|
||||
import pdb
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
sys.path.append('../model')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from model.mlp import Net2, Net3, Net4
|
||||
from model import resnet18
|
||||
import torch
|
||||
from gift_data_pretreatment import getFeatureList
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def init_model(pkl_flag):
|
||||
res_pth = r"../checkpoints/resnet18_1009/best.pth"
|
||||
if pkl_flag:
|
||||
gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
|
||||
gift_model = Net3(pretrained=True, num_classes=1)
|
||||
gift_model.load_state_dict(torch.load(gift_pth))
|
||||
else:
|
||||
gift_pth = r'../checkpoints/gift_model/action3/best.pth'
|
||||
gift_model = Net4('resnet18', True, True) # 预训练模型
|
||||
try:
|
||||
print('>>multiple_cards load pre model <<')
|
||||
gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
||||
torch.load(gift_pth,
|
||||
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
|
||||
except Exception as e:
|
||||
print('>> load pre model <<')
|
||||
gift_model.load_state_dict(torch.load(gift_pth,
|
||||
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
||||
res_model = resnet18()
|
||||
res_model.load_state_dict({k.replace('module.', ''): v for k, v in
|
||||
torch.load(res_pth, map_location=torch.device(device)).items()})
|
||||
return res_model, gift_model
|
||||
|
||||
|
||||
def showHist(nongifts, gifts):
|
||||
# Same = filtered_data[:, 1].astype(np.float32)
|
||||
# Cross = filtered_data[:, 2].astype(np.float32)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(nongifts, bins=50, edgecolor='blue')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('nongifts')
|
||||
|
||||
axs[1].hist(gifts, bins=50, edgecolor='green')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('gifts')
|
||||
# plt.savefig('plot.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
def calculate_precision_recall(nongift, gift, points):
|
||||
precision, recall = [], []
|
||||
for point in points:
|
||||
TP = np.sum(gift > point)
|
||||
FN = np.sum(gift < point)
|
||||
FP = np.sum(nongift > point)
|
||||
TN = np.sum(nongift < point)
|
||||
if TP == 0:
|
||||
precision.append(0)
|
||||
recall.append(0)
|
||||
else:
|
||||
precision.append(TP / (TP + FP))
|
||||
recall.append(TP / (TP + FN))
|
||||
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
||||
if point == 0.5:
|
||||
print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
|
||||
return precision, recall
|
||||
|
||||
|
||||
def showgrid(all_prec, all_recall, points):
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
|
||||
plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# plt.savefig('grid.png')
|
||||
plt.show()
|
||||
plt.close()
|
||||
pass
|
||||
|
||||
|
||||
def discriminate_action(roots): # 判断加购还是退购
|
||||
pth = os.sep.join([roots, 'process.data'])
|
||||
with open(pth, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
content = line.strip()
|
||||
if 'weightValue' in content:
|
||||
# print(content.split(":")[-1].split(',')[0])
|
||||
if int(content.split(":")[-1].split(',')[0]) > 0:
|
||||
return 'add'
|
||||
else:
|
||||
return 'return'
|
||||
|
||||
|
||||
def median(lst):
|
||||
sorted_lst = sorted(lst)
|
||||
n = len(sorted_lst)
|
||||
if n % 2 == 1:
|
||||
# 如果列表长度是奇数,中位数是中间的那个元素
|
||||
return sorted_lst[n // 2]
|
||||
else:
|
||||
# 如果列表长度是偶数,中位数是中间两个元素的平均值
|
||||
mid1 = sorted_lst[(n // 2) - 1]
|
||||
mid2 = sorted_lst[n // 2]
|
||||
return (mid1 + mid2) / 2
|
||||
|
||||
|
||||
def get_special_data(data, p):
|
||||
# print(data)
|
||||
length = len(data)
|
||||
if length > 5:
|
||||
if p == 'max':
|
||||
return max(data[:round(length * 0.5)])
|
||||
elif p == 'average':
|
||||
return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
|
||||
elif p == 'median':
|
||||
return median(data[:round(length * 0.5)])
|
||||
else:
|
||||
return sum(data) / len(data)
|
||||
|
||||
|
||||
def read_data_file(pth):
|
||||
result = []
|
||||
with open(pth, 'r') as data_file:
|
||||
lines = data_file.readlines()
|
||||
for line in lines:
|
||||
if line.split(':')[0] == 'free_gift__result':
|
||||
if '0_tracking_output.data' in pth:
|
||||
result = line.split(':')[1].split(',')[:-1]
|
||||
else:
|
||||
result = line.split(':')[1].split(',')[:-2]
|
||||
result = [float(i) for i in result]
|
||||
return result
|
||||
|
||||
|
||||
def get_tracking_data(pth):
|
||||
result = []
|
||||
with open(pth, 'r') as data_file:
|
||||
lines = data_file.readlines()
|
||||
for line in lines:
|
||||
if len(line.split(',')) == 65:
|
||||
result.append([float(item) for item in line.split(',')[:-1]])
|
||||
return result
|
||||
|
||||
|
||||
def clean_reurn_data(pth):
|
||||
for roots, dirs, files in os.walk(pth):
|
||||
# print(roots, dirs, files)
|
||||
if len(dirs) == 0:
|
||||
flag = discriminate_action(roots)
|
||||
if flag == 'return':
|
||||
shutil.rmtree(roots)
|
||||
|
||||
|
||||
def get_gift_files(pth): # 测试后直接分析测试结果文件
|
||||
add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
|
||||
add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
|
||||
for roots, dirs, files in os.walk(pth):
|
||||
# print(roots, dirs, files)
|
||||
if len(dirs) == 0:
|
||||
flag = discriminate_action(roots)
|
||||
for file in files:
|
||||
if file == '0_tracking_output.data':
|
||||
result = read_data_file(os.path.join(roots, file))
|
||||
if not len(result) == 0:
|
||||
if flag == 'add':
|
||||
add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
|
||||
else:
|
||||
return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
|
||||
if flag == 'add':
|
||||
add_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
||||
else:
|
||||
return_tracking_output_0 += read_data_file(os.path.join(roots, file))
|
||||
elif file == '1_tracking_output.data':
|
||||
result = read_data_file(os.path.join(roots, file))
|
||||
if not len(result) == 0:
|
||||
if flag == 'add':
|
||||
add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
|
||||
else:
|
||||
return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
|
||||
if flag == 'add':
|
||||
add_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
||||
else:
|
||||
return_tracking_output_1 += read_data_file(os.path.join(roots, file))
|
||||
comprehensive_dicts = {"add_special_output_0": add_special_output_0,
|
||||
"return_special_output_0": return_special_output_0,
|
||||
"add_tracking_output_0": add_tracking_output_0,
|
||||
"return_tracking_output_0": return_tracking_output_0,
|
||||
"add_special_output_1": add_special_output_1,
|
||||
"return_special_output_1": return_special_output_1,
|
||||
"add_tracking_output_1": add_tracking_output_1,
|
||||
"return_tracking_output_1": return_tracking_output_1,
|
||||
}
|
||||
# print(tracking_output_0, tracking_output_1)
|
||||
showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
|
||||
np.array(comprehensive_dicts['add_tracking_output_1']))
|
||||
# showHist(np.array(comprehensive_dicts['add_special_output_0']),
|
||||
# np.array(comprehensive_dicts['add_special_output_1']))
|
||||
return comprehensive_dicts
|
||||
|
||||
|
||||
def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
|
||||
features_np = []
|
||||
if pkl_flag:
|
||||
for img_lists in img_pth_lists:
|
||||
# print(img_lists)
|
||||
fe_nps = getFeatureList(None, img_lists, res_model)
|
||||
# fe_nps.squeeze()
|
||||
try:
|
||||
fe_nps = fe_nps[0][:, 256:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
fe_nps = torch.from_numpy(fe_nps)
|
||||
fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
|
||||
if len(fe_nps):
|
||||
fe_np = gift_model(fe_nps)
|
||||
fe_np = np.squeeze(fe_np.detach().numpy())
|
||||
features_np.append(fe_np)
|
||||
else:
|
||||
for img_lists in img_pth_lists:
|
||||
fe_nps = getFeatureList(None, img_lists, gift_model)
|
||||
if len(fe_nps) > 0:
|
||||
fe_nps = np.concatenate(fe_nps)
|
||||
features_np.append(fe_nps)
|
||||
return features_np
|
||||
|
||||
|
||||
import pickle
|
||||
|
||||
|
||||
def create_gift_subimg_np(data_pth, pkl_flag):
|
||||
gift_array_pth = os.path.join(data_pth, 'gift.pkl')
|
||||
nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
|
||||
res_model, gift_model = init_model(pkl_flag)
|
||||
res_model = res_model.eval()
|
||||
gift_model = gift_model.eval()
|
||||
gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
|
||||
|
||||
for root, dirs, files in os.walk(data_pth):
|
||||
if ('commodity' in root and 'subimg' in root):
|
||||
print("commodity >> {}".format(root))
|
||||
for file in files:
|
||||
nongift_img_pth_list.append(os.sep.join([root, file]))
|
||||
nongift_lists.append(nongift_img_pth_list)
|
||||
nongift_img_pth_list = []
|
||||
elif ('Havegift' in root and 'subimg' in root):
|
||||
print("Havegift >> {}".format(root))
|
||||
for file in files:
|
||||
gift_img_pth_list.append(os.sep.join([root, file]))
|
||||
gift_lists.append(gift_img_pth_list)
|
||||
gift_img_pth_list = []
|
||||
nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
|
||||
gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
|
||||
with open(nongift_array_pth, 'wb') as file:
|
||||
pickle.dump(nongift, file)
|
||||
with open(gift_array_pth, 'wb') as file:
|
||||
pickle.dump(gift, file)
|
||||
|
||||
|
||||
def top_25_percent_mean(arr):
|
||||
# 1. 对数组进行从高到低排序
|
||||
sorted_arr = np.sort(arr)[::-1]
|
||||
|
||||
# 2. 计算数组长度的25%
|
||||
top_25_percent_length = int(len(sorted_arr) * 0.25)
|
||||
|
||||
# 3. 取排序后数组的前25%元素
|
||||
top_25_percent = sorted_arr[:top_25_percent_length]
|
||||
|
||||
# 4. 计算这些元素的平均值
|
||||
mean_value = np.mean(top_25_percent)
|
||||
|
||||
return top_25_percent
|
||||
|
||||
|
||||
def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
|
||||
points = (np.linspace(1, 100, 100)) / 100
|
||||
gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
|
||||
nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
|
||||
if not os.path.exists(gift_pkl_pth):
|
||||
create_gift_subimg_np(data_pth, pkl_flag)
|
||||
with open(nongift_pkl_pth, 'rb') as f:
|
||||
nongift = pickle.load(f)
|
||||
with open(gift_pkl_pth, 'rb') as f:
|
||||
gift = pickle.load(f)
|
||||
# showHist(nongift.flatten(), gift.flatten())
|
||||
|
||||
'''
|
||||
一分位均值
|
||||
'''
|
||||
nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
|
||||
gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
|
||||
'''
|
||||
中位数
|
||||
'''
|
||||
# nongift_mean = [np.median(items) for items in nongift]
|
||||
# gift_mean = [np.median(items) for items in gift] # 平均值
|
||||
|
||||
'''
|
||||
全部结果
|
||||
'''
|
||||
# nongifts = [items for items in nongift]
|
||||
# gifts = [items for items in gift]
|
||||
# showHist(nongifts, gifts)
|
||||
|
||||
'''
|
||||
平均值
|
||||
'''
|
||||
# nongift_mean = [np.mean(items) for items in nongift]
|
||||
# gift_mean = [np.mean(items) for items in gift]
|
||||
|
||||
showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
|
||||
precision, recall = calculate_precision_recall(np.array(nongift_mean),
|
||||
np.array(gift_mean),
|
||||
points)
|
||||
showgrid(precision, recall, points)
|
||||
|
||||
|
||||
def get_comprehensive_dicts(data_pth):
|
||||
gift_pth = r'../checkpoints/gift_model/action2/best.pth'
|
||||
g_model = Net3(pretrained=True, num_classes=1)
|
||||
g_model.load_state_dict(torch.load(gift_pth))
|
||||
g_model.eval()
|
||||
result = []
|
||||
file_name = ['0_tracking_output.data',
|
||||
'1_tracking_output.data']
|
||||
for root, dirs, files in os.walk(data_pth):
|
||||
if not len(dirs):
|
||||
for file in files:
|
||||
if file in file_name:
|
||||
print(os.path.join(root, file))
|
||||
result += get_tracking_data(os.path.join(root, file))
|
||||
result = torch.from_numpy(np.array(result))
|
||||
input = result.view(result.shape[0], 64, 1, 1)
|
||||
input = input.to('cpu')
|
||||
input = input.to(torch.float32)
|
||||
ji = g_model(input)
|
||||
print(ji)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
|
||||
# pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
|
||||
# pth = r'C:\Users\HP\Desktop\zengpin\1227'
|
||||
# get_gift_files(pth)
|
||||
|
||||
# 根据子图分析结果
|
||||
pth = r'D:\Project\contrast_nettest\data\gift_test'
|
||||
assess_gift_subimg(pth)
|
||||
|
||||
# 根据完整数据集分析结果
|
||||
# pth = r'C:\Users\HP\Desktop\zengpin\1231'
|
||||
# get_comprehensive_dicts(pth)
|
||||
|
||||
# 删除退购视频
|
||||
# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
|
||||
# clean_reurn_data(pth)
|
92
tools/gift_data_pretreatment.py
Normal file
92
tools/gift_data_pretreatment.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from config import config as conf
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image_path, output_path=None):
|
||||
"""
|
||||
将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
|
||||
|
||||
:param image_path: 输入图像的路径
|
||||
:param output_path: 转换后的图像保存路径
|
||||
"""
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
# print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
# else:
|
||||
# # 如果已经是RGB或其他模式,直接保存
|
||||
# img.save(image_path)
|
||||
# print(f"Image already in {img.mode} mode, saved to {image_path}")
|
||||
|
||||
|
||||
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
# print(img)
|
||||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def inference(images, model, actionModel=False):
|
||||
data = test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf.device)
|
||||
features = model(data)
|
||||
return features
|
||||
|
||||
|
||||
def group_image(images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def normalize(queFeatList):
|
||||
for num1 in range(len(queFeatList)):
|
||||
for num2 in range(len(queFeatList[num1])):
|
||||
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
|
||||
return queFeatList
|
||||
|
||||
def getFeatureList(barList, imgList, model):
|
||||
# featList = [[] for i in range(len(barList))]
|
||||
# for index, feat in enumerate(imgList):
|
||||
fe_nps = []
|
||||
groups = group_image(imgList)
|
||||
for group in groups:
|
||||
feat_tensor = inference(group, model)
|
||||
# for fe in feat_tensor:
|
||||
if feat_tensor.device == 'cpu':
|
||||
fe_np = feat_tensor.squeeze().detach().numpy()
|
||||
# fe_np = fe_np[:, 256:]
|
||||
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||
else:
|
||||
fe_np = feat_tensor.squeeze().detach().cpu().numpy()
|
||||
# fe_np = fe_np[:, 256:]
|
||||
# fe_np = fe_np[256:]
|
||||
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||
# fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
|
||||
# print(fe_np)
|
||||
|
||||
fe_nps.append(fe_np)
|
||||
# if fe_nps:
|
||||
# merged_fe_np = np.concatenate(fe_nps, axis=0)
|
||||
# else:
|
||||
# merged_fe_np = np.array([]) #
|
||||
# fe_list = normalize(fe_nps)
|
||||
return fe_nps
|
118
tools/json_contrast.py
Normal file
118
tools/json_contrast.py
Normal file
@ -0,0 +1,118 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
def showHist(same, cross):
|
||||
Same = np.array(same)
|
||||
Cross = np.array(cross)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('Same Barcode')
|
||||
|
||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('Cross Barcode')
|
||||
# plt.savefig('plot.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||||
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig('grid.png')
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def compute_accuracy_recall(score, labels):
|
||||
th = 0.1
|
||||
squence = np.linspace(-1, 1, num=50)
|
||||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||||
Same = score[:len(score) // 2]
|
||||
Cross = score[len(score) // 2:]
|
||||
for th in squence:
|
||||
t_score = (score > th)
|
||||
t_labels = (labels == 1)
|
||||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||
f_score = (score < th)
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||||
|
||||
showHist(Same, Cross)
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||||
|
||||
|
||||
def get_similarity(features1, features2, n, m):
|
||||
features1 = np.array(features1)
|
||||
features2 = np.array(features2)
|
||||
all_similarity = []
|
||||
for feature1 in features1:
|
||||
for feature2 in features2:
|
||||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
all_similarity.append(similarity)
|
||||
test_similarity = np.array(all_similarity)
|
||||
np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
|
||||
if n == 5 and m == 5:
|
||||
print(all_similarity)
|
||||
return np.mean(np_all_array), all_similarity
|
||||
# return sum(all_similarity)/len(all_similarity), all_similarity
|
||||
# return max(all_similarity), all_similarity
|
||||
|
||||
|
||||
def deal_similarity(dicts):
|
||||
all_similarity = []
|
||||
similarity = []
|
||||
same_barcode, diff_barcode = [], []
|
||||
for n, (key1, value1) in enumerate(dicts.items()):
|
||||
print('key1 >> {}'.format(key1))
|
||||
for m, (key2, value2) in enumerate(dicts.items()):
|
||||
print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
|
||||
max_similarity, some_similarity = get_similarity(value1, value2, n, m)
|
||||
similarity.append(max_similarity)
|
||||
if key1 == key2:
|
||||
same_barcode += some_similarity
|
||||
else:
|
||||
diff_barcode += some_similarity
|
||||
all_similarity.append(similarity)
|
||||
similarity = []
|
||||
all_similarity = np.array(all_similarity)
|
||||
random.shuffle(diff_barcode)
|
||||
same_list = [1] * len(same_barcode)
|
||||
diff_list = [0] * len(same_barcode)
|
||||
all_list = same_list + diff_list
|
||||
all_score = same_barcode + diff_barcode[:len(same_barcode)]
|
||||
compute_accuracy_recall(np.array(all_score), np.array(all_list))
|
||||
print(all_similarity.shape)
|
||||
|
||||
|
||||
with open('../search_library/data_zhanting.json', 'r') as file:
|
||||
data = json.load(file)
|
||||
dicts = {}
|
||||
for dict in data['total']:
|
||||
key = dict['key']
|
||||
value = dict['value']
|
||||
dicts[key] = value
|
||||
deal_similarity(dicts)
|
63
tools/model_onnx_transform.py
Normal file
63
tools/model_onnx_transform.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pdb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model import resnet18
|
||||
from config import config as conf
|
||||
from collections import OrderedDict
|
||||
import cv2
|
||||
|
||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
||||
# 定义模型
|
||||
if model_name == 'resnet18':
|
||||
model = resnet18(scale=0.75)
|
||||
|
||||
print('model_name >>> {}'.format(model_name))
|
||||
if conf.multiple_cards:
|
||||
model = model.to(torch.device('cpu'))
|
||||
checkpoint = torch.load(pretrained_weights)
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint.items():
|
||||
name = k[7:] # remove "module."
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# try:
|
||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
# 转换为ONNX
|
||||
if model_name == 'gift_type2':
|
||||
input_shape = [1, 64, 13, 13]
|
||||
elif model_name == 'gift_type3':
|
||||
input_shape = [1, 3, 224, 224]
|
||||
else:
|
||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
||||
input_shape = [1, 3, 224, 224]
|
||||
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
|
||||
output_file = pretrained_weights.replace('pth', 'onnx')
|
||||
|
||||
# 导出模型
|
||||
torch.onnx.export(model,
|
||||
torch.randn(input_shape),
|
||||
output_file,
|
||||
verbose=True,
|
||||
input_names=['input'],
|
||||
output_names=['output']) ##, optset_version=12
|
||||
|
||||
model.eval()
|
||||
trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
|
||||
trace_model.save(output_file.replace('.onnx', '.pt'))
|
||||
print(f"Model exported to {output_file}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
186
tools/model_rknn_transform.py
Normal file
186
tools/model_rknn_transform.py
Normal file
@ -0,0 +1,186 @@
|
||||
import os
|
||||
import pdb
|
||||
import urllib
|
||||
import traceback
|
||||
import time
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
from config import config as conf
|
||||
from rknn.api import RKNN
|
||||
|
||||
import config
|
||||
|
||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
||||
|
||||
|
||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||
# RKNN_MODEL = 'v3_small_0424.rknn'
|
||||
|
||||
def show_outputs(outputs):
|
||||
# print('***************outputs', outputs)
|
||||
output = outputs[0][0]
|
||||
# print('len(outputs)',len(output), output)
|
||||
output_sorted = sorted(output, reverse=True)
|
||||
top5_str = 'resnet50v2\n-----TOP 5-----\n'
|
||||
for i in range(5):
|
||||
value = output_sorted[i]
|
||||
index = np.where(output == value)
|
||||
for j in range(len(index)):
|
||||
if (i + j) >= 5:
|
||||
break
|
||||
if value > 0:
|
||||
topi = '{}: {}\n'.format(index[j], value)
|
||||
else:
|
||||
topi = '-1: 0.0\n'
|
||||
top5_str += topi
|
||||
# pdb.set_trace()
|
||||
print(top5_str)
|
||||
|
||||
|
||||
def readable_speed(speed):
|
||||
speed_bytes = float(speed)
|
||||
speed_kbytes = speed_bytes / 1024
|
||||
if speed_kbytes > 1024:
|
||||
speed_mbytes = speed_kbytes / 1024
|
||||
if speed_mbytes > 1024:
|
||||
speed_gbytes = speed_mbytes / 1024
|
||||
return "{:.2f} GB/s".format(speed_gbytes)
|
||||
else:
|
||||
return "{:.2f} MB/s".format(speed_mbytes)
|
||||
else:
|
||||
return "{:.2f} KB/s".format(speed_kbytes)
|
||||
|
||||
|
||||
def show_progress(blocknum, blocksize, totalsize):
|
||||
speed = (blocknum * blocksize) / (time.time() - start_time)
|
||||
speed_str = " Speed: {}".format(readable_speed(speed))
|
||||
recv_size = blocknum * blocksize
|
||||
|
||||
f = sys.stdout
|
||||
progress = (recv_size / totalsize)
|
||||
progress_str = "{:.2f}%".format(progress * 100)
|
||||
n = round(progress * 50)
|
||||
s = ('#' * n).ljust(50, '-')
|
||||
f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str)
|
||||
f.flush()
|
||||
f.write('\r\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Create RKNN object
|
||||
rknn = RKNN(verbose=True)
|
||||
|
||||
# If resnet50v2 does not exist, download it.
|
||||
# Download address:
|
||||
# https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx
|
||||
if not os.path.exists(ONNX_MODEL):
|
||||
print('--> Download {}'.format(ONNX_MODEL))
|
||||
url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx'
|
||||
download_file = ONNX_MODEL
|
||||
try:
|
||||
start_time = time.time()
|
||||
urllib.request.urlretrieve(url, download_file, show_progress)
|
||||
except:
|
||||
print('Download {} failed.'.format(download_file))
|
||||
print(traceback.format_exc())
|
||||
exit(-1)
|
||||
print('done')
|
||||
|
||||
# pre-process config
|
||||
print('--> config model')
|
||||
# rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.82, 58.82, 58.82])
|
||||
rknn.config(
|
||||
mean_values=[[127.5, 127.5, 127.5]],
|
||||
std_values=[[127.5, 127.5, 127.5]],
|
||||
target_platform='rk3588',
|
||||
model_pruning=False,
|
||||
compress_weight=False,
|
||||
single_core_mode=True)
|
||||
# rknn.config(
|
||||
# mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||
# std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
|
||||
# target_platform='rk3588', # 设置目标平台
|
||||
# # quantize_dtype='int8',
|
||||
# # quantize_algo='normal',
|
||||
# # output_optimize=False,
|
||||
# # output_format='rknnb'
|
||||
# )
|
||||
print('done')
|
||||
|
||||
# Load model
|
||||
print('--> Loading model')
|
||||
ret = rknn.load_onnx(model=ONNX_MODEL)
|
||||
if ret != 0:
|
||||
print('Load model failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Build model
|
||||
print('--> Building model')
|
||||
ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
|
||||
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
||||
if ret != 0:
|
||||
print('Build model failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Export rknn model
|
||||
print('--> Export rknn model')
|
||||
ret = rknn.export_rknn(RKNN_MODEL)
|
||||
if ret != 0:
|
||||
print('Export rknn model failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Set inputs
|
||||
img = cv2.imread('./dog_224x224.jpg')
|
||||
# img = cv2.imread('./data/gift_test/Havegift/20241213-161415-cb8e0762-f376-45d1-8f36-7dc070990fa5/subimg/cam1_9_tid2_fid(18, 33250169482).png')
|
||||
# print('img', img)
|
||||
# with open('pixel_values.txt', 'w') as file:
|
||||
|
||||
# for y in range(img.shape[0]):
|
||||
# for x in range(img.shape[1]):
|
||||
# b, g, r = img[y, x]
|
||||
# file.write(f'{r},{g},{b}\n')
|
||||
|
||||
# img = cv2.imread('./810115161912_810115161912_20240131-145622_0da14e4d-a3da-499f-b512-2d4168ab1c87_front_addGood_70f75407b7ae_29_01.jpg')
|
||||
img = cv2.resize(img, (224, 224))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# img = conf.test_transform(img)
|
||||
# img = img.numpy()
|
||||
# img = img.transpose(1, 2, 0)
|
||||
|
||||
# Init runtime environment
|
||||
print('--> Init runtime environment')
|
||||
ret = rknn.init_runtime()
|
||||
# ret = rknn.init_runtime('rk3588')
|
||||
if ret != 0:
|
||||
print('Init runtime environment failed!')
|
||||
exit(ret)
|
||||
print('done')
|
||||
|
||||
# Inference
|
||||
print('--> Running model')
|
||||
T1 = time.time()
|
||||
outputs = rknn.inference(inputs=[img])
|
||||
# outputs = rknn.inference(inputs=img)
|
||||
T2 = time.time()
|
||||
print('消耗时间 >>> {}'.format(T2 - T1))
|
||||
with open('result_0415_128.txt', 'a') as f:
|
||||
f.write(str(outputs))
|
||||
# pdb.set_trace()
|
||||
print('***outputs', outputs)
|
||||
np.save('./onnx_resnet50v2_0.npy', outputs[0])
|
||||
x = outputs[0]
|
||||
output = np.exp(x) / np.sum(np.exp(x))
|
||||
outputs = [output]
|
||||
show_outputs(outputs)
|
||||
print('done')
|
||||
|
||||
rknn.release()
|
233
tools/operate_usearch.py
Normal file
233
tools/operate_usearch.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from usearch.index import Index
|
||||
import json
|
||||
import struct
|
||||
|
||||
|
||||
def create_index():
|
||||
index = Index(
|
||||
ndim=256,
|
||||
metric='cos',
|
||||
# dtype='f32',
|
||||
dtype='f16',
|
||||
connectivity=32,
|
||||
expansion_add=40, # 128,
|
||||
expansion_search=10, # 64,
|
||||
multi=True
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
def compare_feature(features1, features2, model='1'):
|
||||
"""
|
||||
:param model 比对策略
|
||||
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||||
'1':带对比的所有相似度的均值
|
||||
'2':比对1:1的最大值
|
||||
:param feature1:
|
||||
:param feature2:
|
||||
:return:
|
||||
"""
|
||||
similarity_group, similarity_groups = [], []
|
||||
if model == '0':
|
||||
for feature1 in features1:
|
||||
for feature2 in features2[0]:
|
||||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(max(similarity_group))
|
||||
similarity_group = []
|
||||
return sum(similarity_groups) / len(similarity_groups)
|
||||
|
||||
elif model == '1':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (
|
||||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||||
similarity_group = []
|
||||
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||||
if len(similarity_groups) == 0:
|
||||
return -1
|
||||
return sum(similarity_groups) / len(similarity_groups)
|
||||
elif model == '2':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (
|
||||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
return max(similarity_group)
|
||||
|
||||
def get_barcode_feature(data):
|
||||
barcode = data['key']
|
||||
features = data['value']
|
||||
return [barcode] * len(features), features
|
||||
|
||||
|
||||
def analysis_file(file_path):
|
||||
"""
|
||||
:param file_path:
|
||||
:return:
|
||||
"""
|
||||
barcodes, features = [], []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for dic in data['total']:
|
||||
barcode, feature = get_barcode_feature(dic)
|
||||
barcodes.append(barcode)
|
||||
features.append(feature)
|
||||
return barcodes, features
|
||||
|
||||
|
||||
def create_base_index(index_file_pth=None,
|
||||
barcodes=None,
|
||||
features=None,
|
||||
save_index_name=None):
|
||||
index = create_index()
|
||||
if index_file_pth is not None:
|
||||
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||||
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||||
barcodes, features = analysis_file(index_file_pth)
|
||||
else:
|
||||
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||||
for barcode, feature in zip(barcodes, features):
|
||||
try:
|
||||
index.add(np.array(barcode), np.array(feature))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
index.save(save_index_name)
|
||||
|
||||
|
||||
def get_feature_index(index_file_pth=None,
|
||||
barcodes=None):
|
||||
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||||
index = Index.restore(index_file_pth, view=True)
|
||||
feature_lists = index.get(np.array(barcodes))
|
||||
print("memory {} size {}".format(index.memory_usage, index.size))
|
||||
print("feature_lists {}".format(feature_lists))
|
||||
return feature_lists
|
||||
|
||||
|
||||
def search_in_index(query=None,
|
||||
barcode=None, # barcode -> int or np.ndarray
|
||||
index_name=None,
|
||||
temp_index=False, # 是否为临时库
|
||||
model='0',
|
||||
):
|
||||
if temp_index:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists)
|
||||
else:
|
||||
results = index.search(query, count=5)
|
||||
return results
|
||||
else: # 标准库
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists, model)
|
||||
else:
|
||||
results = index.search(query, count=10)
|
||||
return results
|
||||
|
||||
|
||||
def delete_index(index_name=None, key=None, index=None):
|
||||
assert key is not None, 'key must be not None'
|
||||
if index is None:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
index.remove(index_name)
|
||||
else:
|
||||
index.remove(key)
|
||||
|
||||
from scipy.spatial.distance import cdist
|
||||
def compute_similarity_matrix(featurelists1, featurelists2):
|
||||
"""计算图片之间的余弦相似度矩阵"""
|
||||
# 计算所有向量对之间的余弦相似度
|
||||
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
|
||||
cosine_similarities = np.around(cosine_similarities, decimals=3)
|
||||
return cosine_similarities
|
||||
|
||||
def check_usearch_json_diff(index_file_pth, json_file_pth):
|
||||
json_features = None
|
||||
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
|
||||
with open(json_file_pth, 'r') as json_file:
|
||||
json_data = json.load(json_file)
|
||||
for data in json_data['total']:
|
||||
if data['key'] == '6923644272159':
|
||||
json_features = data['value']
|
||||
json_features = np.array(json_features)
|
||||
feature_lists = np.array(feature_lists[0])
|
||||
compute_similarity_matrix(json_features, feature_lists)
|
||||
|
||||
|
||||
def write_binary_file(filename, datas):
|
||||
with open(filename, 'wb') as f:
|
||||
# 先写入数据中的key数量(为C++读取提供便利)
|
||||
key_count = len(datas)
|
||||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||
|
||||
for data in datas:
|
||||
key = data['key']
|
||||
feats = data['value']
|
||||
key_bytes = key.encode('utf-8')
|
||||
key_len = len(key)
|
||||
length_byte = struct.pack('<B', key_len)
|
||||
f.write(length_byte)
|
||||
# f.write(struct.pack('Q', len(key_bytes)))
|
||||
f.write(key_bytes)
|
||||
value_count = len(feats)
|
||||
f.write(struct.pack('I', (value_count * 256)))
|
||||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||
for values in feats:
|
||||
# 写入每个浮点数值(保留小数点后六位)
|
||||
for value in values:
|
||||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||
value_half = np.float16(value)
|
||||
# print(value_half.tobytes())
|
||||
f.write(value_half.tobytes())
|
||||
def create_binary_file(json_path, flag=True):
|
||||
# 1. 打开JSON文件
|
||||
with open(json_path, 'r', encoding='utf-8') as file:
|
||||
# 2. 读取并解析JSON文件内容
|
||||
data = json.load(file)
|
||||
if flag:
|
||||
for flag, values in data.items():
|
||||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||
write_binary_file(index_file_pth.replace('json', 'bin'), values)
|
||||
else:
|
||||
write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||
|
||||
def create_binary_files(index_file_pth):
|
||||
if os.path.isfile(index_file_pth):
|
||||
create_binary_file(index_file_pth)
|
||||
else:
|
||||
for name in os.listdir(index_file_pth):
|
||||
jsonpth = os.sep.join([index_file_pth, name])
|
||||
create_binary_file(jsonpth, False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
|
||||
index_file_pth = '../search_library/yunhedian_30-04.json'
|
||||
# create_base_index(index_file_pth) # 生成usearch文件
|
||||
create_binary_files(index_file_pth) # 生成二进制文件 多文件
|
||||
|
||||
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||||
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||||
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||||
|
||||
# # check index data file
|
||||
# index_file_pth = '../search_library/data_zhanting.data'
|
||||
# # # get_feature_index(index_file_pth, ['6901070602818'])
|
||||
# get_feature_index(index_file_pth, ['6923644272159'])
|
||||
|
||||
# index_file_pth = '../search_library/data_zhanting.data'
|
||||
# json_file_pth = '../search_library/data_zhanting.json'
|
||||
# check_usearch_json_diff(index_file_pth, json_file_pth)
|
84
tools/threshold_partition.py
Normal file
84
tools/threshold_partition.py
Normal file
@ -0,0 +1,84 @@
|
||||
'''
|
||||
现场1:N测试,确定阈值
|
||||
'''
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def showHist(filtered_data):
|
||||
Same = filtered_data[:, 1].astype(np.float32)
|
||||
Cross = filtered_data[:, 2].astype(np.float32)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([-0.1, 1])
|
||||
axs[0].set_title('first')
|
||||
|
||||
axs[1].hist(Cross, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([-0.1, 1])
|
||||
axs[1].set_title('second')
|
||||
# plt.savefig('plot.png')
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_tartget_list(nested_list):
|
||||
filtered_list = np.array(list(filter(lambda x: len(x) >= 2, nested_list))) # 去除无轨迹的数据
|
||||
filtered_correct = filtered_list[filtered_list[:, 0] != 'wrong'] # 获取比对正确的时项
|
||||
filtered_wrong = filtered_list[filtered_list[:, 0] == 'wrong'] # 获取比对错误的时项
|
||||
showHist(filtered_correct)
|
||||
# showHist(filtered_wrong)
|
||||
print(filtered_list)
|
||||
|
||||
|
||||
def deal_process(file_pth):
|
||||
flag = False
|
||||
event = file_pth.split('\\')[-2]
|
||||
target_barcode = file_pth.split('\\')[-2].split('_')[-1]
|
||||
temp_list = []
|
||||
|
||||
with open(file_pth, 'r') as f:
|
||||
for line in f:
|
||||
if 'oneToOne' in line:
|
||||
flag = True
|
||||
continue
|
||||
if flag:
|
||||
line = line.replace('\n', '')
|
||||
comparison_data = line.split(',')
|
||||
forecast_barcode = comparison_data[0]
|
||||
value = comparison_data[-1].split(':')[-1]
|
||||
if value == '':
|
||||
break
|
||||
if len(temp_list) == 0:
|
||||
if forecast_barcode == target_barcode:
|
||||
temp_list.append('correct')
|
||||
else:
|
||||
temp_list.append('wrong')
|
||||
temp_list.append(float(value))
|
||||
temp_list.append(event)
|
||||
return temp_list
|
||||
|
||||
|
||||
def anaylze_scratch(scratch_pth):
|
||||
purchase, back = [], []
|
||||
for root, dirs, files in os.walk(scratch_pth):
|
||||
if len(root) > 0:
|
||||
if len(root.split('_')) == 4: # 加购
|
||||
process = os.path.join(root, 'process.data')
|
||||
if not os.path.exists(process):
|
||||
continue
|
||||
purchase.append(deal_process(process))
|
||||
elif len(root.split('_')) == 3:
|
||||
process = os.path.join(root, 'process.data')
|
||||
if not os.path.exists(process):
|
||||
continue
|
||||
back.append(deal_process(process))
|
||||
# get_tartget_list(purchase)
|
||||
get_tartget_list(back)
|
||||
print(purchase)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1108_展厅模型v800测试\\'
|
||||
scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1120_展厅模型v801测试\\扫A放A\\'
|
||||
anaylze_scratch(scratch_pth)
|
411
tools/write_feature_json.py
Normal file
411
tools/write_feature_json.py
Normal file
@ -0,0 +1,411 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tools.dataset import get_transform
|
||||
from model import resnet18
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import yaml
|
||||
import shutil
|
||||
import struct
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FeatureExtractor:
|
||||
def __init__(self, conf):
|
||||
self.conf = conf
|
||||
self.model = self.initModel()
|
||||
_, self.test_transform = get_transform(self.conf)
|
||||
pass
|
||||
|
||||
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
|
||||
"""
|
||||
Initialize and load the ResNet18 model for inference.
|
||||
|
||||
Args:
|
||||
inference_model: Optional path to model weights. Uses conf.test_model if None.
|
||||
|
||||
Returns:
|
||||
Loaded and configured PyTorch model in evaluation mode.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If model weights file is not found
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
|
||||
|
||||
try:
|
||||
# Verify model file exists
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18().to(self.conf['base']['device'])
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Load weights
|
||||
state_dict = torch.load(model_path, map_location=conf['base']['device'])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.eval()
|
||||
logger.info(f"Successfully loaded model from {model_path}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
raise
|
||||
|
||||
def convert_rgba_to_rgb(self, image_path):
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
|
||||
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
def inference(self, images, model, actionModel=False):
|
||||
data = self.test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf['base']['device'])
|
||||
features = model(data)
|
||||
if conf['data']['half']:
|
||||
features = features.half()
|
||||
return features
|
||||
|
||||
def group_image(self, images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def getFeatureList(self, barList, imgList):
|
||||
featList = [[] for _ in range(len(barList))]
|
||||
|
||||
for index, image_paths in enumerate(imgList):
|
||||
try:
|
||||
# Process images in batches
|
||||
for batch in self.group_image(image_paths):
|
||||
# Get features for batch
|
||||
features = self.inference(batch, self.model)
|
||||
|
||||
# Process each feature in batch
|
||||
for feat in features:
|
||||
# Move to CPU and convert to numpy
|
||||
feat_np = feat.squeeze().detach().cpu().numpy()
|
||||
|
||||
# Normalize first 256 dimensions
|
||||
normalized = self.normalize_256(feat_np[:256])
|
||||
|
||||
# Combine with remaining dimensions
|
||||
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
|
||||
|
||||
featList[index].append(combined)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing images for index {index}: {str(e)}")
|
||||
continue
|
||||
return featList
|
||||
|
||||
def get_files(
|
||||
self,
|
||||
folder: str,
|
||||
filter: Optional[List[str]] = None,
|
||||
create_single_json: bool = False
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Recursively collect image files from directory structure.
|
||||
|
||||
Args:
|
||||
folder: Root directory to scan
|
||||
filter: Optional list of barcodes to include
|
||||
create_single_json: Whether to create individual JSON files per barcode
|
||||
|
||||
Returns:
|
||||
Dictionary mapping barcode names to lists of image paths
|
||||
|
||||
Example:
|
||||
{
|
||||
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
|
||||
"barcode2": ["path/to/img3.jpg"]
|
||||
}
|
||||
"""
|
||||
file_dicts = {}
|
||||
total_files = 0
|
||||
feature_counts = []
|
||||
barcode_count = 0
|
||||
subclass = [str(i) for i in range(100)]
|
||||
# Validate input directory
|
||||
if not os.path.isdir(folder):
|
||||
raise ValueError(f"Invalid directory: {folder}")
|
||||
|
||||
# Process each barcode directory
|
||||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||
if not dirs: # Leaf directory (contains images)
|
||||
basename = os.path.basename(root)
|
||||
if basename in subclass:
|
||||
ori_barcode = root.split('/')[-2]
|
||||
barcode = root.split('/')[-2] + '_' + basename
|
||||
else:
|
||||
ori_barcode = basename
|
||||
barcode = basename
|
||||
# Apply filter if provided
|
||||
if filter and ori_barcode not in filter:
|
||||
continue
|
||||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
f.write(ori_barcode + '\n')
|
||||
f.close()
|
||||
continue
|
||||
|
||||
# Process image files
|
||||
if files:
|
||||
image_paths = self._process_image_files(root, files)
|
||||
if not image_paths:
|
||||
continue
|
||||
|
||||
# Update counters
|
||||
barcode_count += 1
|
||||
file_count = len(image_paths)
|
||||
total_files += file_count
|
||||
feature_counts.append(file_count)
|
||||
|
||||
# Handle output mode
|
||||
if create_single_json:
|
||||
self._process_single_barcode(barcode, image_paths)
|
||||
else:
|
||||
if barcode.split('_')[-1] == '0':
|
||||
barcode = barcode.split('_')[0]
|
||||
file_dicts[barcode] = image_paths
|
||||
|
||||
# # Log summary
|
||||
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
|
||||
# logger.debug(f"Image counts per barcode: {feature_counts}")
|
||||
|
||||
# Batch process if not creating individual JSONs
|
||||
if not create_single_json and file_dicts:
|
||||
self.createFeatureDict(
|
||||
file_dicts,
|
||||
create_single_json=False,
|
||||
)
|
||||
return file_dicts
|
||||
|
||||
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
|
||||
"""Process and validate image files in a directory."""
|
||||
valid_paths = []
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
try:
|
||||
# Convert RGBA to RGB if needed
|
||||
self.convert_rgba_to_rgb(file_path)
|
||||
valid_paths.append(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
|
||||
return valid_paths
|
||||
|
||||
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
|
||||
"""Process a single barcode and create individual JSON file."""
|
||||
temp_dict = {barcode: image_paths}
|
||||
self.createFeatureDict(
|
||||
temp_dict,
|
||||
create_single_json=True,
|
||||
)
|
||||
|
||||
def normalize_256(self, queFeatList):
|
||||
queFeatList = queFeatList / np.linalg.norm(queFeatList)
|
||||
return queFeatList
|
||||
|
||||
def img2feature(
|
||||
self,
|
||||
imgs_dict: Dict[str, List[str]]
|
||||
) -> Tuple[List[str], List[List[np.ndarray]]]:
|
||||
"""
|
||||
Extract features for all images in the dictionary.
|
||||
|
||||
Args:
|
||||
imgs_dict: Dictionary mapping barcodes to image paths
|
||||
model: Pretrained feature extraction model
|
||||
barcode_flag: Whether to include barcode info (unused)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of barcode IDs
|
||||
- List of feature lists (one per barcode)
|
||||
|
||||
Raises:
|
||||
ValueError: If input dictionary is empty
|
||||
RuntimeError: If feature extraction fails
|
||||
"""
|
||||
if not imgs_dict:
|
||||
raise ValueError("No images provided for feature extraction")
|
||||
|
||||
try:
|
||||
barcode_list = list(imgs_dict.keys())
|
||||
image_list = list(imgs_dict.values())
|
||||
feature_list = self.getFeatureList(barcode_list, image_list)
|
||||
|
||||
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
|
||||
return barcode_list, feature_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Feature extraction failed: {str(e)}")
|
||||
raise RuntimeError(f"Feature extraction failed: {str(e)}")
|
||||
|
||||
def createFeatureDict(self, imgs_dict,
|
||||
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||
dicts_all = {}
|
||||
value_list = []
|
||||
barcode_list, imgs_list = self.img2feature(imgs_dict)
|
||||
for i in range(len(barcode_list)):
|
||||
dicts = {}
|
||||
|
||||
imgs_list_ = []
|
||||
for j in range(len(imgs_list[i])):
|
||||
imgs_list_.append(imgs_list[i][j].tolist())
|
||||
|
||||
dicts['key'] = barcode_list[i]
|
||||
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
|
||||
dicts['value'] = truncated_imgs_list
|
||||
if create_single_json:
|
||||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||
with open(json_path, 'w') as json_file:
|
||||
json.dump(dicts, json_file)
|
||||
else:
|
||||
value_list.append(dicts)
|
||||
if not create_single_json:
|
||||
dicts_all['total'] = value_list
|
||||
with open(self.conf['save']['json_bin'], 'w') as json_file:
|
||||
json.dump(dicts_all, json_file)
|
||||
self.create_binary_files(self.conf['save']['json_bin'])
|
||||
|
||||
def statisticsBarcodes(self, pth, filter=None):
|
||||
feature_num = 0
|
||||
feature_num_lists = []
|
||||
nn = 0
|
||||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||
for barcode in os.listdir(pth):
|
||||
print("barcode length >> {}".format(len(barcode)))
|
||||
if len(barcode) > 13 or len(barcode) < 8:
|
||||
continue
|
||||
if filter is not None:
|
||||
f.writelines(barcode + '\n')
|
||||
if barcode in filter:
|
||||
print(barcode)
|
||||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||
nn += 1
|
||||
else:
|
||||
print('barcode name >>{}'.format(barcode))
|
||||
f.writelines(barcode + '\n')
|
||||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||
feature_num_lists.append(feature_num)
|
||||
print("特征总量: {}".format(feature_num))
|
||||
print("barcode总量: {}".format(nn))
|
||||
f.close()
|
||||
|
||||
def get_shop_barcodes(self, file_path):
|
||||
if file_path:
|
||||
df = pd.read_excel(file_path)
|
||||
column_values = list(df.iloc[:, 6].values)
|
||||
column_values = list(map(str, column_values))
|
||||
return column_values
|
||||
else:
|
||||
return None
|
||||
|
||||
def del_base_dir(self, pth):
|
||||
for root, dirs, files in os.walk(pth):
|
||||
if len(dirs) == 1:
|
||||
if dirs[0] == 'base':
|
||||
shutil.rmtree(os.path.join(root, dirs[0]))
|
||||
|
||||
def write_binary_file(self, filename, datas):
|
||||
with open(filename, 'wb') as f:
|
||||
# 先写入数据中的key数量(为C++读取提供便利)
|
||||
key_count = len(datas)
|
||||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||
for data in datas:
|
||||
key = data['key']
|
||||
feats = data['value']
|
||||
key_bytes = key.encode('utf-8')
|
||||
key_len = len(key)
|
||||
length_byte = struct.pack('<B', key_len)
|
||||
f.write(length_byte)
|
||||
# f.write(struct.pack('Q', len(key_bytes)))
|
||||
f.write(key_bytes)
|
||||
value_count = len(feats)
|
||||
f.write(struct.pack('I', (value_count * 256)))
|
||||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||
for values in feats:
|
||||
# 写入每个浮点数值(保留小数点后六位)
|
||||
for value in values:
|
||||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||
value_half = np.float16(value)
|
||||
# print(value_half.tobytes())
|
||||
f.write(value_half.tobytes())
|
||||
|
||||
def create_binary_file(self, json_path, flag=True):
|
||||
# 1. 打开JSON文件
|
||||
with open(json_path, 'r', encoding='utf-8') as file:
|
||||
# 2. 读取并解析JSON文件内容
|
||||
data = json.load(file)
|
||||
if flag:
|
||||
for flag, values in data.items():
|
||||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
|
||||
else:
|
||||
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||
|
||||
def create_binary_files(self, index_file_pth):
|
||||
if os.path.isfile(index_file_pth):
|
||||
self.create_binary_file(index_file_pth)
|
||||
else:
|
||||
for name in os.listdir(index_file_pth):
|
||||
jsonpth = os.sep.join([index_file_pth, name])
|
||||
self.create_binary_file(jsonpth, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open('../configs/write_feature.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
###将图片名称和模型推理特征向量字典存为json文件
|
||||
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
|
||||
# xlsx_pth = None
|
||||
# del_base_dir(mg_path)
|
||||
|
||||
extractor = FeatureExtractor(conf)
|
||||
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||||
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||||
filter=column_values,
|
||||
create_single_json=False) # False
|
||||
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
142
train_compare.py
Normal file
142
train_compare.py
Normal file
@ -0,0 +1,142 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.loss import FocalLoss
|
||||
from tools.dataset import load_data
|
||||
import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
with open('configs/scatter.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Data Setup
|
||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
metric_mapping = tr_tools.get_metric(class_num)
|
||||
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
|
||||
if conf['training']['metric'] in metric_mapping:
|
||||
metric = metric_mapping[conf['training']['metric']]()
|
||||
else:
|
||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
metric = nn.DataParallel(metric)
|
||||
|
||||
# Training Setup
|
||||
if conf['training']['loss'] == 'focal_loss':
|
||||
criterion = FocalLoss(gamma=2)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||
if conf['training']['optimizer'] in optimizer_mapping:
|
||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=conf['training']['lr_step'],
|
||||
gamma=conf['training']['lr_decay']
|
||||
)
|
||||
else:
|
||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||
|
||||
# Checkpoints Setup
|
||||
checkpoints = conf['training']['checkpoints']
|
||||
os.makedirs(checkpoints, exist_ok=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('backbone>{} '.format(conf['models']['backbone']),
|
||||
'metric>{} '.format(conf['training']['metric']),
|
||||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
||||
)
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
epochs = []
|
||||
temp_loss = 100
|
||||
if conf['training']['restore']:
|
||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
||||
map_location=conf['base']['device']))
|
||||
|
||||
for e in range(conf['training']['epochs']):
|
||||
train_loss = 0
|
||||
model.train()
|
||||
|
||||
for train_data, train_labels in tqdm(train_dataloader,
|
||||
desc="Epoch {}/{}"
|
||||
.format(e, conf['training']['epochs']),
|
||||
ascii=True,
|
||||
total=len(train_dataloader)):
|
||||
train_data = train_data.to(conf['base']['device'])
|
||||
train_labels = train_labels.to(conf['base']['device'])
|
||||
|
||||
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
||||
# pdb.set_trace()
|
||||
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(train_embeddings, train_labels) # [256,357]
|
||||
else:
|
||||
thetas = metric(train_embeddings)
|
||||
tloss = criterion(thetas, train_labels)
|
||||
optimizer.zero_grad()
|
||||
tloss.backward()
|
||||
optimizer.step()
|
||||
train_loss += tloss.item()
|
||||
train_lossAvg = train_loss / len(train_dataloader)
|
||||
train_losses.append(train_lossAvg)
|
||||
epochs.append(e)
|
||||
val_loss = 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
||||
ascii=True, total=len(val_dataloader)):
|
||||
val_data = val_data.to(conf['base']['device'])
|
||||
val_labels = val_labels.to(conf['base']['device'])
|
||||
val_embeddings = model(val_data).to(conf['base']['device'])
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(val_embeddings, val_labels)
|
||||
else:
|
||||
thetas = metric(val_embeddings)
|
||||
vloss = criterion(thetas, val_labels)
|
||||
val_loss += vloss.item()
|
||||
val_lossAvg = val_loss / len(val_dataloader)
|
||||
val_losses.append(val_lossAvg)
|
||||
if val_lossAvg < temp_loss:
|
||||
if torch.cuda.device_count() > 1:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||
else:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||
temp_loss = val_lossAvg
|
||||
|
||||
scheduler.step()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
|
||||
print(log_info)
|
||||
# 写入日志文件
|
||||
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
||||
f.write(log_info + '\n')
|
||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||
else:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||
plt.plot(epochs, train_losses, color='blue')
|
||||
plt.plot(epochs, val_losses, color='red')
|
||||
# plt.savefig('lossMobilenetv3.png')
|
||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
205
train_distill.py
Normal file
205
train_distill.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""
|
||||
ResNet50蒸馏训练ResNet18实现
|
||||
学生网络使用ArcFace损失
|
||||
支持单机双卡训练
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.cuda.amp import GradScaler
|
||||
from model import resnet18, resnet50, ArcFace
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
from tools.dataset import load_data
|
||||
# from config import config as conf
|
||||
import yaml
|
||||
import math
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = '0.0.0.0'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DistillTrainer:
|
||||
def __init__(self, rank, world_size, conf):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.device = torch.device(f'cuda:{rank}')
|
||||
|
||||
# 初始化模型
|
||||
self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
|
||||
self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
|
||||
|
||||
# 加载预训练教师模型
|
||||
# teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
|
||||
teacher_path = conf['models']['teacher_model_path']
|
||||
if os.path.exists(teacher_path):
|
||||
teacher_state = torch.load(teacher_path, map_location=self.device)
|
||||
new_state_dict = {}
|
||||
for k, v in teacher_state.items():
|
||||
if k.startswith('module.'):
|
||||
new_state_dict[k[7:]] = v # 去除前7个字符'module.'
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
# 加载处理后的状态字典
|
||||
self.teacher.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"Successfully loaded teacher model from {teacher_path}")
|
||||
else:
|
||||
raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
|
||||
|
||||
# 数据加载
|
||||
self.train_loader, num_classes = load_data(training=True, cfg=conf)
|
||||
self.val_loader, _ = load_data(training=False, cfg=conf)
|
||||
|
||||
# ArcFace损失
|
||||
self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
|
||||
|
||||
# 分布式训练
|
||||
if world_size > 1:
|
||||
self.teacher = DDP(self.teacher, device_ids=[rank])
|
||||
self.student = DDP(self.student, device_ids=[rank])
|
||||
self.metric = DDP(self.metric, device_ids=[rank])
|
||||
|
||||
# 优化器
|
||||
self.optimizer = torch.optim.SGD([
|
||||
{'params': self.student.parameters()},
|
||||
{'params': self.metric.parameters()}
|
||||
], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
|
||||
self.scaler = GradScaler()
|
||||
|
||||
# 损失函数
|
||||
self.arcface_loss = nn.CrossEntropyLoss()
|
||||
self.distill_loss = nn.KLDivLoss(reduction='batchmean')
|
||||
self.conf = conf
|
||||
|
||||
def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
|
||||
"""
|
||||
余弦退火法动态调整蒸馏权重
|
||||
参数:
|
||||
epoch: 当前训练轮次
|
||||
total_epochs: 总训练轮次
|
||||
initial_weight: 初始蒸馏权重(如0.8)
|
||||
final_weight: 最终蒸馏权重(如0.1)
|
||||
返回:
|
||||
当前轮次的蒸馏权重
|
||||
"""
|
||||
return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
|
||||
def train_epoch(self, epoch):
|
||||
self.teacher.eval()
|
||||
self.student.train()
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"\nTeacher network type: {type(self.teacher)}")
|
||||
print(f"Student network type: {type(self.student)}")
|
||||
|
||||
total_loss = 0
|
||||
for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
|
||||
data = data.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# with autocast():
|
||||
# 教师输出
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(data)
|
||||
|
||||
# 学生输出
|
||||
student_features = self.student(data)
|
||||
student_logits = self.metric(student_features, labels)
|
||||
|
||||
# 计算损失
|
||||
arc_loss = self.arcface_loss(student_logits, labels)
|
||||
distill_loss = self.distill_loss(
|
||||
F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
|
||||
F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
|
||||
) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
|
||||
current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
|
||||
loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
self.scheduler.step()
|
||||
return total_loss / len(self.train_loader)
|
||||
|
||||
def validate(self):
|
||||
self.student.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, labels in self.val_loader:
|
||||
data = data.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
features = self.student(data)
|
||||
logits = self.metric(features, labels)
|
||||
|
||||
loss = self.arcface_loss(logits, labels)
|
||||
total_loss += loss.item()
|
||||
|
||||
_, predicted = torch.max(logits.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
return total_loss / len(self.val_loader), correct / total
|
||||
|
||||
def save_checkpoint(self, epoch, is_best=False):
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
state = {
|
||||
'epoch': epoch,
|
||||
'student_state_dict': self.student.state_dict(),
|
||||
'metric_state_dict': self.metric.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
|
||||
if not os.path.exists(self.conf['training']['checkpoints']):
|
||||
os.makedirs(self.conf['training']['checkpoints'])
|
||||
if filename != 'best.pth':
|
||||
torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
|
||||
else:
|
||||
torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
with open('configs/distill.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
trainer = DistillTrainer(rank, world_size, conf)
|
||||
best_acc = 0
|
||||
for epoch in range(conf['training']['epochs']):
|
||||
train_loss = trainer.train_epoch(epoch)
|
||||
val_loss, val_acc = trainer.validate()
|
||||
|
||||
if rank == 0:
|
||||
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
trainer.save_checkpoint(epoch, is_best=True)
|
||||
|
||||
cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
world_size = torch.cuda.device_count()
|
||||
if world_size > 1:
|
||||
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
||||
else:
|
||||
train(0, 1)
|
Reference in New Issue
Block a user