+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..a1acbf0
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..cda59cb
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/sshConfigs.xml b/.idea/sshConfigs.xml
new file mode 100644
index 0000000..683c220
--- /dev/null
+++ b/.idea/sshConfigs.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/webServers.xml b/.idea/webServers.xml
new file mode 100644
index 0000000..e1e9824
--- /dev/null
+++ b/.idea/webServers.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.vscode/sftp.json b/.vscode/sftp.json
new file mode 100644
index 0000000..f733c42
--- /dev/null
+++ b/.vscode/sftp.json
@@ -0,0 +1,9 @@
+{
+ "name": "My Server",
+ "host": "localhost",
+ "protocol": "sftp",
+ "port": 22,
+ "username": "username",
+ "remotePath": "/",
+ "uploadOnSave": true
+}
diff --git a/__pycache__/config.cpython-38.pyc b/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000..22df908
Binary files /dev/null and b/__pycache__/config.cpython-38.pyc differ
diff --git a/__pycache__/test_ori.cpython-38.pyc b/__pycache__/test_ori.cpython-38.pyc
new file mode 100644
index 0000000..eea10c9
Binary files /dev/null and b/__pycache__/test_ori.cpython-38.pyc differ
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..d0fc3f0
--- /dev/null
+++ b/config.py
@@ -0,0 +1,122 @@
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as F
+
+
+def pad_to_square(img):
+ w, h = img.size
+ max_wh = max(w, h)
+ padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
+ return F.pad(img, padding, fill=0, padding_mode='constant')
+
+
+class Config:
+ # network settings
+ backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
+ # mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
+ metric = 'arcface' # [cosface, arcface, softmax]
+ cbam = False
+ embedding_size = 256 # 256 # gift:2 contrast:256
+ drop_ratio = 0.5
+ img_size = 224
+ multiple_cards = True # 多卡加载
+ model_half = False # 模型半精度测试
+ data_half = True # 数据半精度测试
+ channel_ratio = 0.75 # 通道剪枝比例
+ quantization_test = False # int8量化模型测试
+
+ # custom base_data settings
+ custom_backbone = False # 迁移学习载入除最后一层的所有层
+ custom_num_classes = 128 # 迁移学习的类别数量
+
+ # if quantization_test:
+ # device = torch.device('cpu')
+ # else:
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
+ # PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
+
+ student = 'resnet'
+ # data preprocess
+ """transforms.RandomCrop(size),
+ transforms.RandomVerticalFlip(p=0.5),
+ transforms.RandomHorizontalFlip(),
+ RandomRotate(15, 0.3),
+ # RandomGaussianBlur()"""
+ train_transform = T.Compose([
+ T.Lambda(pad_to_square), # 补边
+ T.ToTensor(),
+ T.Resize((img_size, img_size), antialias=True),
+ # T.RandomCrop(img_size * 4 // 5),
+ T.RandomHorizontalFlip(p=0.5),
+ T.RandomRotation(180),
+ T.ColorJitter(brightness=0.5),
+ T.ConvertImageDtype(torch.float32),
+ T.Normalize(mean=[0.5], std=[0.5]),
+ ])
+ test_transform = T.Compose([
+ # T.Lambda(pad_to_square), # 补边
+ T.ToTensor(),
+ T.Resize((img_size, img_size), antialias=True),
+ T.ConvertImageDtype(torch.float32),
+ # T.Normalize(mean=[0,0,0], std=[255,255,255]),
+ T.Normalize(mean=[0.5], std=[0.5]),
+ ])
+
+ # dataset
+ train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
+ test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
+
+ # training settings
+ checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
+ restore = True
+ # restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
+ restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
+
+ # test settings
+ testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
+
+ # test_val = "./data/2250_train"
+ # test_list = "./data/2250_train/val_pair.txt"
+ # test_group_json = "./data/2250_train/cross_same.json"
+
+ test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
+ test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
+ test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
+ # test_group_json = "./data/test/inner_group_pairs.json"
+
+ # test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
+ test_model = "checkpoints/resnet18_1009/best.pth"
+ # test_model = "checkpoints/zhanting/inland/res_801.pth"
+ # test_model = "checkpoints/resnet18_20250504/best.pth"
+ # test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
+ group_test = False
+ # group_test = False
+
+ train_batch_size = 128 # 256
+ test_batch_size = 128 # 256
+
+ epoch = 5 # 512
+ optimizer = 'sgd' # ['sgd', 'adam', 'adamw']
+ lr = 5e-3 # 1e-2
+ lr_step = 10 # 10
+ lr_decay = 0.98 # 0.98
+ weight_decay = 5e-4
+ loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
+ log_path = './log'
+ lr_min = 1e-6 # min lr
+
+ pin_memory = False # if memory is large, set it True to speed up a bit
+ num_workers = 32 # 64
+ compare = False # compare the result of different models
+
+ '''
+ train_distill settings
+ '''
+ warmup_epochs = 3 # warmup_epoch
+ distributed = True # distributed training
+ teacher_path = "./checkpoints/resnet50_0519/best.pth"
+ distill_weight = 0.8 # 蒸馏权重
+
+config = Config()
\ No newline at end of file
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100644
index 0000000..18ef40f
--- /dev/null
+++ b/configs/__init__.py
@@ -0,0 +1 @@
+from .utils import trainer_tools
\ No newline at end of file
diff --git a/configs/compare.yml b/configs/compare.yml
new file mode 100644
index 0000000..a03261a
--- /dev/null
+++ b/configs/compare.yml
@@ -0,0 +1,69 @@
+# configs/compare.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
+ seed: 42 # 随机种子(保证可复现性)
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 0.75
+
+# 训练参数
+training:
+ epochs: 600 # 总训练轮次
+ batch_size: 128 # 批次大小
+ lr: 0.001 # 初始学习率
+ optimizer: "sgd" # 优化器类型
+ metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
+ loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
+ lr_step: 10 # 学习率调整间隔(epoch)
+ lr_decay: 0.98 # 学习率衰减率
+ weight_decay: 0.0005 # 权重衰减
+ scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
+ num_workers: 32 # 数据加载线程数
+ checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
+ restore: false
+ restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
+
+# 验证参数
+validation:
+ num_workers: 32 # 数据加载线程数
+ val_batch_size: 128 # 测试批次大小
+
+# 数据配置
+data:
+ dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
+ train_batch_size: 128 # 训练批次大小
+ val_batch_size: 128 # 验证批次大小
+ num_workers: 32 # 数据加载线程数
+ data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
+ data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./logs" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+# 分布式训练(可选)
+distributed:
+ enabled: false # 是否启用分布式训练
+ backend: "nccl" # 分布式后端(nccl/gloo)
diff --git a/configs/distill.yml b/configs/distill.yml
new file mode 100644
index 0000000..8332c16
--- /dev/null
+++ b/configs/distill.yml
@@ -0,0 +1,75 @@
+# configs/compare.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
+ seed: 42 # 随机种子(保证可复现性)
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 1.0 # 主干特征通道缩放比例(默认)
+ student_channel_ratio: 0.75
+ teacher_model_path: "./checkpoints/resnet50_0519/best.pth"
+
+# 训练参数
+training:
+ epochs: 600 # 总训练轮次
+ batch_size: 128 # 批次大小
+ lr: 0.001 # 初始学习率
+ optimizer: "sgd" # 优化器类型
+ metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
+ loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
+ lr_step: 10 # 学习率调整间隔(epoch)
+ lr_decay: 0.98 # 学习率衰减率
+ weight_decay: 0.0005 # 权重衰减
+ scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
+ num_workers: 32 # 数据加载线程数
+ checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
+ restore: false
+ restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
+ distill_weight: 0.8 # 蒸馏损失权重
+ temperature: 4 # 蒸馏温度
+
+
+
+# 验证参数
+validation:
+ num_workers: 32 # 数据加载线程数
+ val_batch_size: 128 # 测试批次大小
+
+# 数据配置
+data:
+ dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
+ train_batch_size: 128 # 训练批次大小
+ val_batch_size: 100 # 验证批次大小
+ num_workers: 4 # 数据加载线程数
+ data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
+ data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./logs" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+# 分布式训练(可选)
+distributed:
+ enabled: false # 是否启用分布式训练
+ backend: "nccl" # 分布式后端(nccl/gloo)
diff --git a/configs/scatter.yml b/configs/scatter.yml
new file mode 100644
index 0000000..7612e64
--- /dev/null
+++ b/configs/scatter.yml
@@ -0,0 +1,69 @@
+# configs/scatter.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 1.0
+
+# 训练参数
+training:
+ epochs: 300 # 总训练轮次
+ batch_size: 64 # 批次大小
+ lr: 0.005 # 初始学习率
+ optimizer: "sgd" # 优化器类型
+ metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
+ loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
+ lr_step: 10 # 学习率调整间隔(epoch)
+ lr_decay: 0.98 # 学习率衰减率
+ weight_decay: 0.0005 # 权重衰减
+ scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
+ num_workers: 32 # 数据加载线程数
+ checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
+ restore: True
+ restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
+
+
+
+# 验证参数
+validation:
+ num_workers: 32 # 数据加载线程数
+ val_batch_size: 128 # 测试批次大小
+
+# 数据配置
+data:
+ dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
+ train_batch_size: 128 # 训练批次大小
+ val_batch_size: 100 # 验证批次大小
+ num_workers: 32 # 数据加载线程数
+ data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
+ data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+# 分布式训练(可选)
+distributed:
+ enabled: false # 是否启用分布式训练
+ backend: "nccl" # 分布式后端(nccl/gloo)
diff --git a/configs/test.yml b/configs/test.yml
new file mode 100644
index 0000000..cb10797
--- /dev/null
+++ b/configs/test.yml
@@ -0,0 +1,41 @@
+# configs/test.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 1.0
+ model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
+ half: false # 是否启用半精度测试(fp16)
+
+# 数据配置
+data:
+ group_test: False # 数据集名称(示例用,可替换为实际数据集)
+ test_batch_size: 128 # 训练批次大小
+ num_workers: 32 # 数据加载线程数
+ test_dir: "../data_center/scatter/" # 验证数据集根目录
+ test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
+ test_list: "../data_center/scatter/val_pair.txt"
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+save:
+ save_dir: ""
+ save_name: ""
+
+
diff --git a/configs/utils.py b/configs/utils.py
new file mode 100644
index 0000000..899294f
--- /dev/null
+++ b/configs/utils.py
@@ -0,0 +1,56 @@
+from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
+ PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
+from timm.models import vit_base_patch16_224 as vit_base_16
+from model.metric import ArcFace, CosFace
+import torch.optim as optim
+import torch.nn as nn
+import timm
+
+
+class trainer_tools:
+ def __init__(self, conf):
+ self.conf = conf
+
+ def get_backbone(self):
+ backbone_mapping = {
+ 'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
+ 'mobilevit_s': lambda: mobilevit_s(),
+ 'mobilenetv3_small': lambda: MobileNetV3_Small(),
+ 'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
+ 'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
+ 'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
+ 'mobilenetv3_large': lambda: MobileNetV3_Large(),
+ 'vit_base': lambda: vit_base_16(pretrained=True),
+ 'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
+ num_classes=self.conf.embedding_size)
+ }
+ return backbone_mapping
+
+ def get_metric(self, class_num):
+ # 优化后的metric选择代码块,使用字典映射提高可读性和扩展性
+ metric_mapping = {
+ 'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
+ 'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
+ 'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
+ }
+ return metric_mapping
+
+ def get_optimizer(self, model, metric):
+ optimizer_mapping = {
+ 'sgd': lambda: optim.SGD(
+ [{'params': model.parameters()}, {'params': metric.parameters()}],
+ lr=self.conf['training']['lr'],
+ weight_decay=self.conf['training']['weight_decay']
+ ),
+ 'adam': lambda: optim.Adam(
+ [{'params': model.parameters()}, {'params': metric.parameters()}],
+ lr=self.conf['training']['lr'],
+ weight_decay=self.conf['training']['weight_decay']
+ ),
+ 'adamw': lambda: optim.AdamW(
+ [{'params': model.parameters()}, {'params': metric.parameters()}],
+ lr=self.conf['training']['lr'],
+ weight_decay=self.conf['training']['weight_decay']
+ )
+ }
+ return optimizer_mapping
diff --git a/configs/write_feature.yml b/configs/write_feature.yml
new file mode 100644
index 0000000..fdf7d77
--- /dev/null
+++ b/configs/write_feature.yml
@@ -0,0 +1,47 @@
+# configs/write_feature.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ distributed: true # 是否启用分布式训练
+ pin_memory: true # 是否启用pin_memory
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 0.75
+ checkpoints: "../checkpoints/resnet18_1009/best.pth"
+
+# 数据配置
+data:
+ train_batch_size: 128 # 训练批次大小
+ test_batch_size: 128 # 验证批次大小
+ num_workers: 32 # 数据加载线程数
+ half: true # 是否启用半精度数据
+ img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
+# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
+ xlsx_pth: false # 过滤商品, 默认None不进行过滤
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./logs" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+save:
+ json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
+ json_path: "../data/feature_json_compare/" # 保存单个json文件
+ error_barcodes: "error_barcodes.txt"
+ barcodes_statistics: "../search_library/barcodes_statistics.txt"
\ No newline at end of file
diff --git a/model/BAM.py b/model/BAM.py
new file mode 100644
index 0000000..4ac61ae
--- /dev/null
+++ b/model/BAM.py
@@ -0,0 +1,88 @@
+import torch.nn as nn
+import torchvision
+from torch.nn import init
+
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.shape[0], -1)
+
+
+class ChannelAttention(nn.Module):
+ def __int__(self, channel, reduction, num_layers):
+ super(ChannelAttention, self).__init__()
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ gate_channels = [channel]
+ gate_channels += [len(channel) // reduction] * num_layers
+ gate_channels += [channel]
+
+ self.ca = nn.Sequential()
+ self.ca.add_module('flatten', Flatten())
+ for i in range(len(gate_channels) - 2):
+ self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
+ self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
+ self.ca.add_module('', nn.ReLU())
+ self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
+
+ def forward(self, x):
+ res = self.avgpool(x)
+ res = self.ca(res)
+ res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
+ return res
+
+
+class SpatialAttention(nn.Module):
+ def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
+ super(SpatialAttention).__init__()
+ self.sa = nn.Sequential()
+ self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
+ self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
+ self.sa.add_module('', nn.ReLU())
+ for i in range(num_lay):
+ self.sa.add_module('', nn.Conv2d(kernel_size=3,
+ in_channels=(channel // reduction),
+ out_channels=(channel // reduction),
+ padding=1,
+ dilation=2))
+ self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
+ self.sa.add_module('', nn.ReLU())
+ self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
+
+ def forward(self, x):
+ res = self.sa(x)
+ res = res.expand_as(x)
+ return res
+
+
+class BAMblock(nn.Module):
+ def __init__(self, channel=512, reduction=16, dia_val=2):
+ super(BAMblock, self).__init__()
+ self.ca = ChannelAttention(channel, reduction)
+ self.sa = SpatialAttention(channel, reduction, dia_val)
+ self.sigmoid = nn.Sigmoid()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal(m.weight, mode='fan_out')
+ if m.bais is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ sa_out = self.sa(x)
+ ca_out = self.ca(x)
+ weight = self.sigmoid(sa_out + ca_out)
+ out = (1 + weight) * x
+ return out
+
+
+if __name__ == "__main__":
+ print(512 // 14)
diff --git a/model/CBAM.py b/model/CBAM.py
new file mode 100644
index 0000000..69747e0
--- /dev/null
+++ b/model/CBAM.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+class channelAttention(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(channelAttention, self).__init__()
+ self.Maxpooling = nn.AdaptiveMaxPool2d(1)
+ self.Avepooling = nn.AdaptiveAvgPool2d(1)
+ self.ca = nn.Sequential()
+ self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
+ self.ca.add_module('Relu', nn.ReLU())
+ self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
+ self.sigmod = nn.Sigmoid()
+
+ def forward(self, x):
+ M_out = self.Maxpooling(x)
+ A_out = self.Avepooling(x)
+ M_out = self.ca(M_out)
+ A_out = self.ca(A_out)
+ out = self.sigmod(M_out+A_out)
+ return out
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ max_result, _ = torch.max(x, dim=1, keepdim=True)
+ avg_result = torch.mean(x, dim=1, keepdim=True)
+ result = torch.cat([max_result, avg_result], dim=1)
+ output = self.conv(result)
+ output = self.sigmoid(output)
+ return output
+
+class CBAM(nn.Module):
+ def __init__(self, channel, reduction=16, kernel_size=7):
+ super().__init__()
+ self.ca = channelAttention(channel, reduction)
+ self.sa = SpatialAttention(kernel_size)
+
+ def init_weights(self):
+ for m in self.modules():#权重初始化
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ # b,c_,_ = x.size()
+ # residual = x
+ out = x*self.ca(x)
+ out = out*self.sa(out)
+ return out
+
+if __name__ == '__main__':
+ input=torch.randn(50,512,7,7)
+ kernel_size=input.shape[2]
+ cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
+ output=cbam(input)
+ print(output.shape)
diff --git a/model/Tool.py b/model/Tool.py
new file mode 100644
index 0000000..3c65931
--- /dev/null
+++ b/model/Tool.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GeM(nn.Module):
+ def __init__(self, p=3, eps=1e-6):
+ super(GeM, self).__init__()
+ self.p = nn.Parameter(torch.ones(1) * p)
+ self.eps = eps
+
+ def forward(self, x):
+ return self.gem(x, p=self.p, eps=self.eps, stride=2)
+
+ def gem(self, x, p=3, eps=1e-6, stride=2):
+ return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
+ ', ' + 'eps=' + str(self.eps) + ')'
+
+
+class TripletLoss(nn.Module):
+ def __init__(self, margin):
+ super(TripletLoss, self).__init__()
+ self.margin = margin
+
+ def forward(self, anchor, positive, negative, size_average=True):
+ distance_positive = (anchor - positive).pow(2).sum(1)
+ distance_negative = (anchor - negative).pow(2).sum(1)
+ losses = F.relu(distance_negative - distance_positive + self.margin)
+ return losses.mean() if size_average else losses.sum()
+
+
+if __name__ == '__main__':
+ print('')
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..fef1029
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,14 @@
+from .fmobilenet import FaceMobileNet
+# from .resnet_face import ResIRSE
+from .mobilevit import mobilevit_s
+from .metric import ArcFace, CosFace
+from .loss import FocalLoss
+from .resbam import resnet
+from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
+from .mobilenet_v2 import mobilenet_v2
+from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
+# from .mobilenet_v1 import mobilenet_v1
+from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, \
+ PPLCNET_x2_5
+from .vit import vit_base
+from .mlp import MLP
\ No newline at end of file
diff --git a/model/__pycache__/CBAM.cpython-38.pyc b/model/__pycache__/CBAM.cpython-38.pyc
new file mode 100644
index 0000000..fb7929f
Binary files /dev/null and b/model/__pycache__/CBAM.cpython-38.pyc differ
diff --git a/model/__pycache__/Tool.cpython-38.pyc b/model/__pycache__/Tool.cpython-38.pyc
new file mode 100644
index 0000000..3aadd9d
Binary files /dev/null and b/model/__pycache__/Tool.cpython-38.pyc differ
diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..66cdfbc
Binary files /dev/null and b/model/__pycache__/__init__.cpython-38.pyc differ
diff --git a/model/__pycache__/fmobilenet.cpython-38.pyc b/model/__pycache__/fmobilenet.cpython-38.pyc
new file mode 100644
index 0000000..07cee90
Binary files /dev/null and b/model/__pycache__/fmobilenet.cpython-38.pyc differ
diff --git a/model/__pycache__/lcnet.cpython-38.pyc b/model/__pycache__/lcnet.cpython-38.pyc
new file mode 100644
index 0000000..3646a89
Binary files /dev/null and b/model/__pycache__/lcnet.cpython-38.pyc differ
diff --git a/model/__pycache__/loss.cpython-38.pyc b/model/__pycache__/loss.cpython-38.pyc
new file mode 100644
index 0000000..2c845d3
Binary files /dev/null and b/model/__pycache__/loss.cpython-38.pyc differ
diff --git a/model/__pycache__/metric.cpython-38.pyc b/model/__pycache__/metric.cpython-38.pyc
new file mode 100644
index 0000000..d6a88d8
Binary files /dev/null and b/model/__pycache__/metric.cpython-38.pyc differ
diff --git a/model/__pycache__/mlp.cpython-38.pyc b/model/__pycache__/mlp.cpython-38.pyc
new file mode 100644
index 0000000..4998c73
Binary files /dev/null and b/model/__pycache__/mlp.cpython-38.pyc differ
diff --git a/model/__pycache__/mobilenet_v1.cpython-38.pyc b/model/__pycache__/mobilenet_v1.cpython-38.pyc
new file mode 100644
index 0000000..772951e
Binary files /dev/null and b/model/__pycache__/mobilenet_v1.cpython-38.pyc differ
diff --git a/model/__pycache__/mobilenet_v2.cpython-38.pyc b/model/__pycache__/mobilenet_v2.cpython-38.pyc
new file mode 100644
index 0000000..746f2e0
Binary files /dev/null and b/model/__pycache__/mobilenet_v2.cpython-38.pyc differ
diff --git a/model/__pycache__/mobilenet_v3.cpython-38.pyc b/model/__pycache__/mobilenet_v3.cpython-38.pyc
new file mode 100644
index 0000000..69305b7
Binary files /dev/null and b/model/__pycache__/mobilenet_v3.cpython-38.pyc differ
diff --git a/model/__pycache__/mobilevit.cpython-38.pyc b/model/__pycache__/mobilevit.cpython-38.pyc
new file mode 100644
index 0000000..00d53cc
Binary files /dev/null and b/model/__pycache__/mobilevit.cpython-38.pyc differ
diff --git a/model/__pycache__/resbam.cpython-38.pyc b/model/__pycache__/resbam.cpython-38.pyc
new file mode 100644
index 0000000..5869c88
Binary files /dev/null and b/model/__pycache__/resbam.cpython-38.pyc differ
diff --git a/model/__pycache__/resnet_pre.cpython-38.pyc b/model/__pycache__/resnet_pre.cpython-38.pyc
new file mode 100644
index 0000000..50807eb
Binary files /dev/null and b/model/__pycache__/resnet_pre.cpython-38.pyc differ
diff --git a/model/__pycache__/utils.cpython-38.pyc b/model/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000..2e67c67
Binary files /dev/null and b/model/__pycache__/utils.cpython-38.pyc differ
diff --git a/model/__pycache__/vit.cpython-38.pyc b/model/__pycache__/vit.cpython-38.pyc
new file mode 100644
index 0000000..227e029
Binary files /dev/null and b/model/__pycache__/vit.cpython-38.pyc differ
diff --git a/model/benchmark.py b/model/benchmark.py
new file mode 100644
index 0000000..1ab19fb
--- /dev/null
+++ b/model/benchmark.py
@@ -0,0 +1,142 @@
+import torch
+import torch.nn as nn
+import time
+import numpy as np
+from resnet_attention import resnet18_cbam, resnet34_cbam, resnet50_cbam
+
+# 设置随机种子以确保结果可复现
+torch.manual_seed(42)
+np.random.seed(42)
+
+# 设备配置
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print(f"测试设备: {device}")
+
+# 测试参数
+batch_sizes = [1, 4, 8, 16]
+image_sizes = [224, 384, 512]
+num_runs = 100 # 每个配置运行的次数
+warmup_runs = 20 # 预热运行次数,排除启动开销
+
+# 模型配置
+model_configs = {
+ "resnet18": {
+ "base_model": lambda: resnet18_cbam(use_cbam=False),
+ "attention_model": lambda: resnet18_cbam(use_cbam=True)
+ },
+ "resnet34": {
+ "base_model": lambda: resnet34_cbam(use_cbam=False),
+ "attention_model": lambda: resnet34_cbam(use_cbam=True)
+ },
+ "resnet50": {
+ "base_model": lambda: resnet50_cbam(use_cbam=False),
+ "attention_model": lambda: resnet50_cbam(use_cbam=True)
+ }
+}
+
+
+# 基准测试函数
+def benchmark_model(model, input_size, batch_size, num_runs, warmup_runs):
+ """
+ 测试模型的推理性能
+
+ 参数:
+ - model: 待测试的模型
+ - input_size: 输入图像尺寸
+ - batch_size: 批次大小
+ - num_runs: 测试运行次数
+ - warmup_runs: 预热运行次数
+
+ 返回:
+ - 平均推理时间(毫秒)
+ - 吞吐量(样本/秒)
+ """
+ # 设置为评估模式
+ model.eval()
+ model.to(device)
+
+ # 创建随机输入
+ input_tensor = torch.randn(batch_size, 3, input_size, input_size, device=device)
+
+ # 预热
+ with torch.no_grad():
+ for _ in range(warmup_runs):
+ _ = model(input_tensor)
+ if device.type == 'cuda':
+ torch.cuda.synchronize() # 同步GPU操作
+
+ # 测量推理时间
+ start_time = time.time()
+ with torch.no_grad():
+ for _ in range(num_runs):
+ _ = model(input_tensor)
+ if device.type == 'cuda':
+ torch.cuda.synchronize() # 同步GPU操作
+ end_time = time.time()
+
+ # 计算指标
+ total_time = end_time - start_time
+ avg_time_per_batch = total_time / num_runs * 1000 # 毫秒
+ throughput = batch_size * num_runs / total_time # 样本/秒
+
+ return avg_time_per_batch, throughput
+
+
+# 运行测试
+results = {}
+
+for model_name, config in model_configs.items():
+ results[model_name] = {}
+
+ # 创建模型
+ base_model = config["base_model"]()
+ attention_model = config["attention_model"]()
+
+ # 计算参数量
+ base_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
+ attention_params = sum(p.numel() for p in attention_model.parameters() if p.requires_grad)
+ param_increase = (attention_params - base_params) / base_params * 100
+
+ print(f"\n测试模型: {model_name}")
+ print(f" 基础参数量: {base_params / 1e6:.2f}M")
+ print(f" 带注意力参数量: {attention_params / 1e6:.2f}M")
+ print(f" 参数量增加: {param_increase:.2f}%")
+
+ for batch_size in batch_sizes:
+ for image_size in image_sizes:
+ key = f"batch_{batch_size}_size_{image_size}"
+ results[model_name][key] = {}
+
+ # 测试基础模型
+ base_time, base_throughput = benchmark_model(
+ base_model, image_size, batch_size, num_runs, warmup_runs
+ )
+
+ # 测试注意力模型
+ attention_time, attention_throughput = benchmark_model(
+ attention_model, image_size, batch_size, num_runs, warmup_runs
+ )
+
+ # 计算增加的百分比
+ time_increase = (attention_time - base_time) / base_time * 100
+ throughput_decrease = (base_throughput - attention_throughput) / base_throughput * 100
+
+ results[model_name][key]["base_time"] = base_time
+ results[model_name][key]["attention_time"] = attention_time
+ results[model_name][key]["time_increase"] = time_increase
+ results[model_name][key]["base_throughput"] = base_throughput
+ results[model_name][key]["attention_throughput"] = attention_throughput
+ results[model_name][key]["throughput_decrease"] = throughput_decrease
+
+ print(f" 配置: 批次大小={batch_size}, 图像尺寸={image_size}x{image_size}")
+ print(f" 基础模型: 平均时间={base_time:.2f}ms, 吞吐量={base_throughput:.2f}样本/秒")
+ print(f" 注意力模型: 平均时间={attention_time:.2f}ms, 吞吐量={attention_throughput:.2f}样本/秒")
+ print(f" 时间增加: {time_increase:.2f}%, 吞吐量下降: {throughput_decrease:.2f}%")
+
+# 保存结果
+import json
+
+with open('benchmark_results.json', 'w') as f:
+ json.dump(results, f, indent=2)
+
+print("\n测试完成,结果已保存到 benchmark_results.json")
diff --git a/model/compare.py b/model/compare.py
new file mode 100644
index 0000000..a92a497
--- /dev/null
+++ b/model/compare.py
@@ -0,0 +1,48 @@
+import torch
+from config import config as conf
+import torch.nn as nn
+import torchvision.models as models
+from model.resnet_pre import resnet18, resnet50
+# from model.vit import vit_base_patch16_224, vit_base_patch32_224
+
+
+class ContrastiveModel(nn.Module):
+ def __init__(self, projection_dim, model_name, contraposition=False):
+ super(ContrastiveModel, self).__init__()
+ self.contraposition = contraposition
+ self.base_model = self._get_model(model_name)
+ if not self.contraposition:
+ if 'vit' in model_name:
+ dim_mlp = self.base_model.head.weight.shape[1]
+ self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
+ else:
+ dim_mlp = self.base_model.fc.weight.shape[1]
+ self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
+ # # 冻结除 FC 层之外的所有层
+ # for name, param in self.base_model.named_parameters():
+ # if 'fc' not in name:
+ # param.requires_grad = False
+
+ def _get_projection_layer(self, dim_mlp, projection_dim):
+ return nn.Sequential(
+ nn.Linear(dim_mlp, dim_mlp),
+ nn.ReLU(inplace=True),
+ nn.Linear(dim_mlp, projection_dim)
+ )
+
+ def _get_model(self, model_name):
+ base_model = None
+ if model_name == 'resnet18':
+ base_model = resnet18(pretrained=True)
+ elif model_name == 'resnet50':
+ base_model = resnet50(pretrained=True)
+ # elif model_name == 'vit':
+ # base_model = vit_base_patch32_224()
+ return base_model
+ def forward(self, x):
+ assert self.base_model is not None, 'base_model is none'
+ x = self.base_model(x)
+ return x
+
+if __name__ == '__main__':
+ pass
\ No newline at end of file
diff --git a/model/distill.py b/model/distill.py
new file mode 100644
index 0000000..1246be5
--- /dev/null
+++ b/model/distill.py
@@ -0,0 +1,182 @@
+import torch
+from torch import nn
+from torch.nn import Module
+import torch.nn.functional as F
+
+from vit_pytorch.vit import ViT
+from vit_pytorch.t2t import T2TViT
+from vit_pytorch.efficient import ViT as EfficientViT
+
+from einops import repeat
+from config import config as conf
+# helpers
+# Data Setup
+from tools.dataset import load_data
+train_dataloader, class_num = load_data(conf, training=True)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+# classes
+
+class DistillMixin:
+ def forward(self, img, distill_token=None):
+ distilling = exists(distill_token)
+ x = self.to_patch_embedding(img)
+ b, n, _ = x.shape
+
+ cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b=b)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x += self.pos_embedding[:, :(n + 1)]
+
+ if distilling:
+ distill_tokens = repeat(distill_token, '1 n d -> b n d', b=b)
+ x = torch.cat((x, distill_tokens), dim=1)
+
+ x = self._attend(x)
+
+ if distilling:
+ x, distill_tokens = x[:, :-1], x[:, -1]
+
+ x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
+
+ x = self.to_latent(x)
+ out = self.mlp_head(x)
+
+ if distilling:
+ return out, distill_tokens
+
+ return out
+
+
+class DistillableViT(DistillMixin, ViT):
+ def __init__(self, *args, **kwargs):
+ super(DistillableViT, self).__init__(*args, **kwargs)
+ self.args = args
+ self.kwargs = kwargs
+ self.dim = kwargs['dim']
+ self.num_classes = kwargs['num_classes']
+
+ def to_vit(self):
+ v = ViT(*self.args, **self.kwargs)
+ v.load_state_dict(self.state_dict())
+ return v
+
+ def _attend(self, x):
+ x = self.dropout(x)
+ x = self.transformer(x)
+ return x
+
+
+class DistillableT2TViT(DistillMixin, T2TViT):
+ def __init__(self, *args, **kwargs):
+ super(DistillableT2TViT, self).__init__(*args, **kwargs)
+ self.args = args
+ self.kwargs = kwargs
+ self.dim = kwargs['dim']
+ self.num_classes = kwargs['num_classes']
+
+ def to_vit(self):
+ v = T2TViT(*self.args, **self.kwargs)
+ v.load_state_dict(self.state_dict())
+ return v
+
+ def _attend(self, x):
+ x = self.dropout(x)
+ x = self.transformer(x)
+ return x
+
+
+class DistillableEfficientViT(DistillMixin, EfficientViT):
+ def __init__(self, *args, **kwargs):
+ super(DistillableEfficientViT, self).__init__(*args, **kwargs)
+ self.args = args
+ self.kwargs = kwargs
+ self.dim = kwargs['dim']
+ self.num_classes = kwargs['num_classes']
+
+
+ def to_vit(self):
+ v = EfficientViT(*self.args, **self.kwargs)
+ v.load_state_dict(self.state_dict())
+ return v
+
+ def _attend(self, x):
+ return self.transformer(x)
+
+
+# knowledge distillation wrapper
+
+class DistillWrapper(Module):
+ def __init__(
+ self,
+ *,
+ teacher,
+ student,
+ temperature=1.,
+ alpha=0.5,
+ hard=False,
+ mlp_layernorm=False
+ ):
+ super().__init__()
+ # assert (isinstance(student, (
+ # DistillableViT, DistillableT2TViT, DistillableEfficientViT))), 'student must be a vision transformer'
+ if isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT)):
+ pass
+
+ self.teacher = teacher
+ self.student = student
+
+ dim = conf.embedding_size # student.dim
+ num_classes = class_num # class_num # student.num_classes
+ self.temperature = temperature
+ self.alpha = alpha
+ self.hard = hard
+
+ self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
+
+ # student is vit
+ # self.distill_mlp = nn.Sequential(
+ # nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
+ # nn.Linear(dim, num_classes)
+ # )
+
+ # student is resnet
+ self.distill_mlp = nn.Sequential(
+ nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
+ nn.Linear(dim, num_classes).to(device)
+ )
+
+ def forward(self, img, labels, temperature=None, alpha=None, **kwargs):
+
+ alpha = default(alpha, self.alpha)
+ T = default(temperature, self.temperature)
+
+ with torch.no_grad():
+ teacher_logits = self.teacher(img)
+ teacher_logits = self.distill_mlp(teacher_logits) # teach is vit 初始化
+ # student is vit
+ # student_logits, distill_tokens = self.student(img, distill_token=self.distillation_token, **kwargs)
+ # distill_logits = self.distill_mlp(distill_tokens)
+
+ # student is resnet
+ student_logits = self.student(img)
+ distill_logits = self.distill_mlp(student_logits)
+ loss = F.cross_entropy(distill_logits, labels)
+ # pdb.set_trace()
+ if not self.hard:
+ distill_loss = F.kl_div(
+ F.log_softmax(distill_logits / T, dim=-1),
+ F.softmax(teacher_logits / T, dim=-1).detach(),
+ reduction='batchmean')
+ distill_loss *= T ** 2
+ else:
+ teacher_labels = teacher_logits.argmax(dim=-1)
+ distill_loss = F.cross_entropy(distill_logits, teacher_labels)
+ # pdb.set_trace()
+ return loss * (1 - alpha) + distill_loss * alpha
\ No newline at end of file
diff --git a/model/fmobilenet.py b/model/fmobilenet.py
new file mode 100644
index 0000000..2e38a44
--- /dev/null
+++ b/model/fmobilenet.py
@@ -0,0 +1,124 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.shape[0], -1)
+
+class ConvBn(nn.Module):
+
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_c)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class ConvBnPrelu(nn.Module):
+
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
+ super().__init__()
+ self.net = nn.Sequential(
+ ConvBn(in_c, out_c, kernel, stride, padding, groups),
+ nn.PReLU(out_c)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class DepthWise(nn.Module):
+
+ def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
+ super().__init__()
+ self.net = nn.Sequential(
+ ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
+ ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
+ ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class DepthWiseRes(nn.Module):
+ """DepthWise with Residual"""
+
+ def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
+ super().__init__()
+ self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
+
+ def forward(self, x):
+ return self.net(x) + x
+
+
+class MultiDepthWiseRes(nn.Module):
+
+ def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
+ super().__init__()
+
+ self.net = nn.Sequential(*[
+ DepthWiseRes(channels, channels, kernel, stride, padding, groups)
+ for _ in range(num_block)
+ ])
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class FaceMobileNet(nn.Module):
+
+ def __init__(self, embedding_size):
+ super().__init__()
+ self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
+ self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
+ self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
+ self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
+ self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
+ self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
+ self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
+ self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
+ self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
+ self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
+ self.flatten = Flatten()
+ self.linear = nn.Linear(2048, embedding_size, bias=False)
+ self.bn = nn.BatchNorm1d(embedding_size)
+
+ def forward(self, x):
+ #print('x',x.shape)
+ out = self.conv1(x)
+ out = self.conv2(out)
+ out = self.conv3(out)
+ out = self.conv4(out)
+ out = self.conv5(out)
+ out = self.conv6(out)
+ out = self.conv7(out)
+ out = self.conv8(out)
+ out = self.conv9(out)
+ out = self.conv10(out)
+ out = self.flatten(out)
+ out = self.linear(out)
+ out = self.bn(out)
+ return out
+
+if __name__ == "__main__":
+ from PIL import Image
+ import numpy as np
+
+ x = Image.open("../samples/009.jpg").convert('L')
+ x = x.resize((128, 128))
+ x = np.asarray(x, dtype=np.float32)
+ x = x[None, None, ...]
+ x = torch.from_numpy(x)
+ net = FaceMobileNet(512)
+ net.eval()
+ with torch.no_grad():
+ out = net(x)
+ print(out.shape)
diff --git a/model/lcnet.py b/model/lcnet.py
new file mode 100644
index 0000000..c085c46
--- /dev/null
+++ b/model/lcnet.py
@@ -0,0 +1,233 @@
+import os
+import torch
+import torch.nn as nn
+import thop
+
+# try:
+# import softpool_cuda
+# from SoftPool import soft_pool2d, SoftPool2d
+# except ImportError:
+# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
+# exit(0)
+
+NET_CONFIG = {
+ # k, in_c, out_c, s, use_se
+ "blocks2": [[3, 16, 32, 1, False]],
+ "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
+ "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
+ "blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False], [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
+ "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
+}
+
+
+def autopad(k, p=None):
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
+ return p
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class HardSwish(nn.Module):
+ def __init__(self, inplace=True):
+ super(HardSwish, self).__init__()
+ self.relu6 = nn.ReLU6(inplace=inplace)
+
+ def forward(self, x):
+ return x * self.relu6(x+3) / 6
+
+
+class HardSigmoid(nn.Module):
+ def __init__(self, inplace=True):
+ super(HardSigmoid, self).__init__()
+ self.relu6 = nn.ReLU6(inplace=inplace)
+
+ def forward(self, x):
+ return (self.relu6(x+3)) / 6
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(SELayer, self).__init__()
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ HardSigmoid()
+ )
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+ y = self.avgpool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
+
+
+class DepthwiseSeparable(nn.Module):
+ def __init__(self, inp, oup, dw_size, stride, use_se=False):
+ super(DepthwiseSeparable, self).__init__()
+ self.use_se = use_se
+ self.stride = stride
+ self.inp = inp
+ self.oup = oup
+ self.dw_size = dw_size
+ self.dw_sp = nn.Sequential(
+ nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
+ padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
+ nn.BatchNorm2d(self.inp),
+ HardSwish(),
+
+ nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(self.oup),
+ HardSwish(),
+ )
+ self.se = SELayer(self.oup)
+
+ def forward(self, x):
+ x = self.dw_sp(x)
+ if self.use_se:
+ x = self.se(x)
+ return x
+
+
+class PP_LCNet(nn.Module):
+ def __init__(self, scale=1.0, class_num=256, class_expand=1280, dropout_prob=0.2):
+ super(PP_LCNet, self).__init__()
+ self.scale = scale
+ self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
+ kernel_size=3, stride=2, padding=1, bias=False)
+ # k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
+ self.blocks2 = nn.Sequential(*[
+ DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
+ oup=make_divisible(out_c * self.scale),
+ dw_size=k, stride=s, use_se=use_se)
+ for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
+ ])
+
+ self.blocks3 = nn.Sequential(*[
+ DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
+ oup=make_divisible(out_c * self.scale),
+ dw_size=k, stride=s, use_se=use_se)
+ for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
+ ])
+
+ self.blocks4 = nn.Sequential(*[
+ DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
+ oup=make_divisible(out_c * self.scale),
+ dw_size=k, stride=s, use_se=use_se)
+ for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
+ ])
+ # k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
+ self.blocks5 = nn.Sequential(*[
+ DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
+ oup=make_divisible(out_c * self.scale),
+ dw_size=k, stride=s, use_se=use_se)
+ for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
+ ])
+
+ self.blocks6 = nn.Sequential(*[
+ DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
+ oup=make_divisible(out_c * self.scale),
+ dw_size=k, stride=s, use_se=use_se)
+ for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
+ ])
+
+ self.GAP = nn.AdaptiveAvgPool2d(1)
+
+ self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
+ out_channels=class_expand,
+ kernel_size=1, stride=1, padding=0, bias=False)
+
+ self.hardswish = HardSwish()
+ self.dropout = nn.Dropout(p=dropout_prob)
+
+ self.fc = nn.Linear(class_expand, class_num)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ # print(x.shape)
+ x = self.blocks2(x)
+ # print(x.shape)
+ x = self.blocks3(x)
+ # print(x.shape)
+ x = self.blocks4(x)
+ # print(x.shape)
+ x = self.blocks5(x)
+ # print(x.shape)
+ x = self.blocks6(x)
+ # print(x.shape)
+
+ x = self.GAP(x)
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ x = torch.flatten(x, start_dim=1, end_dim=-1)
+ x = self.fc(x)
+ return x
+
+
+def PPLCNET_x0_25(**kwargs):
+ model = PP_LCNet(scale=0.25, **kwargs)
+ return model
+
+
+def PPLCNET_x0_35(**kwargs):
+ model = PP_LCNet(scale=0.35, **kwargs)
+ return model
+
+
+def PPLCNET_x0_5(**kwargs):
+ model = PP_LCNet(scale=0.5, **kwargs)
+ return model
+
+
+def PPLCNET_x0_75(**kwargs):
+ model = PP_LCNet(scale=0.75, **kwargs)
+ return model
+
+
+def PPLCNET_x1_0(**kwargs):
+ model = PP_LCNet(scale=1.0, **kwargs)
+ return model
+
+
+def PPLCNET_x1_5(**kwargs):
+ model = PP_LCNet(scale=1.5, **kwargs)
+ return model
+
+
+def PPLCNET_x2_0(**kwargs):
+ model = PP_LCNet(scale=2.0, **kwargs)
+ return model
+
+def PPLCNET_x2_5(**kwargs):
+ model = PP_LCNet(scale=2.5, **kwargs)
+ return model
+
+
+
+
+if __name__ == '__main__':
+ # input = torch.randn(1, 3, 640, 640)
+ # model = PPLCNET_x2_5()
+ # flops, params = thop.profile(model, inputs=(input,))
+ # print('flops:', flops / 1000000000)
+ # print('params:', params / 1000000)
+
+ model = PPLCNET_x1_0()
+ # model_1 = PW_Conv(3, 16)
+ input = torch.randn(2, 3, 256, 256)
+ print(input.shape)
+ output = model(input)
+ print(output.shape) # [1, num_class]
+
diff --git a/model/loss.py b/model/loss.py
new file mode 100644
index 0000000..8f40c5c
--- /dev/null
+++ b/model/loss.py
@@ -0,0 +1,18 @@
+import torch
+import torch.nn as nn
+
+
+class FocalLoss(nn.Module):
+
+ def __init__(self, gamma=2):
+ super().__init__()
+ self.gamma = gamma
+ self.ce = torch.nn.CrossEntropyLoss()
+
+ def forward(self, input, target):
+
+ #print(f'theta {input.shape, input[0]}, target {target.shape, target}')
+ logp = self.ce(input, target)
+ p = torch.exp(-logp)
+ loss = (1 - p) ** self.gamma * logp
+ return loss.mean()
\ No newline at end of file
diff --git a/model/metric.py b/model/metric.py
new file mode 100644
index 0000000..791b3a4
--- /dev/null
+++ b/model/metric.py
@@ -0,0 +1,94 @@
+# Definition of ArcFace loss and CosFace loss
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ArcFace(nn.Module):
+
+ def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
+ """ArcFace formula:
+ cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
+ Note that:
+ 0 <= m + theta <= Pi
+ So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
+ we have:
+ cos(theta) < cos(Pi - m)
+ So we can use cos(Pi - m) as threshold to check whether
+ (m + theta) go out of [0, Pi]
+
+ Args:
+ embedding_size: usually 128, 256, 512 ...
+ class_num: num of people when training
+ s: scale, see normface https://arxiv.org/abs/1704.06369
+ m: margin, see SphereFace, CosFace, and ArcFace paper
+ """
+ super().__init__()
+ self.in_features = embedding_size
+ self.out_features = class_num
+ self.s = s
+ self.m = m
+ self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
+ nn.init.xavier_uniform_(self.weight)
+
+ self.cos_m = math.cos(m)
+ self.sin_m = math.sin(m)
+ self.th = math.cos(math.pi - m)
+ self.mm = math.sin(math.pi - m) * m
+
+ def forward(self, input, label):
+ #print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
+ cosine = F.linear(F.normalize(input), F.normalize(self.weight))
+ # print('F.normalize(input)',input.shape)
+ # print('F.normalize(self.weight)',F.normalize(self.weight).shape)
+ sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
+ phi = cosine * self.cos_m - sine * self.sin_m
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
+ #print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
+ # update y_i by phi in cosine
+ output = cosine * 1.0 # make backward works
+ batch_size = len(output)
+ output[range(batch_size), label] = phi[range(batch_size), label]
+ # print(f'output {(output * self.s).shape}')
+ # print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
+ return output * self.s
+
+
+class CosFace(nn.Module):
+
+ def __init__(self, in_features, out_features, s=30.0, m=0.40):
+ """
+ Args:
+ embedding_size: usually 128, 256, 512 ...
+ class_num: num of people when training
+ s: scale, see normface https://arxiv.org/abs/1704.06369
+ m: margin, see SphereFace, CosFace, and ArcFace paper
+ """
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.s = s
+ self.m = m
+ self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
+ nn.init.xavier_uniform_(self.weight)
+
+ def forward(self, input, label):
+ cosine = F.linear(F.normalize(input), F.normalize(self.weight))
+ phi = cosine - self.m
+ output = cosine * 1.0 # make backward works
+ batch_size = len(output)
+ output[range(batch_size), label] = phi[range(batch_size), label]
+ return output * self.s
+
+class Distillation(nn.Module):
+ def __init__(self, in_features, out_features, T=1.0):
+ super(Distillation, self).__init__()
+ self.T = T
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
+ nn.init.xavier_uniform_(self.weight)
+ def forward(self, input, labels):
+ pass
\ No newline at end of file
diff --git a/model/mlp.py b/model/mlp.py
new file mode 100644
index 0000000..544250e
--- /dev/null
+++ b/model/mlp.py
@@ -0,0 +1,274 @@
+import pdb
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from model.resnet_pre import resnet18, conv1x1, BasicBlock, load_state_dict_from_url, model_urls
+
+class MLP(nn.Module):
+ def __init__(self, input_dim=256, output_dim=1):
+ super(MLP, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.fc1 = nn.Linear(self.input_dim, 128) # 32
+ self.fc2 = nn.Linear(128, 64)
+ self.fc3 = nn.Linear(64, 32)
+ self.fc4 = nn.Linear(32, 16)
+ self.fc5 = nn.Linear(16, self.output_dim)
+ self.relu = nn.ReLU()
+ self.sigmoid = nn.Sigmoid()
+ self.dropout = nn.Dropout(0.5)
+ self.bn1 = nn.BatchNorm1d(128)
+ self.bn2 = nn.BatchNorm1d(64)
+ self.bn3 = nn.BatchNorm1d(32)
+ self.bn4 = nn.BatchNorm1d(16)
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.relu(self.bn1(x))
+ x = self.fc2(x)
+ x = self.relu(self.bn2(x))
+ x = self.fc3(x)
+ x = self.relu(self.bn3(x))
+ x = self.fc4(x)
+ x = self.relu(self.bn4(x))
+ x = self.sigmoid(self.fc5(x))
+ return x
+
+
+class Net2(nn.Module): # 该网络部署有风险,dnn推理有障碍
+ def __init__(self, input_dim=960, output_dim=1):
+ super(Net2, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
+ self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
+ # self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
+ # self.conv4 = nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=1)
+ self.maxPool1 = nn.MaxPool1d(kernel_size=3, stride=2)
+ self.conv5 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=1)
+ self.maxPool2 = nn.MaxPool1d(kernel_size=3, stride=2)
+
+ self.avgPool = nn.AdaptiveAvgPool1d(1)
+ self.MaxPool = nn.AdaptiveMaxPool1d(1)
+ self.relu = nn.ReLU()
+ self.sigmoid = nn.Sigmoid()
+ self.dropout = nn.Dropout(0.5)
+ self.flatten = nn.Flatten()
+ # self.conv6 = nn.Conv1d(128, 128, kernel_size=5, stride=2, padding=1)
+ self.fc1 = nn.Linear(960, 128)
+ self.fc21 = nn.Linear(960, 32)
+ self.fc22 = nn.Linear(32, 128)
+ self.fc3 = nn.Linear(128, 1)
+ self.bn1 = nn.BatchNorm1d(16)
+ self.bn2 = nn.BatchNorm1d(32)
+ self.bn3 = nn.BatchNorm1d(64)
+ self.bn4 = nn.BatchNorm1d(128)
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+ def forward(self, x):
+ x = self.conv1(x) # 16
+ x = self.relu(x)
+ x = self.conv2(x) # 32
+ x = self.relu(x)
+ # x = self.conv3(x)
+ # x = self.relu(x)
+ # x = self.conv4(x) # 64
+ # x = self.relu(x)
+ # x = self.maxPool1(x)
+
+ x = self.conv5(x)
+ x = self.relu(x)
+ # x = self.conv6(x)
+ # x = self.relu(x)
+ # x = self.maxPool2(x)
+ # x = self.MaxPool(x)
+
+ x = x.view(x.size(0), -1)
+ x = self.dropout(x)
+ x = self.flatten(x)
+
+ # pdb.set_trace()
+ x1 = self.fc1(x)
+ x2 = self.fc22(self.fc21(x))
+ x = self.fc3(x1 + x2)
+ x = self.sigmoid(x)
+ return x
+
+class Net3(nn.Module): # 目前较合适的网络结构,相较于Net2,Net3的输出结果更加准确
+ def __init__(self, pretrained=True, progress=True, num_classes=1, scale=0.75):
+ super(Net3, self).__init__()
+ self.resnet18 = resnet18(pretrained=pretrained, progress=progress)
+
+ # Remove the last three layers (layer3, layer4, avgpool, fc)
+ # self.resnet18.layer3 = nn.Identity()
+ # self.resnet18.layer4 = nn.Identity()
+ self.resnet18.avgpool = nn.Identity()
+ self.resnet18.fc = nn.Identity()
+ self.flatten = nn.Flatten()
+ # Calculate the output size after layer2
+ # Assuming input size is 224x224, layer2 will have output size of 56x56
+ # So, the flattened size will be 128 * scale * 56 * 56
+ self.flattened_size = int(128 * (56 * 56) * scale * scale)
+
+ # Add new layers for classification
+ self.classifier = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Flatten(),
+ nn.Linear(384, num_classes), # layer1, layer2 in_features=96 # layer1 in_features=48 #layer3 in_features=192
+ # nn.ReLU(),
+ nn.Dropout(0.6),
+ # nn.Linear(256, num_classes),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ x = self.resnet18.layer1(x)
+ x = self.resnet18.layer2(x)
+ x = self.resnet18.layer3(x)
+ x = self.resnet18.layer4(x)
+
+ # Debugging: Print the shape of the tensor before flattening
+ # print("Shape before flattening:", x.shape)
+
+ # Ensure the tensor is flattened correctly
+ # x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ return x
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers, num_classes=1, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None, scale=0.75):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
+ self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ self.sigmoid = nn.Sigmoid()
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ x = self.sigmoid(x)
+ return x
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+def Net4(arch, pretrained, progress, **kwargs):
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
+ src_state_dict = state_dict
+ target_state_dict = model.state_dict()
+ skip_keys = []
+ # skip mismatch size tensors in case of pretraining
+ for k in src_state_dict.keys():
+ if k not in target_state_dict:
+ continue
+ if src_state_dict[k].size() != target_state_dict[k].size():
+ skip_keys.append(k)
+ for k in skip_keys:
+ del src_state_dict[k]
+ missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
+ return model
+
+
+if __name__ == '__main__':
+ '''
+ net2 = Net2()
+ input_tensor = torch.randn(10, 1, 64)
+ # 前向传播
+ output_tensor = net2(input_tensor)
+ # pdb.set_trace()
+ print("输入张量形状:", input_tensor.shape)
+ print("输出张量形状:", output_tensor.shape)
+ '''
+
+ # model = Net3(pretrained=True, num_classes=1) # 预训练从resnet中间结果获取数据训练模型
+ model = Net4('resnet18', True, True)
+ input_tensor = torch.randn(1, 3, 224, 244) # Adjust batch size to 10
+ output = model(input_tensor)
+ print(output.shape) # Should be [10, 2]
\ No newline at end of file
diff --git a/model/mobilenet_v1.py b/model/mobilenet_v1.py
new file mode 100644
index 0000000..1262d9e
--- /dev/null
+++ b/model/mobilenet_v1.py
@@ -0,0 +1,148 @@
+# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from typing import Callable, Any, Optional
+
+import torch
+from torch import Tensor
+from torch import nn
+from torchvision.ops.misc import Conv2dNormActivation
+from config import config as conf
+
+__all__ = [
+ "MobileNetV1",
+ "DepthWiseSeparableConv2d",
+ "mobilenet_v1",
+]
+
+
+class MobileNetV1(nn.Module):
+
+ def __init__(
+ self,
+ num_classes: int = conf.embedding_size,
+ ) -> None:
+ super(MobileNetV1, self).__init__()
+ self.features = nn.Sequential(
+ Conv2dNormActivation(3,
+ 32,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_layer=nn.BatchNorm2d,
+ activation_layer=nn.ReLU,
+ inplace=True,
+ bias=False,
+ ),
+
+ DepthWiseSeparableConv2d(32, 64, 1),
+ DepthWiseSeparableConv2d(64, 128, 2),
+ DepthWiseSeparableConv2d(128, 128, 1),
+ DepthWiseSeparableConv2d(128, 256, 2),
+ DepthWiseSeparableConv2d(256, 256, 1),
+ DepthWiseSeparableConv2d(256, 512, 2),
+ DepthWiseSeparableConv2d(512, 512, 1),
+ DepthWiseSeparableConv2d(512, 512, 1),
+ DepthWiseSeparableConv2d(512, 512, 1),
+ DepthWiseSeparableConv2d(512, 512, 1),
+ DepthWiseSeparableConv2d(512, 512, 1),
+ DepthWiseSeparableConv2d(512, 1024, 2),
+ DepthWiseSeparableConv2d(1024, 1024, 1),
+ )
+
+ self.avgpool = nn.AvgPool2d((7, 7))
+
+ self.classifier = nn.Linear(1024, num_classes)
+
+ # Initialize neural network weights
+ self._initialize_weights()
+
+ def forward(self, x: Tensor) -> Tensor:
+ out = self._forward_impl(x)
+
+ return out
+
+ # Support torch.script function
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ out = self.features(x)
+ out = self.avgpool(out)
+ out = torch.flatten(out, 1)
+ out = self.classifier(out)
+
+ return out
+
+ def _initialize_weights(self) -> None:
+ for module in self.modules():
+ if isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, 0, 0.01)
+ nn.init.zeros_(module.bias)
+
+
+class DepthWiseSeparableConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(DepthWiseSeparableConv2d, self).__init__()
+ self.stride = stride
+ if stride not in [1, 2]:
+ raise ValueError(f"stride should be 1 or 2 instead of {stride}")
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ self.conv = nn.Sequential(
+ Conv2dNormActivation(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ groups=in_channels,
+ norm_layer=norm_layer,
+ activation_layer=nn.ReLU,
+ inplace=True,
+ bias=False,
+ ),
+ Conv2dNormActivation(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_layer=norm_layer,
+ activation_layer=nn.ReLU,
+ inplace=True,
+ bias=False,
+ ),
+
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ out = self.conv(x)
+
+ return out
+
+
+def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
+ model = MobileNetV1(**kwargs)
+
+ return model
diff --git a/model/mobilenet_v2.py b/model/mobilenet_v2.py
new file mode 100644
index 0000000..d62f0cd
--- /dev/null
+++ b/model/mobilenet_v2.py
@@ -0,0 +1,200 @@
+from torch import nn
+from .utils import load_state_dict_from_url
+from config import config as conf
+
+__all__ = ['MobileNetV2', 'mobilenet_v2']
+
+
+model_urls = {
+ 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
+}
+
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
+ padding = (kernel_size - 1) // 2
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ norm_layer(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ norm_layer(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self,
+ num_classes=conf.embedding_size,
+ width_mult=1.0,
+ inverted_residual_setting=None,
+ round_nearest=8,
+ block=None,
+ norm_layer=None):
+ """
+ MobileNet V2 main class
+
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ norm_layer: Module specifying the normalization layer to use
+
+ """
+ super(MobileNetV2, self).__init__()
+
+ if block is None:
+ block = InvertedResidual
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ input_channel = 32
+ last_channel = 1280
+
+ if inverted_residual_setting is None:
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
+ input_channel = output_channel
+ # building last several layers
+ features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
+ # make it nn.Sequential
+ self.features = nn.Sequential(*features)
+
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.2),
+ nn.Linear(self.last_channel, num_classes),
+ )
+
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ x = self.features(x)
+ # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
+ x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
+ x = self.classifier(x)
+ return x
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+
+def mobilenet_v2(pretrained=True, progress=True, **kwargs):
+ """
+ Constructs a MobileNetV2 architecture from
+ `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MobileNetV2(**kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
+ progress=progress)
+ src_state_dict = state_dict
+ target_state_dict = model.state_dict()
+ skip_keys = []
+ # skip mismatch size tensors in case of pretraining
+ for k in src_state_dict.keys():
+ if k not in target_state_dict:
+ continue
+ if src_state_dict[k].size() != target_state_dict[k].size():
+ skip_keys.append(k)
+ for k in skip_keys:
+ del src_state_dict[k]
+ missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
+ #.load_state_dict(state_dict)
+ return model
diff --git a/model/mobilenet_v3.py b/model/mobilenet_v3.py
new file mode 100644
index 0000000..d69a5a0
--- /dev/null
+++ b/model/mobilenet_v3.py
@@ -0,0 +1,200 @@
+'''MobileNetV3 in PyTorch.
+
+See the paper "Inverted Residuals and Linear Bottlenecks:
+Mobile Networks for Classification, Detection and Segmentation" for more details.
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+from config import config as conf
+
+
+class hswish(nn.Module):
+ def forward(self, x):
+ out = x * F.relu6(x + 3, inplace=True) / 6
+ return out
+
+
+class hsigmoid(nn.Module):
+ def forward(self, x):
+ out = F.relu6(x + 3, inplace=True) / 6
+ return out
+
+
+class SeModule(nn.Module):
+ def __init__(self, in_size, reduction=4):
+ super(SeModule, self).__init__()
+ self.se = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(in_size // reduction),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(in_size),
+ hsigmoid()
+ )
+
+ def forward(self, x):
+ return x * self.se(x)
+
+
+class Block(nn.Module):
+ '''expand + depthwise + pointwise'''
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
+ super(Block, self).__init__()
+ self.stride = stride
+ self.se = semodule
+
+ self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
+ self.bn1 = nn.BatchNorm2d(expand_size)
+ self.nolinear1 = nolinear
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
+ self.bn2 = nn.BatchNorm2d(expand_size)
+ self.nolinear2 = nolinear
+ self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
+ self.bn3 = nn.BatchNorm2d(out_size)
+
+ self.shortcut = nn.Sequential()
+ if stride == 1 and in_size != out_size:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(out_size),
+ )
+
+ def forward(self, x):
+ out = self.nolinear1(self.bn1(self.conv1(x)))
+ out = self.nolinear2(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ if self.se != None:
+ out = self.se(out)
+ out = out + self.shortcut(x) if self.stride==1 else out
+ return out
+
+
+class MobileNetV3_Large(nn.Module):
+ def __init__(self, num_classes=conf.embedding_size):
+ super(MobileNetV3_Large, self).__init__()
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(16)
+ self.hs1 = hswish()
+
+ self.bneck = nn.Sequential(
+ Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
+ Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
+ Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
+ Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
+ Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
+ Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
+ Block(3, 40, 240, 80, hswish(), None, 2),
+ Block(3, 80, 200, 80, hswish(), None, 1),
+ Block(3, 80, 184, 80, hswish(), None, 1),
+ Block(3, 80, 184, 80, hswish(), None, 1),
+ Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
+ Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
+ Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
+ Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
+ Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
+ )
+
+
+ self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
+ self.bn2 = nn.BatchNorm2d(960)
+ self.hs2 = hswish()
+ self.linear3 = nn.Linear(960, 1280)
+ self.bn3 = nn.BatchNorm1d(1280)
+ self.hs3 = hswish()
+ self.linear4 = nn.Linear(1280, num_classes)
+ self.init_params()
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ out = self.hs1(self.bn1(self.conv1(x)))
+ out = self.bneck(out)
+ out = self.hs2(self.bn2(self.conv2(out)))
+ out = F.avg_pool2d(out, conf.img_size // 32)
+ out = out.view(out.size(0), -1)
+ out = self.hs3(self.bn3(self.linear3(out)))
+ out = self.linear4(out)
+ return out
+
+
+
+class MobileNetV3_Small(nn.Module):
+ def __init__(self, num_classes=conf.embedding_size):
+ super(MobileNetV3_Small, self).__init__()
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(16)
+ self.hs1 = hswish()
+
+ self.bneck = nn.Sequential(
+ Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
+ Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
+ Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
+ Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
+ Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
+ Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
+ Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
+ )
+
+
+ self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
+ self.bn2 = nn.BatchNorm2d(576)
+ self.hs2 = hswish()
+ self.linear3 = nn.Linear(576, 1280)
+ self.bn3 = nn.BatchNorm1d(1280)
+ self.hs3 = hswish()
+ self.linear4 = nn.Linear(1280, num_classes)
+ self.init_params()
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ out = self.hs1(self.bn1(self.conv1(x)))
+ out = self.bneck(out)
+ out = self.hs2(self.bn2(self.conv2(out)))
+ out = F.avg_pool2d(out, conf.img_size // 32)
+ out = out.view(out.size(0), -1)
+
+ out = self.hs3(self.bn3(self.linear3(out)))
+ out = self.linear4(out)
+ return out
+
+
+
+def test():
+ net = MobileNetV3_Small()
+ x = torch.randn(2,3,224,224)
+ y = net(x)
+ print(y.size())
+
+# test()
\ No newline at end of file
diff --git a/model/mobilevit.py b/model/mobilevit.py
new file mode 100644
index 0000000..f371ee9
--- /dev/null
+++ b/model/mobilevit.py
@@ -0,0 +1,265 @@
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+from config import config as conf
+
+
+def conv_1x1_bn(inp, oup):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.SiLU()
+ )
+
+
+def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.SiLU()
+ )
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ ) if project_out else nn.Identity()
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+ attn = self.attend(dots)
+ out = torch.matmul(attn, v)
+ out = rearrange(out, 'b p h n d -> b p n (h d)')
+ return self.to_out(out)
+
+
+class Transformer(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
+ ]))
+
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return x
+
+
+class MV2Block(nn.Module):
+ def __init__(self, inp, oup, stride=1, expansion=4):
+ super().__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(inp * expansion)
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ if expansion == 1:
+ self.conv = nn.Sequential(
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.SiLU(),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+ else:
+ self.conv = nn.Sequential(
+ # pw
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.SiLU(),
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.SiLU(),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileViTBlock(nn.Module):
+ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
+ super().__init__()
+ self.ph, self.pw = patch_size
+
+ self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
+ self.conv2 = conv_1x1_bn(channel, dim)
+
+ self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
+
+ self.conv3 = conv_1x1_bn(dim, channel)
+ self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
+
+ def forward(self, x):
+ y = x.clone()
+
+ # Local representations
+ x = self.conv1(x)
+ x = self.conv2(x)
+
+ # Global representations
+ _, _, h, w = x.shape
+ x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
+ x = self.transformer(x)
+ x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
+ pw=self.pw)
+
+ # Fusion
+ x = self.conv3(x)
+ x = torch.cat((x, y), 1)
+ x = self.conv4(x)
+ return x
+
+
+class MobileViT(nn.Module):
+ def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
+ super().__init__()
+ ih, iw = image_size
+ ph, pw = patch_size
+ assert ih % ph == 0 and iw % pw == 0
+
+ L = [2, 4, 3]
+
+ self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
+
+ self.mv2 = nn.ModuleList([])
+ self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
+ self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
+ self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
+ self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
+ self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
+ self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
+ self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
+
+ self.mvit = nn.ModuleList([])
+ self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
+ self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
+ self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
+
+ self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
+
+ self.pool = nn.AvgPool2d(ih // 32, 1)
+ self.fc = nn.Linear(channels[-1], num_classes, bias=False)
+
+ def forward(self, x):
+ #print('x',x.shape)
+ x = self.conv1(x)
+ x = self.mv2[0](x)
+
+ x = self.mv2[1](x)
+ x = self.mv2[2](x)
+ x = self.mv2[3](x) # Repeat
+
+ x = self.mv2[4](x)
+ x = self.mvit[0](x)
+
+ x = self.mv2[5](x)
+ x = self.mvit[1](x)
+
+ x = self.mv2[6](x)
+ x = self.mvit[2](x)
+ x = self.conv2(x)
+
+
+ #print('pool_before',x.shape)
+ x = self.pool(x).view(-1, x.shape[1])
+ #print('self_pool',self.pool)
+ #print('pool_after',x.shape)
+ x = self.fc(x)
+ return x
+
+
+def mobilevit_xxs():
+ dims = [64, 80, 96]
+ channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
+ return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
+
+
+def mobilevit_xs():
+ dims = [96, 120, 144]
+ channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
+ return MobileViT((256, 256), dims, channels, num_classes=1000)
+
+
+def mobilevit_s():
+ dims = [144, 192, 240]
+ channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
+ return MobileViT((conf.img_size, conf.img_size), dims, channels, num_classes=conf.embedding_size)
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+if __name__ == '__main__':
+ img = torch.randn(5, 3, 256, 256)
+
+ vit = mobilevit_xxs()
+ out = vit(img)
+ print(out.shape)
+ print(count_parameters(vit))
+
+ vit = mobilevit_xs()
+ out = vit(img)
+ print(out.shape)
+ print(count_parameters(vit))
+
+ vit = mobilevit_s()
+ out = vit(img)
+ print(out.shape)
+ print(count_parameters(vit))
diff --git a/model/quant_test_resnet.py b/model/quant_test_resnet.py
new file mode 100644
index 0000000..12a1b80
--- /dev/null
+++ b/model/quant_test_resnet.py
@@ -0,0 +1,412 @@
+import torch
+from torch import Tensor
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class QuantizableBasicBlock(BasicBlock):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+ self.add_relu = torch.nn.quantized.FloatFunctional()
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = self.add_relu.add_relu(out, identity)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ # return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
+ return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/model/resbam.py b/model/resbam.py
new file mode 100644
index 0000000..21395c3
--- /dev/null
+++ b/model/resbam.py
@@ -0,0 +1,142 @@
+from model.CBAM import CBAM
+import torch
+import torch.nn as nn
+from model.Tool import GeM as gem
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inchannel, outchannel, stride=1, dowsample=None):
+ # super(Bottleneck, self).__init__()
+ super().__init__()
+ self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(outchannel)
+ self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False,
+ stride=stride, padding=1)
+ self.bn2 = nn.BatchNorm2d(outchannel)
+ self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False,
+ kernel_size=1)
+ self.bn3 = nn.BatchNorm2d(outchannel * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = dowsample
+
+ def forward(self, x):
+ self.identity = x
+ # print('>>>>>>>>',type(x))
+ if self.downsample is not None:
+ # print('>>>>downsample>>>>', type(self.downsample))
+ self.identity = self.downsample(x)
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out = self.bn3(out)
+ # print('>>>>out>>>identity',out.size(),self.identity.size())
+ out = out + self.identity
+ out = self.relu(out)
+ return out
+
+
+class resnet(nn.Module):
+ def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000):
+ super().__init__()
+ self.in_channel = 64
+ self.conv1 = nn.Conv2d(in_channels=3,
+ out_channels=self.in_channel,
+ stride=2,
+ kernel_size=7,
+ padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(self.in_channel)
+ self.relu = nn.ReLU(inplace=True)
+ self.cbam = CBAM(self.in_channel)
+ self.cbam1 = CBAM(2048)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.gem = gem()
+ self.fc = nn.Linear(512 * block.expansion, num_class)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal(m.weight, mode='fan_out',
+ nonlinearity='relu')
+ if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1.0)
+ nn.init.constant_(m.bias, 1.0)
+
+ def _make_layer(self, block, channel, block_num, stride=1):
+ downsample = None
+ if stride != 1 or self.in_channel != channel * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(channel * block.expansion))
+ layer = []
+ layer.append(block(self.in_channel, channel, stride, downsample))
+ self.in_channel = channel * block.expansion
+ for _ in range(1, block_num):
+ layer.append(block(self.in_channel, channel))
+ return nn.Sequential(*layer)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ x = self.cbam(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.cbam1(x)
+ # x = self.avgpool(x)
+ x = self.gem(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+
+class TripletNet(nn.Module):
+ def __init__(self, num_class, flag=True):
+ super(TripletNet, self).__init__()
+ self.initnet = rescbam(num_class)
+ self.flag = flag
+
+ def forward(self, x1, x2=None, x3=None):
+ if self.flag:
+ output1 = self.initnet(x1)
+ output2 = self.initnet(x2)
+ output3 = self.initnet(x3)
+ return output1, output2, output3
+ else:
+ output = self.initnet(x1)
+ return output
+
+
+def rescbam(num_class):
+ return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class)
+
+
+if __name__ == '__main__':
+ input1 = torch.randn(4, 3, 640, 640)
+ input2 = torch.randn(4, 3, 640, 640)
+ input3 = torch.randn(4, 3, 640, 640)
+
+ # rescbam测试
+ # Resnet50 = rescbam(512)
+ # output = Resnet50.forward(input1)
+ # print(Resnet50)
+
+ # trnet测试
+ trnet = TripletNet(512)
+ output = trnet(input1, input2, input3)
+ print(output)
diff --git a/model/resnet.py b/model/resnet.py
new file mode 100644
index 0000000..3c4fdf0
--- /dev/null
+++ b/model/resnet.py
@@ -0,0 +1,189 @@
+"""resnet in pytorch
+
+
+
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
+
+ Deep Residual Learning for Image Recognition
+ https://arxiv.org/abs/1512.03385v1
+"""
+
+import torch
+import torch.nn as nn
+from config import config as conf
+from CBAM import CBAM
+
+class BasicBlock(nn.Module):
+ """Basic Block for resnet 18 and resnet 34
+
+ """
+
+ #BasicBlock and BottleNeck block
+ #have different output size
+ #we use class attribute expansion
+ #to distinct
+ expansion = 1
+
+ def __init__(self, in_channels, out_channels, stride=1):
+ super().__init__()
+
+ #residual function
+ self.residual_function = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels * BasicBlock.expansion)
+ )
+
+ #shortcut
+ self.shortcut = nn.Sequential()
+
+ #the shortcut output dimension is not the same with residual function
+ #use 1*1 convolution to match the dimension
+ if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_channels * BasicBlock.expansion)
+ )
+
+ def forward(self, x):
+ return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
+
+class BottleNeck(nn.Module):
+ """Residual block for resnet over 50 layers
+
+ """
+ expansion = 4
+ def __init__(self, in_channels, out_channels, stride=1):
+ super().__init__()
+ self.residual_function = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
+ nn.BatchNorm2d(out_channels * BottleNeck.expansion),
+ )
+
+ self.shortcut = nn.Sequential()
+
+ if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
+ nn.BatchNorm2d(out_channels * BottleNeck.expansion)
+ )
+
+ def forward(self, x):
+ return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, num_block, cbam = False, num_classes=conf.embedding_size):
+ super().__init__()
+
+ self.in_channels = 64
+
+ # self.conv1 = nn.Sequential(
+ # nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
+ # nn.BatchNorm2d(64),
+ # nn.ReLU(inplace=True))
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, 64,stride=2,kernel_size=7,padding=3,bias=False),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+
+ self.cbam = CBAM(self.in_channels)
+
+ #we use a different inputsize than the original paper
+ #so conv2_x's stride is 1
+ self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
+ self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
+ self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
+ self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
+ self.cbam1 = CBAM(self.in_channels)
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal(m.weight,mode = 'fan_out',
+ nonlinearity='relu')
+ if isinstance(m, (nn.BatchNorm2d)):
+ nn.init.constant_(m.weight, 1.0)
+ nn.init.constant_(m.bias, 1.0)
+
+ def _make_layer(self, block, out_channels, num_blocks, stride):
+ """make resnet layers(by layer i didnt mean this 'layer' was the
+ same as a neuron netowork layer, ex. conv layer), one layer may
+ contain more than one residual block
+
+ Args:
+ block: block type, basic block or bottle neck block
+ out_channels: output depth channel number of this layer
+ num_blocks: how many blocks per layer
+ stride: the stride of the first block of this layer
+
+ Return:
+ return a resnet layer
+ """
+
+ # we have num_block blocks per layer, the first block
+ # could be 1 or 2, other blocks would always be 1
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_channels, out_channels, stride))
+ self.in_channels = out_channels * block.expansion
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ output = self.conv1(x)
+ if cbam:
+ output = self.cbam(x)
+ output = self.conv2_x(output)
+ output = self.conv3_x(output)
+ output = self.conv4_x(output)
+ output = self.conv5_x(output)
+ if cbam:
+ output = self.cbam1(x)
+ print('pollBefore',output.shape)
+ output = self.avg_pool(output)
+ print('poolAfter',output.shape)
+ output = output.view(output.size(0), -1)
+ print('fcBefore',output.shape)
+ output = self.fc(output)
+
+ return output
+
+def resnet18(cbam = False):
+ """ return a ResNet 18 object
+ """
+ return ResNet(BasicBlock, [2, 2, 2, 2], cbam)
+
+def resnet34():
+ """ return a ResNet 34 object
+ """
+ return ResNet(BasicBlock, [3, 4, 6, 3])
+
+def resnet50():
+ """ return a ResNet 50 object
+ """
+ return ResNet(BottleNeck, [3, 4, 6, 3])
+
+def resnet101():
+ """ return a ResNet 101 object
+ """
+ return ResNet(BottleNeck, [3, 4, 23, 3])
+
+def resnet152():
+ """ return a ResNet 152 object
+ """
+ return ResNet(BottleNeck, [3, 8, 36, 3])
+
+
diff --git a/model/resnet_attention.py b/model/resnet_attention.py
new file mode 100644
index 0000000..660f205
--- /dev/null
+++ b/model/resnet_attention.py
@@ -0,0 +1,271 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ChannelAttention(nn.Module):
+ """通道注意力模块,通过全局平均池化和最大池化提取特征,经过MLP生成通道权重"""
+
+ def __init__(self, in_channels, reduction_ratio=16):
+ super(ChannelAttention, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+
+ # 共享的MLP层
+ self.fc = nn.Sequential(
+ nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
+ nn.ReLU(),
+ nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
+ )
+
+ def forward(self, x):
+ avg_out = self.fc(self.avg_pool(x))
+ max_out = self.fc(self.max_pool(x))
+ out = avg_out + max_out
+ return torch.sigmoid(out)
+
+
+class SpatialAttention(nn.Module):
+ """空间注意力模块,通过通道维度的平均和最大值操作,生成空间权重"""
+
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention, self).__init__()
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ out = torch.cat([avg_out, max_out], dim=1)
+ out = self.conv(out)
+ return torch.sigmoid(out)
+
+
+class CBAM(nn.Module):
+ """CBAM注意力模块,串联通道注意力和空间注意力"""
+
+ def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
+ super(CBAM, self).__init__()
+ self.channel_att = ChannelAttention(in_channels, reduction_ratio)
+ self.spatial_att = SpatialAttention(kernel_size)
+
+ def forward(self, x):
+ x = x * self.channel_att(x)
+ x = x * self.spatial_att(x)
+ return x
+
+
+class BasicBlock(nn.Module):
+ """ResNet基础残差块,适用于ResNet18和ResNet34"""
+ expansion = 1
+
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.downsample = downsample
+ self.stride = stride
+
+ # 是否使用CBAM注意力机制
+ self.use_cbam = use_cbam
+ if use_cbam:
+ self.cbam = CBAM(out_channels)
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ # # 如果使用注意力机制,应用CBAM
+ if self.use_cbam:
+ out = self.cbam(out)
+
+ # 如果有下采样,调整shortcut连接
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ # 残差连接
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """ResNet瓶颈残差块,适用于ResNet50及更深的网络"""
+ expansion = 4
+
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
+ super(Bottleneck, self).__init__()
+ # 1x1卷积降维
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ # 3x3卷积
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+ # 1x1卷积升维
+ self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.downsample = downsample
+ self.stride = stride
+
+ # 是否使用CBAM注意力机制
+ self.use_cbam = use_cbam
+ if use_cbam:
+ self.cbam = CBAM(out_channels * self.expansion)
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ # # 如果使用注意力机制,应用CBAM
+ if self.use_cbam:
+ out = self.cbam(out)
+
+ # 如果有下采样,调整shortcut连接
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ # 残差连接
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ """集成了CBAM注意力机制的ResNet模型"""
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, use_cbam=True):
+ super(ResNet, self).__init__()
+ self.in_channels = 64
+ self.use_cbam = use_cbam
+
+ # 初始卷积层
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.cbam1 = CBAM(64)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ # 残差块层
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ self.cbam2 = CBAM(512)
+ # 全局平均池化和分类器
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ # 初始化权重
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # 零初始化最后一个BN层的权重,使残差分支初始为0
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, out_channels, blocks, stride=1):
+ downsample = None
+ # 如果通道数不匹配或需要调整步长,创建下采样层
+ if stride != 1 or self.in_channels != out_channels * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_channels * block.expansion),
+ )
+
+ layers = []
+ # 第一个块可能需要下采样
+ layers.append(block(self.in_channels, out_channels, stride, downsample, use_cbam=self.use_cbam))
+ self.in_channels = out_channels * block.expansion
+
+ # 添加剩余的块
+ for _ in range(1, blocks):
+ layers.append(block(self.in_channels, out_channels, use_cbam=self.use_cbam))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # 特征提取
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ # if self.use_cbam:
+ # x = self.cbam1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ # if self.use_cbam:
+ # x = self.cbam2(x)
+ # 分类
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+
+# 工厂函数,创建不同深度的ResNet模型
+def resnet18_cbam(pretrained=False, **kwargs):
+ return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+
+
+def resnet34_cbam(pretrained=False, **kwargs):
+ return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+
+
+def resnet50_cbam(pretrained=False, **kwargs):
+ return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+
+
+def resnet101_cbam(pretrained=False, **kwargs):
+ return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+
+
+def resnet152_cbam(pretrained=False, **kwargs):
+ return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+
+
+# 测试模型
+if __name__ == "__main__":
+ # 创建一个带有CBAM注意力机制的ResNet50模型
+ model = resnet50_cbam(num_classes=10)
+ # 测试输入
+ x = torch.randn(1, 3, 224, 224)
+ y = model(x)
+ print(f"输入形状: {x.shape}")
+ print(f"输出形状: {y.shape}")
\ No newline at end of file
diff --git a/model/resnet_pre.py b/model/resnet_pre.py
new file mode 100644
index 0000000..724d3e7
--- /dev/null
+++ b/model/resnet_pre.py
@@ -0,0 +1,480 @@
+import torch
+import torch.nn as nn
+from config import config as conf
+
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+# from .utils import load_state_dict_from_url
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention, self).__init__()
+
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
+ padding = 3 if kernel_size == 7 else 1
+
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ x = torch.cat([avg_out, max_out], dim=1)
+ x = self.conv1(x)
+ return self.sigmoid(x)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.cam = cam
+ self.bam = bam
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+ if self.cam:
+ if planes == 64:
+ self.globalAvgPool = nn.AvgPool2d(56, stride=1)
+ elif planes == 128:
+ self.globalAvgPool = nn.AvgPool2d(28, stride=1)
+ elif planes == 256:
+ self.globalAvgPool = nn.AvgPool2d(14, stride=1)
+ elif planes == 512:
+ self.globalAvgPool = nn.AvgPool2d(7, stride=1)
+
+ self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
+ self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
+ self.sigmod = nn.Sigmoid()
+ if self.bam:
+ self.bam = SpatialAttention()
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ if self.cam:
+ ori_out = self.globalAvgPool(out)
+ out = out.view(out.size(0), -1)
+ out = self.fc1(out)
+ out = self.relu(out)
+ out = self.fc2(out)
+ out = self.sigmod(out)
+ out = out.view(out.size(0), out.size(-1), 1, 1)
+ out = out * ori_out
+
+ if self.bam:
+ out = out * self.bam(out)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ self.cam = cam
+ self.bam = bam
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ if self.cam:
+ if planes == 64:
+ self.globalAvgPool = nn.AvgPool2d(56, stride=1)
+ elif planes == 128:
+ self.globalAvgPool = nn.AvgPool2d(28, stride=1)
+ elif planes == 256:
+ self.globalAvgPool = nn.AvgPool2d(14, stride=1)
+ elif planes == 512:
+ self.globalAvgPool = nn.AvgPool2d(7, stride=1)
+
+ self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
+ self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
+ self.sigmod = nn.Sigmoid()
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ if self.cam:
+ ori_out = self.globalAvgPool(out)
+ out = out.view(out.size(0), -1)
+ out = self.fc1(out)
+ out = self.relu(out)
+ out = self.fc2(out)
+ out = self.sigmod(out)
+ out = out.view(out.size(0), out.size(-1), 1, 1)
+ out = out * ori_out
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None, scale=conf.channel_ratio):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+ print("ResNet scale: >>>>>>>>>> ", scale)
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
+ self.maxpool2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
+ nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
+ )
+ self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
+ self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+
+# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+# model = ResNet(block, layers, **kwargs)
+# if pretrained:
+# state_dict = load_state_dict_from_url(model_urls[arch],
+# progress=progress)
+# model.load_state_dict(state_dict, strict=False)
+# return model
+
+class CustomResNet18(nn.Module):
+ def __init__(self, model, num_classes=conf.custom_num_classes):
+ super(CustomResNet18, self).__init__()
+ self.custom_model = nn.Sequential(*list(model.children())[:-1])
+ self.fc = nn.Linear(model.fc.in_features, num_classes)
+
+ def forward(self, x):
+ x = self.custom_model(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+ return x
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+
+ src_state_dict = state_dict
+ target_state_dict = model.state_dict()
+ skip_keys = []
+ # skip mismatch size tensors in case of pretraining
+ for k in src_state_dict.keys():
+ if k not in target_state_dict:
+ continue
+ if src_state_dict[k].size() != target_state_dict[k].size():
+ skip_keys.append(k)
+ for k in skip_keys:
+ del src_state_dict[k]
+ missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
+
+ return model
+
+
+def resnet14(pretrained=True, progress=True, **kwargs):
+ r"""ResNet-14 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet18(pretrained=True, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ **kwargs: Additional arguments passed to ResNet, including:
+ scale (float): Channel scaling ratio (default: conf.channel_ratio)
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/model/utils.py b/model/utils.py
new file mode 100644
index 0000000..638ef07
--- /dev/null
+++ b/model/utils.py
@@ -0,0 +1,4 @@
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
diff --git a/model/vit.py b/model/vit.py
new file mode 100644
index 0000000..f598d34
--- /dev/null
+++ b/model/vit.py
@@ -0,0 +1,42 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+import torch.nn as nn
+from functools import partial, reduce
+from operator import mul
+
+from timm.models.vision_transformer import VisionTransformer, _cfg
+
+__all__ = [
+ 'vit_small',
+ 'vit_base',
+]
+
+
+def vit_small(**kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ # model.default_cfg = _cfg()
+ return model
+
+
+def vit_base(**kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg(num_classes=256)
+ return model
+
+
+if __name__ == '__main__':
+ img = torch.randn(8, 3, 224, 224)
+ vit = vit_base()
+ out = vit(img)
+ print(out.shape)
+ # print(count_parameters(vit))
diff --git a/test_ori.py b/test_ori.py
new file mode 100644
index 0000000..4357d89
--- /dev/null
+++ b/test_ori.py
@@ -0,0 +1,331 @@
+# -*- coding: utf-8 -*-
+import os.path as osp
+from typing import Dict, List, Set, Tuple
+
+import torch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+import json
+import matplotlib.pyplot as plt
+
+# from config import config as conf
+from tools.dataset import get_transform
+from configs import trainer_tools
+import yaml
+
+with open('configs/test.yml', 'r') as f:
+ conf = yaml.load(f, Loader=yaml.FullLoader)
+
+# Constants from config
+embedding_size = conf["base"]["embedding_size"]
+img_size = conf["transform"]["img_size"]
+device = conf["base"]["device"]
+
+def unique_image(pair_list: str) -> Set[str]:
+ unique_images = set()
+ try:
+ with open(pair_list, 'r') as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ img1, img2, _ = line.split()
+ unique_images.update([img1, img2])
+ except ValueError as e:
+ print(f"Skipping malformed line: {line}")
+ except IOError as e:
+ print(f"Error reading pair list file: {e}")
+ raise
+
+ return unique_images
+
+
+def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
+ """
+ Group image paths into batches of specified size.
+
+ Args:
+ images: Set of image paths to group
+ batch_size: Number of images per batch
+
+ Returns:
+ List of batches, where each batch is a list of image paths
+ """
+ image_list = list(images)
+ num_images = len(image_list)
+ batches = []
+
+ for i in range(0, num_images, batch_size):
+ batch_end = min(i + batch_size, num_images)
+ batches.append(image_list[i:batch_end])
+
+ return batches
+
+
+def _preprocess(images: list, transform) -> torch.Tensor:
+ res = []
+ for img in images:
+ im = Image.open(img)
+ im = transform(im)
+ res.append(im)
+ # data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
+ # data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
+ data = torch.stack(res)
+ return data
+
+
+def test_preprocess(images: list, transform) -> torch.Tensor:
+ res = []
+ for img in images:
+ im = Image.open(img)
+ if im.mode == 'RGBA':
+ im = im.convert('RGB')
+ im = transform(im)
+ res.append(im)
+ data = torch.stack(res)
+ return data
+
+
+def featurize(
+ images: List[str],
+ transform: callable,
+ net: nn.Module,
+ device: torch.device,
+ train: bool = False
+) -> Dict[str, torch.Tensor]:
+ try:
+ # Select appropriate preprocessing
+ preprocess_fn = _preprocess if train else test_preprocess
+
+ # Preprocess and move to device
+ data = preprocess_fn(images, transform)
+ data = data.to(device)
+ net = net.to(device)
+
+ # Extract features with automatic mixed precision
+ with torch.no_grad():
+ if conf['models']['half']:
+ data = data.half()
+ features = net(data)
+ # Create path-to-feature mapping
+ return {img: feature for img, feature in zip(images, features)}
+
+ except Exception as e:
+ print(f"Error in feature extraction: {e}")
+ raise
+def cosin_metric(x1, x2):
+ return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
+def threshold_search(y_score, y_true):
+ y_score = np.asarray(y_score)
+ y_true = np.asarray(y_true)
+ best_acc = 0
+ best_th = 0
+ for i in range(len(y_score)):
+ th = y_score[i]
+ y_test = (y_score >= th)
+ acc = np.mean((y_test == y_true).astype(int))
+ if acc > best_acc:
+ best_acc = acc
+ best_th = th
+ return best_acc, best_th
+
+
+def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
+ x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
+ plt.figure(figsize=(10, 6))
+ plt.plot(x, recall, color='red', label='recall:TP/TPFN')
+ plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
+ plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
+ plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
+ plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
+ plt.legend()
+ plt.xlabel('threshold')
+ # plt.ylabel('Similarity')
+ plt.grid(True, linestyle='--', alpha=0.5)
+ plt.savefig('grid.png')
+ plt.show()
+ plt.close()
+
+
+def showHist(same, cross):
+ Same = np.array(same)
+ Cross = np.array(cross)
+
+ fig, axs = plt.subplots(2, 1)
+ axs[0].hist(Same, bins=50, edgecolor='black')
+ axs[0].set_xlim([-0.1, 1])
+ axs[0].set_title('Same Barcode')
+
+ axs[1].hist(Cross, bins=50, edgecolor='black')
+ axs[1].set_xlim([-0.1, 1])
+ axs[1].set_title('Cross Barcode')
+ plt.savefig('plot.png')
+
+
+def compute_accuracy_recall(score, labels):
+ th = 0.1
+ squence = np.linspace(-1, 1, num=50)
+ recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
+ Same = score[:len(score) // 2]
+ Cross = score[len(score) // 2:]
+ for th in squence:
+ t_score = (score > th)
+ t_labels = (labels == 1)
+ TP = np.sum(np.logical_and(t_score, t_labels))
+ FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
+ f_score = (score < th)
+ f_labels = (labels == 0)
+ TN = np.sum(np.logical_and(f_score, f_labels))
+ FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
+ print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
+
+ PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
+ PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
+ recall.append(0 if TP == 0 else TP / (TP + FN))
+ recall_TN.append(0 if TN == 0 else TN / (TN + FP))
+ Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
+
+ showHist(Same, Cross)
+ showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
+
+
+def compute_accuracy(
+ feature_dict: Dict[str, torch.Tensor],
+ pair_list: str,
+ test_root: str
+) -> Tuple[float, float]:
+ try:
+ with open(pair_list, 'r') as f:
+ pairs = f.readlines()
+ except IOError as e:
+ print(f"Error reading pair list: {e}")
+ raise
+
+ similarities = []
+ labels = []
+
+ for pair in pairs:
+ pair = pair.strip()
+ if not pair:
+ continue
+
+ try:
+ img1, img2, label = pair.split()
+ img1_path = osp.join(test_root, img1)
+ img2_path = osp.join(test_root, img2)
+
+ # Verify features exist
+ if img1_path not in feature_dict or img2_path not in feature_dict:
+ raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
+
+ # Get features and compute similarity
+ feat1 = feature_dict[img1_path].cpu().numpy()
+ feat2 = feature_dict[img2_path].cpu().numpy()
+ similarity = cosin_metric(feat1, feat2)
+
+ similarities.append(similarity)
+ labels.append(int(label))
+
+ except Exception as e:
+ print(f"Skipping invalid pair: {pair}. Error: {e}")
+ continue
+
+ # Find optimal threshold and accuracy
+ accuracy, threshold = threshold_search(similarities, labels)
+ compute_accuracy_recall(np.array(similarities), np.array(labels))
+
+ return accuracy, threshold
+
+
+def deal_group_pair(pairList1, pairList2):
+ allsimilarity = []
+ one_similarity = []
+ for pair1 in pairList1:
+ for pair2 in pairList2:
+ similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
+ one_similarity.append(similarity)
+ allsimilarity.append(max(one_similarity)) # 最大值
+ # allsimilarity.append(sum(one_similarity) / len(one_similarity)) # 均值
+ # allsimilarity.append(statistics.median(one_similarity)) # 中位数
+ # print(allsimilarity)
+ # print(labels)
+ return allsimilarity
+
+
+def compute_group_accuracy(content_list_read):
+ allSimilarity, allLabel = [], []
+ Same, Cross = [], []
+ for data_loaded in content_list_read:
+ print(data_loaded)
+ one_group_list = []
+ try:
+ for i in range(2):
+ images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
+ group = group_image(images, conf.test_batch_size)
+ d = featurize(group[0], conf.test_transform, model, conf.device)
+ one_group_list.append(d.values())
+ if data_loaded[-1] == '1':
+ similarity = deal_group_pair(one_group_list[0], one_group_list[1])
+ Same.append(similarity)
+ else:
+ similarity = deal_group_pair(one_group_list[0], one_group_list[1])
+ Cross.append(similarity)
+ allLabel.append(data_loaded[-1])
+ allSimilarity.extend(similarity)
+ except Exception as e:
+ continue
+ # print(allSimilarity)
+ # print(allLabel)
+ return allSimilarity, allLabel
+
+
+def init_model():
+ tr_tools = trainer_tools(conf)
+ backbone_mapping = tr_tools.get_backbone()
+ if conf['models']['backbone'] in backbone_mapping:
+ model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
+ else:
+ raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
+ print('load model {} '.format(conf['models']['backbone']))
+ if torch.cuda.device_count() > 1 and conf['base']['distributed']:
+ model = nn.DataParallel(model).to(conf['base']['device'])
+ model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
+ if conf['models']['half']:
+ model.half()
+ first_param_dtype = next(model.parameters()).dtype
+ print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
+ else:
+ model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
+ if conf.model_half:
+ model.half()
+ first_param_dtype = next(model.parameters()).dtype
+ print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
+ return model
+
+
+if __name__ == '__main__':
+ model = init_model()
+ model.eval()
+
+ if not conf['data']['group_test']:
+ images = unique_image(conf['data']['test_list'])
+ images = [osp.join(conf['data']['test_dir'], img) for img in images]
+ groups = group_image(images, conf['data']['test_batch_size']) # 根据batch_size取图片
+ feature_dict = dict()
+ _, test_transform = get_transform(conf)
+ for group in groups:
+ d = featurize(group, test_transform, model, conf['base']['device'])
+ feature_dict.update(d)
+ accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
+ print(
+ "Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
+ )
+ elif conf['data']['group_test']:
+ filename = conf['data']['test_group_json']
+ with open(filename, 'r', encoding='utf-8') as file:
+ content_list_read = json.load(file)
+ Similarity, Label = compute_group_accuracy(content_list_read)
+ compute_accuracy_recall(np.array(Similarity), np.array(Label))
+ # compute_group_accuracy(data_loaded)
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tools/__pycache__/gift_data_pretreatment.cpython-38.pyc b/tools/__pycache__/gift_data_pretreatment.cpython-38.pyc
new file mode 100644
index 0000000..5b10726
Binary files /dev/null and b/tools/__pycache__/gift_data_pretreatment.cpython-38.pyc differ
diff --git a/tools/dataset.py b/tools/dataset.py
new file mode 100644
index 0000000..b1e45ff
--- /dev/null
+++ b/tools/dataset.py
@@ -0,0 +1,68 @@
+from torch.utils.data import DataLoader
+from torchvision.datasets import ImageFolder
+import torchvision.transforms.functional as F
+import torchvision.transforms as T
+# from config import config as conf
+import torch
+
+def pad_to_square(img):
+ w, h = img.size
+ max_wh = max(w, h)
+ padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
+ return F.pad(img, padding, fill=0, padding_mode='constant')
+
+def get_transform(cfg):
+ train_transform = T.Compose([
+ T.Lambda(pad_to_square), # 补边
+ T.ToTensor(),
+ T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
+ # T.RandomCrop(img_size * 4 // 5),
+ T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
+ T.RandomRotation(cfg['transform']['RandomRotation']),
+ T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
+ T.ConvertImageDtype(torch.float32),
+ T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
+ ])
+ test_transform = T.Compose([
+ # T.Lambda(pad_to_square), # 补边
+ T.ToTensor(),
+ T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
+ T.ConvertImageDtype(torch.float32),
+ T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
+ ])
+ return train_transform, test_transform
+
+def load_data(training=True, cfg=None):
+ train_transform, test_transform = get_transform(cfg)
+ if training:
+ dataroot = cfg['data']['data_train_dir']
+ transform = train_transform
+ # transform = conf.train_transform
+ batch_size = cfg['data']['train_batch_size']
+ else:
+ dataroot = cfg['data']['data_val_dir']
+ # transform = conf.test_transform
+ transform = test_transform
+ batch_size = cfg['data']['val_batch_size']
+
+ data = ImageFolder(dataroot, transform=transform)
+ class_num = len(data.classes)
+ loader = DataLoader(data,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=cfg['base']['pin_memory'],
+ num_workers=cfg['data']['num_workers'],
+ drop_last=True)
+ return loader, class_num
+
+# def load_gift_data(action):
+# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
+# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
+# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
+# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
+# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
+# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
+# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
+# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
+# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
+# return train_dataset, val_dataset, test_dataset
diff --git a/tools/dataset.txt b/tools/dataset.txt
new file mode 100644
index 0000000..9227a87
--- /dev/null
+++ b/tools/dataset.txt
@@ -0,0 +1,10 @@
+./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
+./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
+./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
+./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
+./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
+./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
+./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
+./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
+./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
+./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
diff --git a/tools/fp32comparefp16.py b/tools/fp32comparefp16.py
new file mode 100644
index 0000000..37a8424
--- /dev/null
+++ b/tools/fp32comparefp16.py
@@ -0,0 +1,112 @@
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from test_ori import group_image, init_model, featurize
+from config import config as conf
+import json
+import os.path as osp
+
+def compare_fp16_fp32(values_pf16, values_pf32, dataTest):
+ if dataTest:
+ norm_values_pf16 = torch.norm(values_pf16, p=2)
+ norm_values_pf32 = torch.norm(values_pf32, p=2)
+ euclidean_distance = torch.norm(norm_values_pf16 - norm_values_pf32, p=2)
+ print(f"欧几里得距离: {euclidean_distance}")
+ cosine_sim = torch.dot(values_pf16.float(), values_pf32) / (norm_values_pf16 * norm_values_pf32)
+ print(f"余弦相似度: {cosine_sim}")
+ else:
+
+ pass
+def cosin_metric(x1, x2, fp32=True):
+ if fp32:
+ return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
+ else:
+ x1_fp16 = x1.astype(np.float16)
+ x2_fp16 = x2.astype(np.float16)
+ # print(type(x1))
+ # pdb.set_trace()
+ return np.dot(x1_fp16, x2_fp16) / (np.linalg.norm(x1_fp16) * np.linalg.norm(x2_fp16))
+def deal_group_pair(pairList1, pairList2):
+ one_similarity_fp16, one_similarity_fp32, allsimilarity_fp32, allsimilarity_fp16 = [], [], [], []
+ for pair1 in pairList1:
+ for pair2 in pairList2:
+ # similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
+ one_similarity_fp32.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), True))
+ one_similarity_fp16.append(cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy(), False))
+ allsimilarity_fp32.append(one_similarity_fp32)
+ allsimilarity_fp16.append(one_similarity_fp16)
+ one_similarity_fp16, one_similarity_fp32 = [], []
+ return np.array(allsimilarity_fp32), np.array(allsimilarity_fp16)
+
+def compute_group_accuracy(content_list_read, model):
+ allSimilarity, allLabel = [], []
+ Same, Cross = [], []
+ flag_same = True
+ flag_diff = True
+ for data_loaded in content_list_read:
+ one_group_list = []
+ try:
+ if (flag_same and str(data_loaded[-1]) == '1') or (flag_diff and str(data_loaded[-1]) == '0'):
+ for i in range(2):
+ images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
+ group = group_image(images, conf.test_batch_size)
+ d = featurize(group[0], conf.test_transform, model, conf.device)
+ one_group_list.append(d.values())
+ if str(data_loaded[-1]) == '1':
+ flag_same = False
+ allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
+ print('fp32 same-- >', allsimilarity_fp32)
+ print('fp16 same-- >', allsimilarity_fp16)
+ else:
+ flag_diff = False
+ allsimilarity_fp32, allsimilarity_fp16 = deal_group_pair(one_group_list[0], one_group_list[1])
+ print('fp32 diff-- >', allsimilarity_fp32)
+ print('fp16 diff-- >', allsimilarity_fp16)
+ except Exception as e:
+ continue
+ # print(allSimilarity)
+ # print(allLabel)
+ return allSimilarity, allLabel
+def get_feature_list(imgPth):
+ imgs = get_files(imgPth)
+ group = group_image(imgs, conf.test_batch_size)
+ model = init_model()
+ model.eval()
+ fe = featurize(group[0], conf.test_transform, model, conf.device)
+ return fe
+
+
+def get_files(imgPth):
+ imgsList = []
+ for img in os.walk(imgPth):
+ for img_name in img[2]:
+ img_path = os.sep.join([img[0], img_name])
+ imgsList.append(img_path)
+ return imgsList
+import pdb
+
+def compare(imgPth, group=False):
+ model = init_model()
+ model.eval()
+ if not group:
+ values_pf16, values_pf32 = [], []
+ fe = get_feature_list(imgPth)
+ # pdb.set_trace()
+ values_pf32 += [value.cpu() for value in fe.values()]
+ values_pf16 += [value.cpu().half() for value in fe.values()]
+ for value_pf16, value_pf32 in zip(values_pf16, values_pf32):
+ compare_fp16_fp32(value_pf16, value_pf32, dataTest=True)
+ else:
+ filename = conf.test_group_json
+ with open(filename, 'r', encoding='utf-8') as file:
+ content_list_read = json.load(file)
+ compute_group_accuracy(content_list_read, model)
+ pass
+
+
+if __name__ == '__main__':
+ imgPth = './data/test/inner/3701375401900'
+ compare(imgPth)
diff --git a/tools/gift_assessment.py b/tools/gift_assessment.py
new file mode 100644
index 0000000..d632330
--- /dev/null
+++ b/tools/gift_assessment.py
@@ -0,0 +1,369 @@
+import os
+import pdb
+import shutil
+import sys
+
+sys.path.append('../model')
+import matplotlib.pyplot as plt
+import numpy as np
+from model.mlp import Net2, Net3, Net4
+from model import resnet18
+import torch
+from gift_data_pretreatment import getFeatureList
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def init_model(pkl_flag):
+ res_pth = r"../checkpoints/resnet18_1009/best.pth"
+ if pkl_flag:
+ gift_pth = r'../checkpoints/gift_model/action2/gift_v11.pth'
+ gift_model = Net3(pretrained=True, num_classes=1)
+ gift_model.load_state_dict(torch.load(gift_pth))
+ else:
+ gift_pth = r'../checkpoints/gift_model/action3/best.pth'
+ gift_model = Net4('resnet18', True, True) # 预训练模型
+ try:
+ print('>>multiple_cards load pre model <<')
+ gift_model.load_state_dict({k.replace('module.', ''): v for k, v in
+ torch.load(gift_pth,
+ map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')).items()})
+ except Exception as e:
+ print('>> load pre model <<')
+ gift_model.load_state_dict(torch.load(gift_pth,
+ map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
+ res_model = resnet18()
+ res_model.load_state_dict({k.replace('module.', ''): v for k, v in
+ torch.load(res_pth, map_location=torch.device(device)).items()})
+ return res_model, gift_model
+
+
+def showHist(nongifts, gifts):
+ # Same = filtered_data[:, 1].astype(np.float32)
+ # Cross = filtered_data[:, 2].astype(np.float32)
+
+ fig, axs = plt.subplots(2, 1)
+ axs[0].hist(nongifts, bins=50, edgecolor='blue')
+ axs[0].set_xlim([-0.1, 1])
+ axs[0].set_title('nongifts')
+
+ axs[1].hist(gifts, bins=50, edgecolor='green')
+ axs[1].set_xlim([-0.1, 1])
+ axs[1].set_title('gifts')
+ # plt.savefig('plot.png')
+ plt.show()
+
+
+def calculate_precision_recall(nongift, gift, points):
+ precision, recall = [], []
+ for point in points:
+ TP = np.sum(gift > point)
+ FN = np.sum(gift < point)
+ FP = np.sum(nongift > point)
+ TN = np.sum(nongift < point)
+ if TP == 0:
+ precision.append(0)
+ recall.append(0)
+ else:
+ precision.append(TP / (TP + FP))
+ recall.append(TP / (TP + FN))
+ print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
+ if point == 0.5:
+ print("point >> {} TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(point, TP, FP, TN, FN))
+ return precision, recall
+
+
+def showgrid(all_prec, all_recall, points):
+ plt.figure(figsize=(10, 6))
+ plt.plot(points[:-1], all_prec[:-1], color='blue', label='precision')
+ plt.plot(points[:-1], all_recall[:-1], color='red', label='recall')
+ plt.legend()
+ plt.xlabel('threshold')
+ # plt.ylabel('Similarity')
+ plt.grid(True, linestyle='--', alpha=0.5)
+ # plt.savefig('grid.png')
+ plt.show()
+ plt.close()
+ pass
+
+
+def discriminate_action(roots): # 判断加购还是退购
+ pth = os.sep.join([roots, 'process.data'])
+ with open(pth, 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ content = line.strip()
+ if 'weightValue' in content:
+ # print(content.split(":")[-1].split(',')[0])
+ if int(content.split(":")[-1].split(',')[0]) > 0:
+ return 'add'
+ else:
+ return 'return'
+
+
+def median(lst):
+ sorted_lst = sorted(lst)
+ n = len(sorted_lst)
+ if n % 2 == 1:
+ # 如果列表长度是奇数,中位数是中间的那个元素
+ return sorted_lst[n // 2]
+ else:
+ # 如果列表长度是偶数,中位数是中间两个元素的平均值
+ mid1 = sorted_lst[(n // 2) - 1]
+ mid2 = sorted_lst[n // 2]
+ return (mid1 + mid2) / 2
+
+
+def get_special_data(data, p):
+ # print(data)
+ length = len(data)
+ if length > 5:
+ if p == 'max':
+ return max(data[:round(length * 0.5)])
+ elif p == 'average':
+ return sum(data[:round(length * 0.5)]) / len(data[:round(length * 0.5)])
+ elif p == 'median':
+ return median(data[:round(length * 0.5)])
+ else:
+ return sum(data) / len(data)
+
+
+def read_data_file(pth):
+ result = []
+ with open(pth, 'r') as data_file:
+ lines = data_file.readlines()
+ for line in lines:
+ if line.split(':')[0] == 'free_gift__result':
+ if '0_tracking_output.data' in pth:
+ result = line.split(':')[1].split(',')[:-1]
+ else:
+ result = line.split(':')[1].split(',')[:-2]
+ result = [float(i) for i in result]
+ return result
+
+
+def get_tracking_data(pth):
+ result = []
+ with open(pth, 'r') as data_file:
+ lines = data_file.readlines()
+ for line in lines:
+ if len(line.split(',')) == 65:
+ result.append([float(item) for item in line.split(',')[:-1]])
+ return result
+
+
+def clean_reurn_data(pth):
+ for roots, dirs, files in os.walk(pth):
+ # print(roots, dirs, files)
+ if len(dirs) == 0:
+ flag = discriminate_action(roots)
+ if flag == 'return':
+ shutil.rmtree(roots)
+
+
+def get_gift_files(pth): # 测试后直接分析测试结果文件
+ add_special_output_0, return_special_output_0, return_special_output_1, add_special_output_1 = [], [], [], []
+ add_tracking_output_0, return_tracking_output_0, add_tracking_output_1, return_tracking_output_1 = [], [], [], []
+ for roots, dirs, files in os.walk(pth):
+ # print(roots, dirs, files)
+ if len(dirs) == 0:
+ flag = discriminate_action(roots)
+ for file in files:
+ if file == '0_tracking_output.data':
+ result = read_data_file(os.path.join(roots, file))
+ if not len(result) == 0:
+ if flag == 'add':
+ add_special_output_0.append(get_special_data(result, 'average')) # 加购后摄
+ else:
+ return_special_output_0.append(get_special_data(result, 'average')) # 退购后摄
+ if flag == 'add':
+ add_tracking_output_0 += read_data_file(os.path.join(roots, file))
+ else:
+ return_tracking_output_0 += read_data_file(os.path.join(roots, file))
+ elif file == '1_tracking_output.data':
+ result = read_data_file(os.path.join(roots, file))
+ if not len(result) == 0:
+ if flag == 'add':
+ add_special_output_1.append(get_special_data(result, 'average')) # 加购前摄
+ else:
+ return_special_output_1.append(get_special_data(result, 'average')) # 退购前摄
+ if flag == 'add':
+ add_tracking_output_1 += read_data_file(os.path.join(roots, file))
+ else:
+ return_tracking_output_1 += read_data_file(os.path.join(roots, file))
+ comprehensive_dicts = {"add_special_output_0": add_special_output_0,
+ "return_special_output_0": return_special_output_0,
+ "add_tracking_output_0": add_tracking_output_0,
+ "return_tracking_output_0": return_tracking_output_0,
+ "add_special_output_1": add_special_output_1,
+ "return_special_output_1": return_special_output_1,
+ "add_tracking_output_1": add_tracking_output_1,
+ "return_tracking_output_1": return_tracking_output_1,
+ }
+ # print(tracking_output_0, tracking_output_1)
+ showHist(np.array(comprehensive_dicts['add_tracking_output_0']),
+ np.array(comprehensive_dicts['add_tracking_output_1']))
+ # showHist(np.array(comprehensive_dicts['add_special_output_0']),
+ # np.array(comprehensive_dicts['add_special_output_1']))
+ return comprehensive_dicts
+
+
+def get_feature_array(img_pth_lists, res_model, gift_model, pkl_flag=True):
+ features_np = []
+ if pkl_flag:
+ for img_lists in img_pth_lists:
+ # print(img_lists)
+ fe_nps = getFeatureList(None, img_lists, res_model)
+ # fe_nps.squeeze()
+ try:
+ fe_nps = fe_nps[0][:, 256:]
+ except Exception as e:
+ print(e)
+ continue
+ fe_nps = torch.from_numpy(fe_nps)
+ fe_nps = fe_nps.view(fe_nps.shape[0], 64, 13, 13)
+ if len(fe_nps):
+ fe_np = gift_model(fe_nps)
+ fe_np = np.squeeze(fe_np.detach().numpy())
+ features_np.append(fe_np)
+ else:
+ for img_lists in img_pth_lists:
+ fe_nps = getFeatureList(None, img_lists, gift_model)
+ if len(fe_nps) > 0:
+ fe_nps = np.concatenate(fe_nps)
+ features_np.append(fe_nps)
+ return features_np
+
+
+import pickle
+
+
+def create_gift_subimg_np(data_pth, pkl_flag):
+ gift_array_pth = os.path.join(data_pth, 'gift.pkl')
+ nongift_array_pth = os.path.join(data_pth, 'nongift.pkl')
+ res_model, gift_model = init_model(pkl_flag)
+ res_model = res_model.eval()
+ gift_model = gift_model.eval()
+ gift_img_pth_list, gift_lists, nongift_img_pth_list, nongift_lists = [], [], [], []
+
+ for root, dirs, files in os.walk(data_pth):
+ if ('commodity' in root and 'subimg' in root):
+ print("commodity >> {}".format(root))
+ for file in files:
+ nongift_img_pth_list.append(os.sep.join([root, file]))
+ nongift_lists.append(nongift_img_pth_list)
+ nongift_img_pth_list = []
+ elif ('Havegift' in root and 'subimg' in root):
+ print("Havegift >> {}".format(root))
+ for file in files:
+ gift_img_pth_list.append(os.sep.join([root, file]))
+ gift_lists.append(gift_img_pth_list)
+ gift_img_pth_list = []
+ nongift = get_feature_array(nongift_lists, res_model, gift_model, pkl_flag)
+ gift = get_feature_array(gift_lists, res_model, gift_model, pkl_flag)
+ with open(nongift_array_pth, 'wb') as file:
+ pickle.dump(nongift, file)
+ with open(gift_array_pth, 'wb') as file:
+ pickle.dump(gift, file)
+
+
+def top_25_percent_mean(arr):
+ # 1. 对数组进行从高到低排序
+ sorted_arr = np.sort(arr)[::-1]
+
+ # 2. 计算数组长度的25%
+ top_25_percent_length = int(len(sorted_arr) * 0.25)
+
+ # 3. 取排序后数组的前25%元素
+ top_25_percent = sorted_arr[:top_25_percent_length]
+
+ # 4. 计算这些元素的平均值
+ mean_value = np.mean(top_25_percent)
+
+ return top_25_percent
+
+
+def assess_gift_subimg(data_pth, pkl_flag=False): # 分析分割后子图,
+ points = (np.linspace(1, 100, 100)) / 100
+ gift_pkl_pth = os.path.join(data_pth, 'gift.pkl')
+ nongift_pkl_pth = os.path.join(data_pth, 'nongift.pkl')
+ if not os.path.exists(gift_pkl_pth):
+ create_gift_subimg_np(data_pth, pkl_flag)
+ with open(nongift_pkl_pth, 'rb') as f:
+ nongift = pickle.load(f)
+ with open(gift_pkl_pth, 'rb') as f:
+ gift = pickle.load(f)
+ # showHist(nongift.flatten(), gift.flatten())
+
+ '''
+ 一分位均值
+ '''
+ nongift_mean = [np.mean(top_25_percent_mean(items)) for items in nongift]
+ gift_mean = [np.mean(top_25_percent_mean(items)) for items in gift]
+ '''
+ 中位数
+ '''
+ # nongift_mean = [np.median(items) for items in nongift]
+ # gift_mean = [np.median(items) for items in gift] # 平均值
+
+ '''
+ 全部结果
+ '''
+ # nongifts = [items for items in nongift]
+ # gifts = [items for items in gift]
+ # showHist(nongifts, gifts)
+
+ '''
+ 平均值
+ '''
+ # nongift_mean = [np.mean(items) for items in nongift]
+ # gift_mean = [np.mean(items) for items in gift]
+
+ showHist(np.array(nongift_mean), np.array(gift_mean)) # 最大值
+ precision, recall = calculate_precision_recall(np.array(nongift_mean),
+ np.array(gift_mean),
+ points)
+ showgrid(precision, recall, points)
+
+
+def get_comprehensive_dicts(data_pth):
+ gift_pth = r'../checkpoints/gift_model/action2/best.pth'
+ g_model = Net3(pretrained=True, num_classes=1)
+ g_model.load_state_dict(torch.load(gift_pth))
+ g_model.eval()
+ result = []
+ file_name = ['0_tracking_output.data',
+ '1_tracking_output.data']
+ for root, dirs, files in os.walk(data_pth):
+ if not len(dirs):
+ for file in files:
+ if file in file_name:
+ print(os.path.join(root, file))
+ result += get_tracking_data(os.path.join(root, file))
+ result = torch.from_numpy(np.array(result))
+ input = result.view(result.shape[0], 64, 1, 1)
+ input = input.to('cpu')
+ input = input.to(torch.float32)
+ ji = g_model(input)
+ print(ji)
+
+
+if __name__ == '__main__':
+ # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\赠品\\images'
+ # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\没有赠品的商品\\images'
+ # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241203赠品测试数据\\同样的商品没有捆绑赠品\\images'
+ # pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\赠品测试\\20241213赠品测试数据\\赠品'
+ # pth = r'C:\Users\HP\Desktop\zengpin\1227'
+ # get_gift_files(pth)
+
+ # 根据子图分析结果
+ pth = r'D:\Project\contrast_nettest\data\gift_test'
+ assess_gift_subimg(pth)
+
+ # 根据完整数据集分析结果
+ # pth = r'C:\Users\HP\Desktop\zengpin\1231'
+ # get_comprehensive_dicts(pth)
+
+# 删除退购视频
+# pth = r'C:\Users\HP\Desktop\gift_test\20241213\非赠品'
+# clean_reurn_data(pth)
diff --git a/tools/gift_data_pretreatment.py b/tools/gift_data_pretreatment.py
new file mode 100644
index 0000000..8fdb99e
--- /dev/null
+++ b/tools/gift_data_pretreatment.py
@@ -0,0 +1,92 @@
+import torch
+from config import config as conf
+from PIL import Image
+import numpy as np
+
+
+def convert_rgba_to_rgb(image_path, output_path=None):
+ """
+ 将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
+
+ :param image_path: 输入图像的路径
+ :param output_path: 转换后的图像保存路径
+ """
+ # 打开图像
+ img = Image.open(image_path)
+ # 转换图像模式从RGBA到RGB
+ # .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
+ if img.mode == 'RGBA':
+ # 转换为RGB模式
+ img_rgb = img.convert('RGB')
+ # 保存转换后的图像
+ img_rgb.save(image_path)
+ # print(f"Image converted from RGBA to RGB and saved to {image_path}")
+ # else:
+ # # 如果已经是RGB或其他模式,直接保存
+ # img.save(image_path)
+ # print(f"Image already in {img.mode} mode, saved to {image_path}")
+
+
+def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
+ res = []
+ for img in images:
+ try:
+ # print(img)
+ im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
+ res.append(im)
+ except:
+ continue
+ data = torch.stack(res)
+ return data
+
+
+def inference(images, model, actionModel=False):
+ data = test_preprocess(images, actionModel)
+ if torch.cuda.is_available():
+ data = data.to(conf.device)
+ features = model(data)
+ return features
+
+
+def group_image(images, batch=64) -> list:
+ """Group image paths by batch size"""
+ size = len(images)
+ res = []
+ for i in range(0, size, batch):
+ end = min(batch + i, size)
+ res.append(images[i:end])
+ return res
+
+def normalize(queFeatList):
+ for num1 in range(len(queFeatList)):
+ for num2 in range(len(queFeatList[num1])):
+ queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
+ return queFeatList
+
+def getFeatureList(barList, imgList, model):
+ # featList = [[] for i in range(len(barList))]
+ # for index, feat in enumerate(imgList):
+ fe_nps = []
+ groups = group_image(imgList)
+ for group in groups:
+ feat_tensor = inference(group, model)
+ # for fe in feat_tensor:
+ if feat_tensor.device == 'cpu':
+ fe_np = feat_tensor.squeeze().detach().numpy()
+ # fe_np = fe_np[:, 256:]
+ # fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
+ else:
+ fe_np = feat_tensor.squeeze().detach().cpu().numpy()
+ # fe_np = fe_np[:, 256:]
+ # fe_np = fe_np[256:]
+ # fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
+ # fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
+ # print(fe_np)
+
+ fe_nps.append(fe_np)
+ # if fe_nps:
+ # merged_fe_np = np.concatenate(fe_nps, axis=0)
+ # else:
+ # merged_fe_np = np.array([]) #
+ # fe_list = normalize(fe_nps)
+ return fe_nps
diff --git a/tools/json_contrast.py b/tools/json_contrast.py
new file mode 100644
index 0000000..c59198a
--- /dev/null
+++ b/tools/json_contrast.py
@@ -0,0 +1,118 @@
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+import numpy as np
+import random
+
+
+def showHist(same, cross):
+ Same = np.array(same)
+ Cross = np.array(cross)
+
+ fig, axs = plt.subplots(2, 1)
+ axs[0].hist(Same, bins=50, edgecolor='black')
+ axs[0].set_xlim([-0.1, 1])
+ axs[0].set_title('Same Barcode')
+
+ axs[1].hist(Cross, bins=50, edgecolor='black')
+ axs[1].set_xlim([-0.1, 1])
+ axs[1].set_title('Cross Barcode')
+ # plt.savefig('plot.png')
+ plt.show()
+
+
+def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
+ x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
+ plt.figure(figsize=(10, 6))
+ plt.plot(x, recall, color='red', label='recall:TP/TPFN')
+ plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
+ plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
+ plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
+ plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
+ plt.legend()
+ plt.xlabel('threshold')
+ # plt.ylabel('Similarity')
+ plt.grid(True, linestyle='--', alpha=0.5)
+ plt.savefig('grid.png')
+ plt.show()
+ plt.close()
+
+
+def compute_accuracy_recall(score, labels):
+ th = 0.1
+ squence = np.linspace(-1, 1, num=50)
+ recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
+ Same = score[:len(score) // 2]
+ Cross = score[len(score) // 2:]
+ for th in squence:
+ t_score = (score > th)
+ t_labels = (labels == 1)
+ TP = np.sum(np.logical_and(t_score, t_labels))
+ FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
+ f_score = (score < th)
+ f_labels = (labels == 0)
+ TN = np.sum(np.logical_and(f_score, f_labels))
+ FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
+ print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
+
+ PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
+ PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
+ recall.append(0 if TP == 0 else TP / (TP + FN))
+ recall_TN.append(0 if TN == 0 else TN / (TN + FP))
+ Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
+
+ showHist(Same, Cross)
+ showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
+
+
+def get_similarity(features1, features2, n, m):
+ features1 = np.array(features1)
+ features2 = np.array(features2)
+ all_similarity = []
+ for feature1 in features1:
+ for feature2 in features2:
+ similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
+ all_similarity.append(similarity)
+ test_similarity = np.array(all_similarity)
+ np_all_array = np.array(all_similarity).reshape(len(features1), len(features2))
+ if n == 5 and m == 5:
+ print(all_similarity)
+ return np.mean(np_all_array), all_similarity
+ # return sum(all_similarity)/len(all_similarity), all_similarity
+ # return max(all_similarity), all_similarity
+
+
+def deal_similarity(dicts):
+ all_similarity = []
+ similarity = []
+ same_barcode, diff_barcode = [], []
+ for n, (key1, value1) in enumerate(dicts.items()):
+ print('key1 >> {}'.format(key1))
+ for m, (key2, value2) in enumerate(dicts.items()):
+ print('key1 >> {} key2 >> {} peidui {}{}'.format(key1, key2, n, m))
+ max_similarity, some_similarity = get_similarity(value1, value2, n, m)
+ similarity.append(max_similarity)
+ if key1 == key2:
+ same_barcode += some_similarity
+ else:
+ diff_barcode += some_similarity
+ all_similarity.append(similarity)
+ similarity = []
+ all_similarity = np.array(all_similarity)
+ random.shuffle(diff_barcode)
+ same_list = [1] * len(same_barcode)
+ diff_list = [0] * len(same_barcode)
+ all_list = same_list + diff_list
+ all_score = same_barcode + diff_barcode[:len(same_barcode)]
+ compute_accuracy_recall(np.array(all_score), np.array(all_list))
+ print(all_similarity.shape)
+
+
+with open('../search_library/data_zhanting.json', 'r') as file:
+ data = json.load(file)
+dicts = {}
+for dict in data['total']:
+ key = dict['key']
+ value = dict['value']
+ dicts[key] = value
+deal_similarity(dicts)
diff --git a/tools/model_onnx_transform.py b/tools/model_onnx_transform.py
new file mode 100644
index 0000000..815e557
--- /dev/null
+++ b/tools/model_onnx_transform.py
@@ -0,0 +1,63 @@
+import pdb
+import torch
+import torch.nn as nn
+from model import resnet18
+from config import config as conf
+from collections import OrderedDict
+import cv2
+
+def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
+ # 定义模型
+ if model_name == 'resnet18':
+ model = resnet18(scale=0.75)
+
+ print('model_name >>> {}'.format(model_name))
+ if conf.multiple_cards:
+ model = model.to(torch.device('cpu'))
+ checkpoint = torch.load(pretrained_weights)
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint.items():
+ name = k[7:] # remove "module."
+ new_state_dict[name] = v
+ model.load_state_dict(new_state_dict)
+ else:
+ model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
+ # try:
+ # model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
+ # except Exception as e:
+ # print(e)
+ # # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
+ # model = nn.DataParallel(model).to(conf.device)
+ # model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
+
+
+ # 转换为ONNX
+ if model_name == 'gift_type2':
+ input_shape = [1, 64, 13, 13]
+ elif model_name == 'gift_type3':
+ input_shape = [1, 3, 224, 224]
+ else:
+ # 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
+ input_shape = [1, 3, 224, 224]
+
+ img = cv2.imread('./dog_224x224.jpg')
+
+ output_file = pretrained_weights.replace('pth', 'onnx')
+
+ # 导出模型
+ torch.onnx.export(model,
+ torch.randn(input_shape),
+ output_file,
+ verbose=True,
+ input_names=['input'],
+ output_names=['output']) ##, optset_version=12
+
+ model.eval()
+ trace_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
+ trace_model.save(output_file.replace('.onnx', '.pt'))
+ print(f"Model exported to {output_file}")
+
+
+if __name__ == '__main__':
+ tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
+ pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
diff --git a/tools/model_rknn_transform.py b/tools/model_rknn_transform.py
new file mode 100644
index 0000000..a00bf0e
--- /dev/null
+++ b/tools/model_rknn_transform.py
@@ -0,0 +1,186 @@
+import os
+import pdb
+import urllib
+import traceback
+import time
+import sys
+import numpy as np
+import cv2
+from config import config as conf
+from rknn.api import RKNN
+
+import config
+
+# ONNX_MODEL = 'resnet50v2.onnx'
+# RKNN_MODEL = 'resnet50v2.rknn'
+ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
+RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
+
+
+# ONNX_MODEL = 'v3_small_0424.onnx'
+# RKNN_MODEL = 'v3_small_0424.rknn'
+
+def show_outputs(outputs):
+ # print('***************outputs', outputs)
+ output = outputs[0][0]
+ # print('len(outputs)',len(output), output)
+ output_sorted = sorted(output, reverse=True)
+ top5_str = 'resnet50v2\n-----TOP 5-----\n'
+ for i in range(5):
+ value = output_sorted[i]
+ index = np.where(output == value)
+ for j in range(len(index)):
+ if (i + j) >= 5:
+ break
+ if value > 0:
+ topi = '{}: {}\n'.format(index[j], value)
+ else:
+ topi = '-1: 0.0\n'
+ top5_str += topi
+ # pdb.set_trace()
+ print(top5_str)
+
+
+def readable_speed(speed):
+ speed_bytes = float(speed)
+ speed_kbytes = speed_bytes / 1024
+ if speed_kbytes > 1024:
+ speed_mbytes = speed_kbytes / 1024
+ if speed_mbytes > 1024:
+ speed_gbytes = speed_mbytes / 1024
+ return "{:.2f} GB/s".format(speed_gbytes)
+ else:
+ return "{:.2f} MB/s".format(speed_mbytes)
+ else:
+ return "{:.2f} KB/s".format(speed_kbytes)
+
+
+def show_progress(blocknum, blocksize, totalsize):
+ speed = (blocknum * blocksize) / (time.time() - start_time)
+ speed_str = " Speed: {}".format(readable_speed(speed))
+ recv_size = blocknum * blocksize
+
+ f = sys.stdout
+ progress = (recv_size / totalsize)
+ progress_str = "{:.2f}%".format(progress * 100)
+ n = round(progress * 50)
+ s = ('#' * n).ljust(50, '-')
+ f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str)
+ f.flush()
+ f.write('\r\n')
+
+
+if __name__ == '__main__':
+
+ # Create RKNN object
+ rknn = RKNN(verbose=True)
+
+ # If resnet50v2 does not exist, download it.
+ # Download address:
+ # https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx
+ if not os.path.exists(ONNX_MODEL):
+ print('--> Download {}'.format(ONNX_MODEL))
+ url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx'
+ download_file = ONNX_MODEL
+ try:
+ start_time = time.time()
+ urllib.request.urlretrieve(url, download_file, show_progress)
+ except:
+ print('Download {} failed.'.format(download_file))
+ print(traceback.format_exc())
+ exit(-1)
+ print('done')
+
+ # pre-process config
+ print('--> config model')
+ # rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.82, 58.82, 58.82])
+ rknn.config(
+ mean_values=[[127.5, 127.5, 127.5]],
+ std_values=[[127.5, 127.5, 127.5]],
+ target_platform='rk3588',
+ model_pruning=False,
+ compress_weight=False,
+ single_core_mode=True)
+ # rknn.config(
+ # mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
+ # std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]]
+ # target_platform='rk3588', # 设置目标平台
+ # # quantize_dtype='int8',
+ # # quantize_algo='normal',
+ # # output_optimize=False,
+ # # output_format='rknnb'
+ # )
+ print('done')
+
+ # Load model
+ print('--> Loading model')
+ ret = rknn.load_onnx(model=ONNX_MODEL)
+ if ret != 0:
+ print('Load model failed!')
+ exit(ret)
+ print('done')
+
+ # Build model
+ print('--> Building model')
+ ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
+ # ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
+ if ret != 0:
+ print('Build model failed!')
+ exit(ret)
+ print('done')
+
+ # Export rknn model
+ print('--> Export rknn model')
+ ret = rknn.export_rknn(RKNN_MODEL)
+ if ret != 0:
+ print('Export rknn model failed!')
+ exit(ret)
+ print('done')
+
+ # Set inputs
+ img = cv2.imread('./dog_224x224.jpg')
+ # img = cv2.imread('./data/gift_test/Havegift/20241213-161415-cb8e0762-f376-45d1-8f36-7dc070990fa5/subimg/cam1_9_tid2_fid(18, 33250169482).png')
+ # print('img', img)
+ # with open('pixel_values.txt', 'w') as file:
+
+ # for y in range(img.shape[0]):
+ # for x in range(img.shape[1]):
+ # b, g, r = img[y, x]
+ # file.write(f'{r},{g},{b}\n')
+
+ # img = cv2.imread('./810115161912_810115161912_20240131-145622_0da14e4d-a3da-499f-b512-2d4168ab1c87_front_addGood_70f75407b7ae_29_01.jpg')
+ img = cv2.resize(img, (224, 224))
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # img = conf.test_transform(img)
+ # img = img.numpy()
+ # img = img.transpose(1, 2, 0)
+
+ # Init runtime environment
+ print('--> Init runtime environment')
+ ret = rknn.init_runtime()
+ # ret = rknn.init_runtime('rk3588')
+ if ret != 0:
+ print('Init runtime environment failed!')
+ exit(ret)
+ print('done')
+
+ # Inference
+ print('--> Running model')
+ T1 = time.time()
+ outputs = rknn.inference(inputs=[img])
+ # outputs = rknn.inference(inputs=img)
+ T2 = time.time()
+ print('消耗时间 >>> {}'.format(T2 - T1))
+ with open('result_0415_128.txt', 'a') as f:
+ f.write(str(outputs))
+ # pdb.set_trace()
+ print('***outputs', outputs)
+ np.save('./onnx_resnet50v2_0.npy', outputs[0])
+ x = outputs[0]
+ output = np.exp(x) / np.sum(np.exp(x))
+ outputs = [output]
+ show_outputs(outputs)
+ print('done')
+
+ rknn.release()
diff --git a/tools/operate_usearch.py b/tools/operate_usearch.py
new file mode 100644
index 0000000..b8dc4a7
--- /dev/null
+++ b/tools/operate_usearch.py
@@ -0,0 +1,233 @@
+import os
+import numpy as np
+from usearch.index import Index
+import json
+import struct
+
+
+def create_index():
+ index = Index(
+ ndim=256,
+ metric='cos',
+ # dtype='f32',
+ dtype='f16',
+ connectivity=32,
+ expansion_add=40, # 128,
+ expansion_search=10, # 64,
+ multi=True
+ )
+ return index
+
+
+def compare_feature(features1, features2, model='1'):
+ """
+ :param model 比对策略
+ '0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
+ '1':带对比的所有相似度的均值
+ '2':比对1:1的最大值
+ :param feature1:
+ :param feature2:
+ :return:
+ """
+ similarity_group, similarity_groups = [], []
+ if model == '0':
+ for feature1 in features1:
+ for feature2 in features2[0]:
+ similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
+ similarity_group.append(similarity)
+ similarity_groups.append(max(similarity_group))
+ similarity_group = []
+ return sum(similarity_groups) / len(similarity_groups)
+
+ elif model == '1':
+ feature2 = features2[0]
+ for feature1 in features1:
+ for num in range(len(feature2)):
+ similarity = np.dot(feature1, feature2[num]) / (
+ np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
+ similarity_group.append(similarity)
+ similarity_groups.append(sum(similarity_group) / len(similarity_group))
+ similarity_group = []
+ # return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
+ if len(similarity_groups) == 0:
+ return -1
+ return sum(similarity_groups) / len(similarity_groups)
+ elif model == '2':
+ feature2 = features2[0]
+ for feature1 in features1:
+ for num in range(len(feature2)):
+ similarity = np.dot(feature1, feature2[num]) / (
+ np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
+ similarity_group.append(similarity)
+ return max(similarity_group)
+
+def get_barcode_feature(data):
+ barcode = data['key']
+ features = data['value']
+ return [barcode] * len(features), features
+
+
+def analysis_file(file_path):
+ """
+ :param file_path:
+ :return:
+ """
+ barcodes, features = [], []
+ with open(file_path, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ for dic in data['total']:
+ barcode, feature = get_barcode_feature(dic)
+ barcodes.append(barcode)
+ features.append(feature)
+ return barcodes, features
+
+
+def create_base_index(index_file_pth=None,
+ barcodes=None,
+ features=None,
+ save_index_name=None):
+ index = create_index()
+ if index_file_pth is not None:
+ # save_index_name = index_file_pth.split('json')[0] + 'usearch'
+ save_index_name = index_file_pth.split('json')[0] + 'data'
+ barcodes, features = analysis_file(index_file_pth)
+ else:
+ assert barcodes is not None and features is not None, 'barcodes and features must be not None'
+ for barcode, feature in zip(barcodes, features):
+ try:
+ index.add(np.array(barcode), np.array(feature))
+ except Exception as e:
+ print(e)
+ continue
+ index.save(save_index_name)
+
+
+def get_feature_index(index_file_pth=None,
+ barcodes=None):
+ assert index_file_pth is not None, 'index_file_pth must be not None'
+ index = Index.restore(index_file_pth, view=True)
+ feature_lists = index.get(np.array(barcodes))
+ print("memory {} size {}".format(index.memory_usage, index.size))
+ print("feature_lists {}".format(feature_lists))
+ return feature_lists
+
+
+def search_in_index(query=None,
+ barcode=None, # barcode -> int or np.ndarray
+ index_name=None,
+ temp_index=False, # 是否为临时库
+ model='0',
+ ):
+ if temp_index:
+ assert index_name is not None, 'index_name must be not None'
+ index = Index.restore(index_name, view=True)
+ if barcode is not None: # 1:1对比测试
+ feature_lists = index.get(np.array(barcode))
+ results = compare_feature(query, feature_lists)
+ else:
+ results = index.search(query, count=5)
+ return results
+ else: # 标准库
+ assert index_name is not None, 'index_name must be not None'
+ index = Index.restore(index_name, view=True)
+ if barcode is not None: # 1:1对比测试
+ feature_lists = index.get(np.array(barcode))
+ results = compare_feature(query, feature_lists, model)
+ else:
+ results = index.search(query, count=10)
+ return results
+
+
+def delete_index(index_name=None, key=None, index=None):
+ assert key is not None, 'key must be not None'
+ if index is None:
+ assert index_name is not None, 'index_name must be not None'
+ index = Index.restore(index_name, view=True)
+ index.remove(index_name)
+ else:
+ index.remove(key)
+
+from scipy.spatial.distance import cdist
+def compute_similarity_matrix(featurelists1, featurelists2):
+ """计算图片之间的余弦相似度矩阵"""
+ # 计算所有向量对之间的余弦相似度
+ cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
+ cosine_similarities = np.around(cosine_similarities, decimals=3)
+ return cosine_similarities
+
+def check_usearch_json_diff(index_file_pth, json_file_pth):
+ json_features = None
+ feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
+ with open(json_file_pth, 'r') as json_file:
+ json_data = json.load(json_file)
+ for data in json_data['total']:
+ if data['key'] == '6923644272159':
+ json_features = data['value']
+ json_features = np.array(json_features)
+ feature_lists = np.array(feature_lists[0])
+ compute_similarity_matrix(json_features, feature_lists)
+
+
+def write_binary_file(filename, datas):
+ with open(filename, 'wb') as f:
+ # 先写入数据中的key数量(为C++读取提供便利)
+ key_count = len(datas)
+ f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
+
+ for data in datas:
+ key = data['key']
+ feats = data['value']
+ key_bytes = key.encode('utf-8')
+ key_len = len(key)
+ length_byte = struct.pack('= 2, nested_list))) # 去除无轨迹的数据
+ filtered_correct = filtered_list[filtered_list[:, 0] != 'wrong'] # 获取比对正确的时项
+ filtered_wrong = filtered_list[filtered_list[:, 0] == 'wrong'] # 获取比对错误的时项
+ showHist(filtered_correct)
+ # showHist(filtered_wrong)
+ print(filtered_list)
+
+
+def deal_process(file_pth):
+ flag = False
+ event = file_pth.split('\\')[-2]
+ target_barcode = file_pth.split('\\')[-2].split('_')[-1]
+ temp_list = []
+
+ with open(file_pth, 'r') as f:
+ for line in f:
+ if 'oneToOne' in line:
+ flag = True
+ continue
+ if flag:
+ line = line.replace('\n', '')
+ comparison_data = line.split(',')
+ forecast_barcode = comparison_data[0]
+ value = comparison_data[-1].split(':')[-1]
+ if value == '':
+ break
+ if len(temp_list) == 0:
+ if forecast_barcode == target_barcode:
+ temp_list.append('correct')
+ else:
+ temp_list.append('wrong')
+ temp_list.append(float(value))
+ temp_list.append(event)
+ return temp_list
+
+
+def anaylze_scratch(scratch_pth):
+ purchase, back = [], []
+ for root, dirs, files in os.walk(scratch_pth):
+ if len(root) > 0:
+ if len(root.split('_')) == 4: # 加购
+ process = os.path.join(root, 'process.data')
+ if not os.path.exists(process):
+ continue
+ purchase.append(deal_process(process))
+ elif len(root.split('_')) == 3:
+ process = os.path.join(root, 'process.data')
+ if not os.path.exists(process):
+ continue
+ back.append(deal_process(process))
+ # get_tartget_list(purchase)
+ get_tartget_list(back)
+ print(purchase)
+
+
+if __name__ == '__main__':
+ # scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1108_展厅模型v800测试\\'
+ scratch_pth = r'\\192.168.1.28\\share\\测试视频数据以及日志\\各模块测试记录\\展厅测试\\1120_展厅模型v801测试\\扫A放A\\'
+ anaylze_scratch(scratch_pth)
diff --git a/tools/write_feature_json.py b/tools/write_feature_json.py
new file mode 100644
index 0000000..59e1d6e
--- /dev/null
+++ b/tools/write_feature_json.py
@@ -0,0 +1,411 @@
+import json
+import os
+import logging
+import numpy as np
+from typing import Dict, List, Optional, Tuple
+from tools.dataset import get_transform
+from model import resnet18
+import torch
+from PIL import Image
+import pandas as pd
+from tqdm import tqdm
+import yaml
+import shutil
+import struct
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+
+class FeatureExtractor:
+ def __init__(self, conf):
+ self.conf = conf
+ self.model = self.initModel()
+ _, self.test_transform = get_transform(self.conf)
+ pass
+
+ def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
+ """
+ Initialize and load the ResNet18 model for inference.
+
+ Args:
+ inference_model: Optional path to model weights. Uses conf.test_model if None.
+
+ Returns:
+ Loaded and configured PyTorch model in evaluation mode.
+
+ Raises:
+ FileNotFoundError: If model weights file is not found
+ RuntimeError: If model loading fails
+ """
+ model_path = inference_model if inference_model else self.conf['models']['checkpoints']
+
+ try:
+ # Verify model file exists
+ if not os.path.exists(model_path):
+ raise FileNotFoundError(f"Model weights file not found: {model_path}")
+
+ # Initialize model
+ model = resnet18().to(self.conf['base']['device'])
+
+ # Handle multi-GPU case
+ if conf['base']['distributed']:
+ model = torch.nn.DataParallel(model)
+
+ # Load weights
+ state_dict = torch.load(model_path, map_location=conf['base']['device'])
+ model.load_state_dict(state_dict)
+
+ model.eval()
+ logger.info(f"Successfully loaded model from {model_path}")
+ return model
+
+ except Exception as e:
+ logger.error(f"Failed to initialize model: {str(e)}")
+ raise
+
+ def convert_rgba_to_rgb(self, image_path):
+ # 打开图像
+ img = Image.open(image_path)
+ # 转换图像模式从RGBA到RGB
+ # .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
+ if img.mode == 'RGBA':
+ # 转换为RGB模式
+ img_rgb = img.convert('RGB')
+ # 保存转换后的图像
+ img_rgb.save(image_path)
+ print(f"Image converted from RGBA to RGB and saved to {image_path}")
+
+ def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
+ res = []
+ for img in images:
+ try:
+ im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
+ res.append(im)
+ except:
+ continue
+ data = torch.stack(res)
+ return data
+
+ def inference(self, images, model, actionModel=False):
+ data = self.test_preprocess(images, actionModel)
+ if torch.cuda.is_available():
+ data = data.to(conf['base']['device'])
+ features = model(data)
+ if conf['data']['half']:
+ features = features.half()
+ return features
+
+ def group_image(self, images, batch=64) -> list:
+ """Group image paths by batch size"""
+ size = len(images)
+ res = []
+ for i in range(0, size, batch):
+ end = min(batch + i, size)
+ res.append(images[i:end])
+ return res
+
+ def getFeatureList(self, barList, imgList):
+ featList = [[] for _ in range(len(barList))]
+
+ for index, image_paths in enumerate(imgList):
+ try:
+ # Process images in batches
+ for batch in self.group_image(image_paths):
+ # Get features for batch
+ features = self.inference(batch, self.model)
+
+ # Process each feature in batch
+ for feat in features:
+ # Move to CPU and convert to numpy
+ feat_np = feat.squeeze().detach().cpu().numpy()
+
+ # Normalize first 256 dimensions
+ normalized = self.normalize_256(feat_np[:256])
+
+ # Combine with remaining dimensions
+ combined = np.concatenate([normalized, feat_np[256:]], axis=0)
+
+ featList[index].append(combined)
+
+ except Exception as e:
+ logger.error(f"Error processing images for index {index}: {str(e)}")
+ continue
+ return featList
+
+ def get_files(
+ self,
+ folder: str,
+ filter: Optional[List[str]] = None,
+ create_single_json: bool = False
+ ) -> Dict[str, List[str]]:
+ """
+ Recursively collect image files from directory structure.
+
+ Args:
+ folder: Root directory to scan
+ filter: Optional list of barcodes to include
+ create_single_json: Whether to create individual JSON files per barcode
+
+ Returns:
+ Dictionary mapping barcode names to lists of image paths
+
+ Example:
+ {
+ "barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
+ "barcode2": ["path/to/img3.jpg"]
+ }
+ """
+ file_dicts = {}
+ total_files = 0
+ feature_counts = []
+ barcode_count = 0
+ subclass = [str(i) for i in range(100)]
+ # Validate input directory
+ if not os.path.isdir(folder):
+ raise ValueError(f"Invalid directory: {folder}")
+
+ # Process each barcode directory
+ for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
+ if not dirs: # Leaf directory (contains images)
+ basename = os.path.basename(root)
+ if basename in subclass:
+ ori_barcode = root.split('/')[-2]
+ barcode = root.split('/')[-2] + '_' + basename
+ else:
+ ori_barcode = basename
+ barcode = basename
+ # Apply filter if provided
+ if filter and ori_barcode not in filter:
+ continue
+ elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
+ logger.warning(f"Skipping invalid barcode {ori_barcode}")
+ with open(conf['save']['error_barcodes'], 'a') as f:
+ f.write(ori_barcode + '\n')
+ f.close()
+ continue
+
+ # Process image files
+ if files:
+ image_paths = self._process_image_files(root, files)
+ if not image_paths:
+ continue
+
+ # Update counters
+ barcode_count += 1
+ file_count = len(image_paths)
+ total_files += file_count
+ feature_counts.append(file_count)
+
+ # Handle output mode
+ if create_single_json:
+ self._process_single_barcode(barcode, image_paths)
+ else:
+ if barcode.split('_')[-1] == '0':
+ barcode = barcode.split('_')[0]
+ file_dicts[barcode] = image_paths
+
+ # # Log summary
+ # logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
+ # logger.debug(f"Image counts per barcode: {feature_counts}")
+
+ # Batch process if not creating individual JSONs
+ if not create_single_json and file_dicts:
+ self.createFeatureDict(
+ file_dicts,
+ create_single_json=False,
+ )
+ return file_dicts
+
+ def _process_image_files(self, root: str, files: List[str]) -> List[str]:
+ """Process and validate image files in a directory."""
+ valid_paths = []
+ for filename in files:
+ file_path = os.path.join(root, filename)
+ try:
+ # Convert RGBA to RGB if needed
+ self.convert_rgba_to_rgb(file_path)
+ valid_paths.append(file_path)
+ except Exception as e:
+ logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
+ return valid_paths
+
+ def _process_single_barcode(self, barcode: str, image_paths: List[str]):
+ """Process a single barcode and create individual JSON file."""
+ temp_dict = {barcode: image_paths}
+ self.createFeatureDict(
+ temp_dict,
+ create_single_json=True,
+ )
+
+ def normalize_256(self, queFeatList):
+ queFeatList = queFeatList / np.linalg.norm(queFeatList)
+ return queFeatList
+
+ def img2feature(
+ self,
+ imgs_dict: Dict[str, List[str]]
+ ) -> Tuple[List[str], List[List[np.ndarray]]]:
+ """
+ Extract features for all images in the dictionary.
+
+ Args:
+ imgs_dict: Dictionary mapping barcodes to image paths
+ model: Pretrained feature extraction model
+ barcode_flag: Whether to include barcode info (unused)
+
+ Returns:
+ Tuple containing:
+ - List of barcode IDs
+ - List of feature lists (one per barcode)
+
+ Raises:
+ ValueError: If input dictionary is empty
+ RuntimeError: If feature extraction fails
+ """
+ if not imgs_dict:
+ raise ValueError("No images provided for feature extraction")
+
+ try:
+ barcode_list = list(imgs_dict.keys())
+ image_list = list(imgs_dict.values())
+ feature_list = self.getFeatureList(barcode_list, image_list)
+
+ logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
+ return barcode_list, feature_list
+
+ except Exception as e:
+ logger.error(f"Feature extraction failed: {str(e)}")
+ raise RuntimeError(f"Feature extraction failed: {str(e)}")
+
+ def createFeatureDict(self, imgs_dict,
+ create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
+ dicts_all = {}
+ value_list = []
+ barcode_list, imgs_list = self.img2feature(imgs_dict)
+ for i in range(len(barcode_list)):
+ dicts = {}
+
+ imgs_list_ = []
+ for j in range(len(imgs_list[i])):
+ imgs_list_.append(imgs_list[i][j].tolist())
+
+ dicts['key'] = barcode_list[i]
+ truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
+ dicts['value'] = truncated_imgs_list
+ if create_single_json:
+ # json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
+ json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
+ with open(json_path, 'w') as json_file:
+ json.dump(dicts, json_file)
+ else:
+ value_list.append(dicts)
+ if not create_single_json:
+ dicts_all['total'] = value_list
+ with open(self.conf['save']['json_bin'], 'w') as json_file:
+ json.dump(dicts_all, json_file)
+ self.create_binary_files(self.conf['save']['json_bin'])
+
+ def statisticsBarcodes(self, pth, filter=None):
+ feature_num = 0
+ feature_num_lists = []
+ nn = 0
+ with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
+ for barcode in os.listdir(pth):
+ print("barcode length >> {}".format(len(barcode)))
+ if len(barcode) > 13 or len(barcode) < 8:
+ continue
+ if filter is not None:
+ f.writelines(barcode + '\n')
+ if barcode in filter:
+ print(barcode)
+ feature_num += len(os.listdir(os.path.join(pth, barcode)))
+ nn += 1
+ else:
+ print('barcode name >>{}'.format(barcode))
+ f.writelines(barcode + '\n')
+ feature_num += len(os.listdir(os.path.join(pth, barcode)))
+ feature_num_lists.append(feature_num)
+ print("特征总量: {}".format(feature_num))
+ print("barcode总量: {}".format(nn))
+ f.close()
+
+ def get_shop_barcodes(self, file_path):
+ if file_path:
+ df = pd.read_excel(file_path)
+ column_values = list(df.iloc[:, 6].values)
+ column_values = list(map(str, column_values))
+ return column_values
+ else:
+ return None
+
+ def del_base_dir(self, pth):
+ for root, dirs, files in os.walk(pth):
+ if len(dirs) == 1:
+ if dirs[0] == 'base':
+ shutil.rmtree(os.path.join(root, dirs[0]))
+
+ def write_binary_file(self, filename, datas):
+ with open(filename, 'wb') as f:
+ # 先写入数据中的key数量(为C++读取提供便利)
+ key_count = len(datas)
+ f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
+ for data in datas:
+ key = data['key']
+ feats = data['value']
+ key_bytes = key.encode('utf-8')
+ key_len = len(key)
+ length_byte = struct.pack(' 1 and conf['base']['distributed']:
+ print("Let's use", torch.cuda.device_count(), "GPUs!")
+ model = nn.DataParallel(model)
+ metric = nn.DataParallel(metric)
+
+# Training Setup
+if conf['training']['loss'] == 'focal_loss':
+ criterion = FocalLoss(gamma=2)
+else:
+ criterion = nn.CrossEntropyLoss()
+
+optimizer_mapping = tr_tools.get_optimizer(model, metric)
+if conf['training']['optimizer'] in optimizer_mapping:
+ optimizer = optimizer_mapping[conf['training']['optimizer']]()
+ scheduler = optim.lr_scheduler.StepLR(
+ optimizer,
+ step_size=conf['training']['lr_step'],
+ gamma=conf['training']['lr_decay']
+ )
+else:
+ raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
+
+# Checkpoints Setup
+checkpoints = conf['training']['checkpoints']
+os.makedirs(checkpoints, exist_ok=True)
+
+if __name__ == '__main__':
+ print('backbone>{} '.format(conf['models']['backbone']),
+ 'metric>{} '.format(conf['training']['metric']),
+ 'checkpoints>{} '.format(conf['training']['checkpoints']),
+ )
+ train_losses = []
+ val_losses = []
+ epochs = []
+ temp_loss = 100
+ if conf['training']['restore']:
+ print('load pretrain model: {}'.format(conf['training']['restore_model']))
+ model.load_state_dict(torch.load(conf['training']['restore_model'],
+ map_location=conf['base']['device']))
+
+ for e in range(conf['training']['epochs']):
+ train_loss = 0
+ model.train()
+
+ for train_data, train_labels in tqdm(train_dataloader,
+ desc="Epoch {}/{}"
+ .format(e, conf['training']['epochs']),
+ ascii=True,
+ total=len(train_dataloader)):
+ train_data = train_data.to(conf['base']['device'])
+ train_labels = train_labels.to(conf['base']['device'])
+
+ train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
+ # pdb.set_trace()
+
+ if not conf['training']['metric'] == 'softmax':
+ thetas = metric(train_embeddings, train_labels) # [256,357]
+ else:
+ thetas = metric(train_embeddings)
+ tloss = criterion(thetas, train_labels)
+ optimizer.zero_grad()
+ tloss.backward()
+ optimizer.step()
+ train_loss += tloss.item()
+ train_lossAvg = train_loss / len(train_dataloader)
+ train_losses.append(train_lossAvg)
+ epochs.append(e)
+ val_loss = 0
+ model.eval()
+ with torch.no_grad():
+ for val_data, val_labels in tqdm(val_dataloader, desc="val",
+ ascii=True, total=len(val_dataloader)):
+ val_data = val_data.to(conf['base']['device'])
+ val_labels = val_labels.to(conf['base']['device'])
+ val_embeddings = model(val_data).to(conf['base']['device'])
+ if not conf['training']['metric'] == 'softmax':
+ thetas = metric(val_embeddings, val_labels)
+ else:
+ thetas = metric(val_embeddings)
+ vloss = criterion(thetas, val_labels)
+ val_loss += vloss.item()
+ val_lossAvg = val_loss / len(val_dataloader)
+ val_losses.append(val_lossAvg)
+ if val_lossAvg < temp_loss:
+ if torch.cuda.device_count() > 1:
+ torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
+ else:
+ torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
+ temp_loss = val_lossAvg
+
+ scheduler.step()
+ current_lr = optimizer.param_groups[0]['lr']
+ log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
+ .format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
+ print(log_info)
+ # 写入日志文件
+ with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
+ f.write(log_info + '\n')
+ print("第%d个epoch的学习率:%f" % (e, current_lr))
+ if torch.cuda.device_count() > 1 and conf['base']['distributed']:
+ torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
+ else:
+ torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
+ plt.plot(epochs, train_losses, color='blue')
+ plt.plot(epochs, val_losses, color='red')
+ # plt.savefig('lossMobilenetv3.png')
+ plt.savefig('loss/mobilenetv3Large_2250_0316.png')
diff --git a/train_distill.py b/train_distill.py
new file mode 100644
index 0000000..24a0448
--- /dev/null
+++ b/train_distill.py
@@ -0,0 +1,205 @@
+"""
+ResNet50蒸馏训练ResNet18实现
+学生网络使用ArcFace损失
+支持单机双卡训练
+"""
+
+import os
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torch.cuda.amp import GradScaler
+from model import resnet18, resnet50, ArcFace
+from tqdm import tqdm
+import torch.nn.functional as F
+from tools.dataset import load_data
+# from config import config as conf
+import yaml
+import math
+def setup(rank, world_size):
+ os.environ['MASTER_ADDR'] = '0.0.0.0'
+ os.environ['MASTER_PORT'] = '12355'
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+def cleanup():
+ dist.destroy_process_group()
+
+class DistillTrainer:
+ def __init__(self, rank, world_size, conf):
+ self.rank = rank
+ self.world_size = world_size
+ self.device = torch.device(f'cuda:{rank}')
+
+ # 初始化模型
+ self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
+ self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
+
+ # 加载预训练教师模型
+ # teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
+ teacher_path = conf['models']['teacher_model_path']
+ if os.path.exists(teacher_path):
+ teacher_state = torch.load(teacher_path, map_location=self.device)
+ new_state_dict = {}
+ for k, v in teacher_state.items():
+ if k.startswith('module.'):
+ new_state_dict[k[7:]] = v # 去除前7个字符'module.'
+ else:
+ new_state_dict[k] = v
+ # 加载处理后的状态字典
+ self.teacher.load_state_dict(new_state_dict, strict=False)
+
+ if self.rank == 0:
+ print(f"Successfully loaded teacher model from {teacher_path}")
+ else:
+ raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
+
+ # 数据加载
+ self.train_loader, num_classes = load_data(training=True, cfg=conf)
+ self.val_loader, _ = load_data(training=False, cfg=conf)
+
+ # ArcFace损失
+ self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
+
+ # 分布式训练
+ if world_size > 1:
+ self.teacher = DDP(self.teacher, device_ids=[rank])
+ self.student = DDP(self.student, device_ids=[rank])
+ self.metric = DDP(self.metric, device_ids=[rank])
+
+ # 优化器
+ self.optimizer = torch.optim.SGD([
+ {'params': self.student.parameters()},
+ {'params': self.metric.parameters()}
+ ], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
+
+ self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
+ self.scaler = GradScaler()
+
+ # 损失函数
+ self.arcface_loss = nn.CrossEntropyLoss()
+ self.distill_loss = nn.KLDivLoss(reduction='batchmean')
+ self.conf = conf
+
+ def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
+ """
+ 余弦退火法动态调整蒸馏权重
+ 参数:
+ epoch: 当前训练轮次
+ total_epochs: 总训练轮次
+ initial_weight: 初始蒸馏权重(如0.8)
+ final_weight: 最终蒸馏权重(如0.1)
+ 返回:
+ 当前轮次的蒸馏权重
+ """
+ return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
+ def train_epoch(self, epoch):
+ self.teacher.eval()
+ self.student.train()
+
+ if self.rank == 0:
+ print(f"\nTeacher network type: {type(self.teacher)}")
+ print(f"Student network type: {type(self.student)}")
+
+ total_loss = 0
+ for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
+ data = data.to(self.device)
+ labels = labels.to(self.device)
+
+ # with autocast():
+ # 教师输出
+ with torch.no_grad():
+ teacher_logits = self.teacher(data)
+
+ # 学生输出
+ student_features = self.student(data)
+ student_logits = self.metric(student_features, labels)
+
+ # 计算损失
+ arc_loss = self.arcface_loss(student_logits, labels)
+ distill_loss = self.distill_loss(
+ F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
+ F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
+ ) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
+ current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
+ loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
+
+ self.optimizer.zero_grad()
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ total_loss += loss.item()
+
+ self.scheduler.step()
+ return total_loss / len(self.train_loader)
+
+ def validate(self):
+ self.student.eval()
+ total_loss = 0
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for data, labels in self.val_loader:
+ data = data.to(self.device)
+ labels = labels.to(self.device)
+
+ features = self.student(data)
+ logits = self.metric(features, labels)
+
+ loss = self.arcface_loss(logits, labels)
+ total_loss += loss.item()
+
+ _, predicted = torch.max(logits.data, 1)
+ total += labels.size(0)
+ correct += (predicted == labels).sum().item()
+
+ return total_loss / len(self.val_loader), correct / total
+
+ def save_checkpoint(self, epoch, is_best=False):
+ if self.rank != 0:
+ return
+
+ state = {
+ 'epoch': epoch,
+ 'student_state_dict': self.student.state_dict(),
+ 'metric_state_dict': self.metric.state_dict(),
+ 'optimizer_state_dict': self.optimizer.state_dict(),
+ }
+
+ filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
+ if not os.path.exists(self.conf['training']['checkpoints']):
+ os.makedirs(self.conf['training']['checkpoints'])
+ if filename != 'best.pth':
+ torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
+ else:
+ torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
+
+def train(rank, world_size):
+ setup(rank, world_size)
+ with open('configs/distill.yml', 'r') as f:
+ conf = yaml.load(f, Loader=yaml.FullLoader)
+ trainer = DistillTrainer(rank, world_size, conf)
+ best_acc = 0
+ for epoch in range(conf['training']['epochs']):
+ train_loss = trainer.train_epoch(epoch)
+ val_loss, val_acc = trainer.validate()
+
+ if rank == 0:
+ print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
+
+ if val_acc > best_acc:
+ best_acc = val_acc
+ trainer.save_checkpoint(epoch, is_best=True)
+
+ cleanup()
+
+if __name__ == '__main__':
+ world_size = torch.cuda.device_count()
+ if world_size > 1:
+ mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
+ else:
+ train(0, 1)