|10(t
zn|MR1X4`?$(L#jzn<$^`(r}V&z
zGEsJm0_UL~1a~{iZ0TKiPOa41!&sTb%6xBLG#usC%8?#KtsKV67(N<8Hac({(NHuR
zjYPREgsUVs9S0D=_QE%^6v`~RFj&%0Z$p(tA>@ok=$XJDQTP$sg;puNMf-H|=lbin
z@yUVp%l-B0scOA5u>MWp$MO6Tq=)LwRIk6+^;g4Gl#0?+014Ab8j~>miOAWs`%>JK
z%qQ>u6hF}hdzB;?b@3YYBWg&37q8QdWZ9t1b8>6pm*sitFXXH3W|ZCt3yAQJj+JK1VRN%Hil>NxFGJ00i9k+X#g
zO}C#yAIcUvWZJf1%&=ny!UD8ug)n
z(6u&lI@E_UQEs3O<404^4+z62DC{*TYcAp)l7OrrMLbBVHphCH3q@
zXGD}@KF;gO?udKPNV5?%rL_v*FHPh>(ndJY+z>C*1d1=?^mi}yNUl`OqzE1hGWhrRB5L|zd20U{B8;RKA^!g8jbv>LE&Iramr$xr_fME|J~}7
zku%42@jhnnu-8cd$eSv`BJxm40v7^lh|U}p9&n#Rh=^dC#EirUOz-v}0ttwp8G$`0
zHH}@*BBeKxs+*he!C)5x={QTUW8(iF9wC08nkh7>EP3n2q^oLmm*8t=x4U!h5WNklx0E2#mIoGpf=Gij-T%SGTJIcP#M@OclRN%>^{l
zQR;=#!yQAMB`6gOc5^g+hJsjCBlr=}0Tm%sga)tD
z6^G)aj>?en%f$6BRpNxmp*H#gP0y)U>=$FCtrC6l94%AgoKTgMK#7w~@b!@Dp=4@`
zLvn;F$=rvJ!3hwrdrS%V<(>uEa2?3|C30c|&MBiR2PhdZ&SDLS4@jG+G_qTbRZyw3
zSP#gY>?Ym#XH)27ObzJeHI-CSB59~NJd1cdANPz=;GidPSqDt_mZCWE&*9Jj{EC{N2K>Lm
z+yMN7f@f4zmkWmnFw<(Mtinu%`2kEm9AQJ9QNq)wsW4y(5@HNhCb|!OgBA&tS;pE#$_|A
zM->h6C&0TSEBNjemK46EFuN^7y6WSjrrP%>SP@)8qnUB`O&NIn4eRIKv%upr^2iJ)
zj5wr4%+c@V4DliMlTN8rM_sd2`b@u}YzLJAeu3W*9n2Bmrl!zu4pZ+B(Uf3(`aVHH
zTG=Hqy6_^E$L9;JO#10=e2#%Rb8@AXO>gHd1*xY6MU^dkEnIo{H0njDBOw#7`jyrm
zSm$kHT{wGI324t`$K#@YvA8IWMj(@oV1~YF(uYPAbja69v&oj@(y^mHK+X6*N~eAt
z4iV9%x%m7g_1y*mY|!w9380c!#Y&}CN?~^Eo6@Au5Yt%Clt_*sccMZ;t-6damvP-E
zz6(`INRzn1sZ`5>o8$Fm&gj(2`MD0l+1L##^%$eKIzr7+YRGHK#1i)dNc&?1IYSUj
zMb6|mxKlSBJ3Ny8ZV0Y?-+i?7kI~44
znh=O@668;*`3^N7DBM~T4wQ>6g1kqK3ayno{tbq1?C9f15Xn(-$47p80hL!si-zK_
zY3285etKI3TswBwO2OZt8HJq*6=z5x
HbNc@Q8LGrc
literal 0
HcmV?d00001
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..d0fc3f0
--- /dev/null
+++ b/config.py
@@ -0,0 +1,122 @@
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as F
+
+
+def pad_to_square(img):
+ w, h = img.size
+ max_wh = max(w, h)
+ padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
+ return F.pad(img, padding, fill=0, padding_mode='constant')
+
+
+class Config:
+ # network settings
+ backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large,
+ # mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5, vit_base]
+ metric = 'arcface' # [cosface, arcface, softmax]
+ cbam = False
+ embedding_size = 256 # 256 # gift:2 contrast:256
+ drop_ratio = 0.5
+ img_size = 224
+ multiple_cards = True # 多卡加载
+ model_half = False # 模型半精度测试
+ data_half = True # 数据半精度测试
+ channel_ratio = 0.75 # 通道剪枝比例
+ quantization_test = False # int8量化模型测试
+
+ # custom base_data settings
+ custom_backbone = False # 迁移学习载入除最后一层的所有层
+ custom_num_classes = 128 # 迁移学习的类别数量
+
+ # if quantization_test:
+ # device = torch.device('cpu')
+ # else:
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ teacher = 'vit' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1,
+ # PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
+
+ student = 'resnet'
+ # data preprocess
+ """transforms.RandomCrop(size),
+ transforms.RandomVerticalFlip(p=0.5),
+ transforms.RandomHorizontalFlip(),
+ RandomRotate(15, 0.3),
+ # RandomGaussianBlur()"""
+ train_transform = T.Compose([
+ T.Lambda(pad_to_square), # 补边
+ T.ToTensor(),
+ T.Resize((img_size, img_size), antialias=True),
+ # T.RandomCrop(img_size * 4 // 5),
+ T.RandomHorizontalFlip(p=0.5),
+ T.RandomRotation(180),
+ T.ColorJitter(brightness=0.5),
+ T.ConvertImageDtype(torch.float32),
+ T.Normalize(mean=[0.5], std=[0.5]),
+ ])
+ test_transform = T.Compose([
+ # T.Lambda(pad_to_square), # 补边
+ T.ToTensor(),
+ T.Resize((img_size, img_size), antialias=True),
+ T.ConvertImageDtype(torch.float32),
+ # T.Normalize(mean=[0,0,0], std=[255,255,255]),
+ T.Normalize(mean=[0.5], std=[0.5]),
+ ])
+
+ # dataset
+ train_root = '../data_center/scatter/train' # ['./data/2250_train/base_data', # './data/2000_train/base_data', './data/zhanting/base_data', './data/base_train/one_stage/train']
+ test_root = '../data_center/scatter/val' # ["./data/2250_train/val", "./data/2000_train/val/", './data/zhanting/val', './data/base_train/one_stage/val']
+
+ # training settings
+ checkpoints = "checkpoints/resnet18_scatter_6.2/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
+ restore = True
+ # restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
+ restore_model = "checkpoints/resnet18_scatter_6.2/best.pth" # best_resnet18_1491_0306.pth
+
+ # test settings
+ testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
+
+ # test_val = "./data/2250_train"
+ # test_list = "./data/2250_train/val_pair.txt"
+ # test_group_json = "./data/2250_train/cross_same.json"
+
+ test_val = "../data_center/scatter/" # [../data_center/contrast_learning/model_test_data/val_2250]
+ test_list = "../data_center/scatter/val_pair.txt" # [./data/test/public_single_pairs.txt]
+ test_group_json = "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" # [./data/2250_train/cross_same.json]
+ # test_group_json = "./data/test/inner_group_pairs.json"
+
+ # test_model = "checkpoints/resnet18_scatter_6.2/best.pth"
+ test_model = "checkpoints/resnet18_1009/best.pth"
+ # test_model = "checkpoints/zhanting/inland/res_801.pth"
+ # test_model = "checkpoints/resnet18_20250504/best.pth"
+ # test_model = "checkpoints/resnet18_vit-base_20250430/best.pth"
+ group_test = False
+ # group_test = False
+
+ train_batch_size = 128 # 256
+ test_batch_size = 128 # 256
+
+ epoch = 5 # 512
+ optimizer = 'sgd' # ['sgd', 'adam', 'adamw']
+ lr = 5e-3 # 1e-2
+ lr_step = 10 # 10
+ lr_decay = 0.98 # 0.98
+ weight_decay = 5e-4
+ loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
+ log_path = './log'
+ lr_min = 1e-6 # min lr
+
+ pin_memory = False # if memory is large, set it True to speed up a bit
+ num_workers = 32 # 64
+ compare = False # compare the result of different models
+
+ '''
+ train_distill settings
+ '''
+ warmup_epochs = 3 # warmup_epoch
+ distributed = True # distributed training
+ teacher_path = "./checkpoints/resnet50_0519/best.pth"
+ distill_weight = 0.8 # 蒸馏权重
+
+config = Config()
\ No newline at end of file
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100644
index 0000000..18ef40f
--- /dev/null
+++ b/configs/__init__.py
@@ -0,0 +1 @@
+from .utils import trainer_tools
\ No newline at end of file
diff --git a/configs/compare.yml b/configs/compare.yml
new file mode 100644
index 0000000..a03261a
--- /dev/null
+++ b/configs/compare.yml
@@ -0,0 +1,69 @@
+# configs/compare.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
+ seed: 42 # 随机种子(保证可复现性)
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 0.75
+
+# 训练参数
+training:
+ epochs: 600 # 总训练轮次
+ batch_size: 128 # 批次大小
+ lr: 0.001 # 初始学习率
+ optimizer: "sgd" # 优化器类型
+ metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
+ loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
+ lr_step: 10 # 学习率调整间隔(epoch)
+ lr_decay: 0.98 # 学习率衰减率
+ weight_decay: 0.0005 # 权重衰减
+ scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
+ num_workers: 32 # 数据加载线程数
+ checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
+ restore: false
+ restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
+
+# 验证参数
+validation:
+ num_workers: 32 # 数据加载线程数
+ val_batch_size: 128 # 测试批次大小
+
+# 数据配置
+data:
+ dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
+ train_batch_size: 128 # 训练批次大小
+ val_batch_size: 128 # 验证批次大小
+ num_workers: 32 # 数据加载线程数
+ data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
+ data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./logs" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+# 分布式训练(可选)
+distributed:
+ enabled: false # 是否启用分布式训练
+ backend: "nccl" # 分布式后端(nccl/gloo)
diff --git a/configs/distill.yml b/configs/distill.yml
new file mode 100644
index 0000000..8332c16
--- /dev/null
+++ b/configs/distill.yml
@@ -0,0 +1,75 @@
+# configs/compare.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
+ seed: 42 # 随机种子(保证可复现性)
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 1.0 # 主干特征通道缩放比例(默认)
+ student_channel_ratio: 0.75
+ teacher_model_path: "./checkpoints/resnet50_0519/best.pth"
+
+# 训练参数
+training:
+ epochs: 600 # 总训练轮次
+ batch_size: 128 # 批次大小
+ lr: 0.001 # 初始学习率
+ optimizer: "sgd" # 优化器类型
+ metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
+ loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
+ lr_step: 10 # 学习率调整间隔(epoch)
+ lr_decay: 0.98 # 学习率衰减率
+ weight_decay: 0.0005 # 权重衰减
+ scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
+ num_workers: 32 # 数据加载线程数
+ checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
+ restore: false
+ restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
+ distill_weight: 0.8 # 蒸馏损失权重
+ temperature: 4 # 蒸馏温度
+
+
+
+# 验证参数
+validation:
+ num_workers: 32 # 数据加载线程数
+ val_batch_size: 128 # 测试批次大小
+
+# 数据配置
+data:
+ dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
+ train_batch_size: 128 # 训练批次大小
+ val_batch_size: 100 # 验证批次大小
+ num_workers: 4 # 数据加载线程数
+ data_train_dir: "../data_center/contrast_learning/data_base/train" # 训练数据集根目录
+ data_val_dir: "../data_center/contrast_learning/data_base/val" # 验证数据集根目录
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./logs" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+# 分布式训练(可选)
+distributed:
+ enabled: false # 是否启用分布式训练
+ backend: "nccl" # 分布式后端(nccl/gloo)
diff --git a/configs/scatter.yml b/configs/scatter.yml
new file mode 100644
index 0000000..7612e64
--- /dev/null
+++ b/configs/scatter.yml
@@ -0,0 +1,69 @@
+# configs/scatter.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 1.0
+
+# 训练参数
+training:
+ epochs: 300 # 总训练轮次
+ batch_size: 64 # 批次大小
+ lr: 0.005 # 初始学习率
+ optimizer: "sgd" # 优化器类型
+ metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
+ loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
+ lr_step: 10 # 学习率调整间隔(epoch)
+ lr_decay: 0.98 # 学习率衰减率
+ weight_decay: 0.0005 # 权重衰减
+ scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
+ num_workers: 32 # 数据加载线程数
+ checkpoints: "./checkpoints/resnet18_scatter_6.2/" # 模型保存目录
+ restore: True
+ restore_model: "checkpoints/resnet18_scatter_6.2/best.pth" # 模型恢复路径
+
+
+
+# 验证参数
+validation:
+ num_workers: 32 # 数据加载线程数
+ val_batch_size: 128 # 测试批次大小
+
+# 数据配置
+data:
+ dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
+ train_batch_size: 128 # 训练批次大小
+ val_batch_size: 100 # 验证批次大小
+ num_workers: 32 # 数据加载线程数
+ data_train_dir: "../data_center/scatter/train" # 训练数据集根目录
+ data_val_dir: "../data_center/scatter/val" # 验证数据集根目录
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./log/2025.6.2-scatter.txt" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+# 分布式训练(可选)
+distributed:
+ enabled: false # 是否启用分布式训练
+ backend: "nccl" # 分布式后端(nccl/gloo)
diff --git a/configs/test.yml b/configs/test.yml
new file mode 100644
index 0000000..cb10797
--- /dev/null
+++ b/configs/test.yml
@@ -0,0 +1,41 @@
+# configs/test.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ pin_memory: true # 是否启用pin_memory
+ distributed: true # 是否启用分布式训练
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 1.0
+ model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
+ half: false # 是否启用半精度测试(fp16)
+
+# 数据配置
+data:
+ group_test: False # 数据集名称(示例用,可替换为实际数据集)
+ test_batch_size: 128 # 训练批次大小
+ num_workers: 32 # 数据加载线程数
+ test_dir: "../data_center/scatter/" # 验证数据集根目录
+ test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
+ test_list: "../data_center/scatter/val_pair.txt"
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+save:
+ save_dir: ""
+ save_name: ""
+
+
diff --git a/configs/utils.py b/configs/utils.py
new file mode 100644
index 0000000..899294f
--- /dev/null
+++ b/configs/utils.py
@@ -0,0 +1,56 @@
+from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
+ PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
+from timm.models import vit_base_patch16_224 as vit_base_16
+from model.metric import ArcFace, CosFace
+import torch.optim as optim
+import torch.nn as nn
+import timm
+
+
+class trainer_tools:
+ def __init__(self, conf):
+ self.conf = conf
+
+ def get_backbone(self):
+ backbone_mapping = {
+ 'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
+ 'mobilevit_s': lambda: mobilevit_s(),
+ 'mobilenetv3_small': lambda: MobileNetV3_Small(),
+ 'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
+ 'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
+ 'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
+ 'mobilenetv3_large': lambda: MobileNetV3_Large(),
+ 'vit_base': lambda: vit_base_16(pretrained=True),
+ 'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
+ num_classes=self.conf.embedding_size)
+ }
+ return backbone_mapping
+
+ def get_metric(self, class_num):
+ # 优化后的metric选择代码块,使用字典映射提高可读性和扩展性
+ metric_mapping = {
+ 'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
+ 'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
+ 'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
+ }
+ return metric_mapping
+
+ def get_optimizer(self, model, metric):
+ optimizer_mapping = {
+ 'sgd': lambda: optim.SGD(
+ [{'params': model.parameters()}, {'params': metric.parameters()}],
+ lr=self.conf['training']['lr'],
+ weight_decay=self.conf['training']['weight_decay']
+ ),
+ 'adam': lambda: optim.Adam(
+ [{'params': model.parameters()}, {'params': metric.parameters()}],
+ lr=self.conf['training']['lr'],
+ weight_decay=self.conf['training']['weight_decay']
+ ),
+ 'adamw': lambda: optim.AdamW(
+ [{'params': model.parameters()}, {'params': metric.parameters()}],
+ lr=self.conf['training']['lr'],
+ weight_decay=self.conf['training']['weight_decay']
+ )
+ }
+ return optimizer_mapping
diff --git a/configs/write_feature.yml b/configs/write_feature.yml
new file mode 100644
index 0000000..fdf7d77
--- /dev/null
+++ b/configs/write_feature.yml
@@ -0,0 +1,47 @@
+# configs/write_feature.yml
+# 专为模型训练对比设计的配置文件
+# 支持对比不同训练策略(如蒸馏vs独立训练)
+
+# 基础配置
+base:
+ device: "cuda" # 训练设备(cuda/cpu)
+ log_level: "info" # 日志级别(debug/info/warning/error)
+ embedding_size: 256 # 特征维度
+ distributed: true # 是否启用分布式训练
+ pin_memory: true # 是否启用pin_memory
+
+# 模型配置
+models:
+ backbone: 'resnet18'
+ channel_ratio: 0.75
+ checkpoints: "../checkpoints/resnet18_1009/best.pth"
+
+# 数据配置
+data:
+ train_batch_size: 128 # 训练批次大小
+ test_batch_size: 128 # 验证批次大小
+ num_workers: 32 # 数据加载线程数
+ half: true # 是否启用半精度数据
+ img_dirs_path: "/shareData/temp_data/comparison/Hangzhou_Yunhe/base_data/05-09"
+# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
+ xlsx_pth: false # 过滤商品, 默认None不进行过滤
+
+transform:
+ img_size: 224 # 图像尺寸
+ img_mean: 0.5 # 图像均值
+ img_std: 0.5 # 图像方差
+ RandomHorizontalFlip: 0.5 # 随机水平翻转概率
+ RandomRotation: 180 # 随机旋转角度
+ ColorJitter: 0.5 # 随机颜色抖动强度
+
+# 日志与监控
+logging:
+ logging_dir: "./logs" # 日志保存目录
+ tensorboard: true # 是否启用TensorBoard
+ checkpoint_interval: 30 # 检查点保存间隔(epoch)
+
+save:
+ json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
+ json_path: "../data/feature_json_compare/" # 保存单个json文件
+ error_barcodes: "error_barcodes.txt"
+ barcodes_statistics: "../search_library/barcodes_statistics.txt"
\ No newline at end of file
diff --git a/model/BAM.py b/model/BAM.py
new file mode 100644
index 0000000..4ac61ae
--- /dev/null
+++ b/model/BAM.py
@@ -0,0 +1,88 @@
+import torch.nn as nn
+import torchvision
+from torch.nn import init
+
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.shape[0], -1)
+
+
+class ChannelAttention(nn.Module):
+ def __int__(self, channel, reduction, num_layers):
+ super(ChannelAttention, self).__init__()
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ gate_channels = [channel]
+ gate_channels += [len(channel) // reduction] * num_layers
+ gate_channels += [channel]
+
+ self.ca = nn.Sequential()
+ self.ca.add_module('flatten', Flatten())
+ for i in range(len(gate_channels) - 2):
+ self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
+ self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
+ self.ca.add_module('', nn.ReLU())
+ self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
+
+ def forward(self, x):
+ res = self.avgpool(x)
+ res = self.ca(res)
+ res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
+ return res
+
+
+class SpatialAttention(nn.Module):
+ def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
+ super(SpatialAttention).__init__()
+ self.sa = nn.Sequential()
+ self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
+ self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
+ self.sa.add_module('', nn.ReLU())
+ for i in range(num_lay):
+ self.sa.add_module('', nn.Conv2d(kernel_size=3,
+ in_channels=(channel // reduction),
+ out_channels=(channel // reduction),
+ padding=1,
+ dilation=2))
+ self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
+ self.sa.add_module('', nn.ReLU())
+ self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
+
+ def forward(self, x):
+ res = self.sa(x)
+ res = res.expand_as(x)
+ return res
+
+
+class BAMblock(nn.Module):
+ def __init__(self, channel=512, reduction=16, dia_val=2):
+ super(BAMblock, self).__init__()
+ self.ca = ChannelAttention(channel, reduction)
+ self.sa = SpatialAttention(channel, reduction, dia_val)
+ self.sigmoid = nn.Sigmoid()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal(m.weight, mode='fan_out')
+ if m.bais is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ sa_out = self.sa(x)
+ ca_out = self.ca(x)
+ weight = self.sigmoid(sa_out + ca_out)
+ out = (1 + weight) * x
+ return out
+
+
+if __name__ == "__main__":
+ print(512 // 14)
diff --git a/model/CBAM.py b/model/CBAM.py
new file mode 100644
index 0000000..69747e0
--- /dev/null
+++ b/model/CBAM.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+class channelAttention(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(channelAttention, self).__init__()
+ self.Maxpooling = nn.AdaptiveMaxPool2d(1)
+ self.Avepooling = nn.AdaptiveAvgPool2d(1)
+ self.ca = nn.Sequential()
+ self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
+ self.ca.add_module('Relu', nn.ReLU())
+ self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
+ self.sigmod = nn.Sigmoid()
+
+ def forward(self, x):
+ M_out = self.Maxpooling(x)
+ A_out = self.Avepooling(x)
+ M_out = self.ca(M_out)
+ A_out = self.ca(A_out)
+ out = self.sigmod(M_out+A_out)
+ return out
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ max_result, _ = torch.max(x, dim=1, keepdim=True)
+ avg_result = torch.mean(x, dim=1, keepdim=True)
+ result = torch.cat([max_result, avg_result], dim=1)
+ output = self.conv(result)
+ output = self.sigmoid(output)
+ return output
+
+class CBAM(nn.Module):
+ def __init__(self, channel, reduction=16, kernel_size=7):
+ super().__init__()
+ self.ca = channelAttention(channel, reduction)
+ self.sa = SpatialAttention(kernel_size)
+
+ def init_weights(self):
+ for m in self.modules():#权重初始化
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ # b,c_,_ = x.size()
+ # residual = x
+ out = x*self.ca(x)
+ out = out*self.sa(out)
+ return out
+
+if __name__ == '__main__':
+ input=torch.randn(50,512,7,7)
+ kernel_size=input.shape[2]
+ cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
+ output=cbam(input)
+ print(output.shape)
diff --git a/model/Tool.py b/model/Tool.py
new file mode 100644
index 0000000..3c65931
--- /dev/null
+++ b/model/Tool.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GeM(nn.Module):
+ def __init__(self, p=3, eps=1e-6):
+ super(GeM, self).__init__()
+ self.p = nn.Parameter(torch.ones(1) * p)
+ self.eps = eps
+
+ def forward(self, x):
+ return self.gem(x, p=self.p, eps=self.eps, stride=2)
+
+ def gem(self, x, p=3, eps=1e-6, stride=2):
+ return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
+ ', ' + 'eps=' + str(self.eps) + ')'
+
+
+class TripletLoss(nn.Module):
+ def __init__(self, margin):
+ super(TripletLoss, self).__init__()
+ self.margin = margin
+
+ def forward(self, anchor, positive, negative, size_average=True):
+ distance_positive = (anchor - positive).pow(2).sum(1)
+ distance_negative = (anchor - negative).pow(2).sum(1)
+ losses = F.relu(distance_negative - distance_positive + self.margin)
+ return losses.mean() if size_average else losses.sum()
+
+
+if __name__ == '__main__':
+ print('')
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..fef1029
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,14 @@
+from .fmobilenet import FaceMobileNet
+# from .resnet_face import ResIRSE
+from .mobilevit import mobilevit_s
+from .metric import ArcFace, CosFace
+from .loss import FocalLoss
+from .resbam import resnet
+from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
+from .mobilenet_v2 import mobilenet_v2
+from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
+# from .mobilenet_v1 import mobilenet_v1
+from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, \
+ PPLCNET_x2_5
+from .vit import vit_base
+from .mlp import MLP
\ No newline at end of file
diff --git a/model/__pycache__/CBAM.cpython-38.pyc b/model/__pycache__/CBAM.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb7929fe6eb133670914d4c8af2c780336eb539a
GIT binary patch
literal 3018
zcmbVOOK%)S5bo}I?c0w8;)GXtBrp;(At#6s#Bl^1Vk^Srz{qGc*>2mj-ksg_jGbVu
zPmYw6C61hs%>l`Oh6BgGa`KTY;=otEv)*+c7iP7!)z#g#_4ul49yFUZhUeN3H#YvL
zGWHiOrXL@R+xX_cfk-BKmqnvr-V>4FBi`u=!8q%=k;~aPObY2dWztci;~hAW4~Z*1
zNW4-KK;p{)lAx4SR8>~Ch5l8sqnfPBnykzEu^ZJ@z0;75%PeZD3ZzZdl&wQ2YN?j0
zAMj{K&8nGBF_%Ri
zbOmO%y_IB{>Mj+9%8Im~ozCKV+nLf#JOan)H~$%=U@r@vpY4A7b;ynTFimpfey6%SBk9d>#(4S8Kxr-LXeEx*EG^>Lh%7U6OEMW0
z>8@Hyp4{*EyEmn&fgONNv(3r&((dMXdudmVw`*(a(GEo~=^D{aOf8Wz?)BwPR~dg9
zCP5ucx%a>XYw2dMpGxEB#K_PwZmzl;CMeTjs#?jNHYMVOn`#`lyGfqM@d^7UynAc?
zzV3HayI6;xg--G!&J^;N7wgcay6ek#mR1%AKWPFy!8)g0GT!11d}{nEpT~Fa^7AQQ
z9LJ(fp&36O7{b&tzWFVXAv;1Ej`_X-hctq7z=@$KWn$=*GG_?>62rrAb19;9QJ!sO-T+gu1Ej`QrRX@Q7(w+W
zIqJ)l*wY)=TUF?f}2QtKgqCv8K7i$8jyD1=Z55>L*7`sCcaP`4?L*H^fI*Yez
zL-Tzoj3d*Y3AU9QfB~o|-l1>2qOaRq#_1(bjN4O5W}J3X=tb=6Pv}T6a{L|eQ4ObJ
zt@532VR#(XlHJYm(l5n6qCR*5{1%dLLq!Svi+mQvOyJ*Jcp*{~;D1C`yhmhGOB48h
zh>Zo(S;E)TBzZM_$)~fyQniBxJ^U#R7Ek(5S>(zfn?iNeN?zd6}xh1$~~%-x2m|+E%$tBXW}bP)^M@{{g|WBkEoy1kDc@S%jt;L
znPk*gX)@BN(Bzjz_U6AQfB(7i@cviWa`Hlk4HJw;S8i%)o@RNGWNoFViks%9nb}U#
z9y&mr^>r`l#>RiF(#@?xQw^Ic8US4^VpG496z#25oWd0J@1>bavvWmhn~Hw94J
z>>GAEk96kLG;+!%vy*HVu%89jvXoYPb2e?$cg{6R0vG{r8Q=UW2*8Ep(G0uDFwaI4
z4RdU?8iP!XCKV08(3zC6CJoOSYzveQj=Xw&`;Mvhl$X_C=!>m_L4@##%0QwFr!g3}?E_MhRx`FaTFG@UgS
z+rk#J%w{=ij+cuRPkoz?_$&5x#?Cj}CG4MCCyb|)OlHPQGxROIWNjsQ%c7KF4m?bH
v`ZHp^fll$xGI8mw{hP`cBfx!0pv>vDpn~Mii}?%lt@$@9NOBFqtnmK;UrCm{
literal 0
HcmV?d00001
diff --git a/model/__pycache__/Tool.cpython-38.pyc b/model/__pycache__/Tool.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aadd9d1255b6132c7589ee942e6b5f6f3b17135
GIT binary patch
literal 2081
zcmZuy&2Jk;6rYdX_1bX)O+u5Z0-+wD3zZLu1FG7BR75#YD^kN{6|q|FjNLW+VP@7T
z80VBoJw;p*C;QkN7ybqQh`Dk=IUsRE%Aw}H**JCr&w6j(zMYwUAHVnf_4;~?0NwfJ
zr{2#FA%Eb*DuIoB2CTmUA&8(E3G0>SEM&CSdCnN2BzHoWk|#tk!B2?bk`26L9yY+@
z2p25wg~gXm;eE?gC|ja&OyLj0FA}e1miol&ge3An8v)ND@*(LMGZTNskIHoRenA
zg)6)h5;|gC_@a42LszuK8d$uS#35k=Zr?EciF|s-;f_B~|N3qE&+orJO7LkxsJpJEsCRCGd)*_6#7xQb7m%AI4dzG>xlAL#Y&lC*wy^
zTBKDJ8CDdgxfiQAmz7k;txAJFtTX&UHQ+e1_COH2L2uG=8&c3&
zc$A=n3ql5bBryArLC`j1Jg@3wlbYe8R9U
zQ%5XTKBGN`z&qtrF1!i54K;%Z8r^^zo~Omx>-l_s20XW)3#i(Ld;9el|Nb;^4gJEj
z;=_J4gp__Hj0<(hp;E*01=J}W%fPh(D$Lh1YP>*bt=!;1W>kK?U--ve6hv#8qs(s?1Vd84hK(
zpoC0eU9W|Rmnb<@QG_Md?|>lGqioJBZ_gQHuWR=0xMtL&pvs7qMp6_EkE}Wc5|4^?w=;uSC|n
zAg+!ZY`uuv!9u%%qtHq9AqwniPHdiZ7Q_
zVE;;+VH$A%V?4MA0{mdRWC{~{MyC+yj2^jQ~QW>NaYQlG(_HW1L*Z1pw5J5HBv0=cO|Ku+(Xl9!ldE
zvL9Qmjjb5Phf>9T8Qu_pRoE?|YwN8mrpq&42KA771rs)B2%xutEgT5lw!iV#YThm;
zR6zH^djzGriy}aQDX&5J0c!sTI4mJn+g8P=ZO0ZBdnb?60$@#oHHH7l#nus=v@xXO!B=SLkrgnPKMMVLmxdUZZpT%UukXUk4{f$>ydLlbG-}yW+A7m{IKdcnubt23enjQz(&D1vq>GT$qgvmcka70Ir4*mL{@phl8R&^>xp$wSFC@sVuQ|#bx)$=@F56&q?XjYdrCFzr=C~2(QtfrExvX{Z?GZzgUC)qH-oQ|!LH~eQoAA<
z>utn(VnCNQE=H4a`xx@K>su?{5hy&pU)<7^l)+AJJxi4)9c9rA)xuR75-Q71TGAAi
z$;mv(n;9yLMgD)Sd*!JdXY-D#0)x%qFcb+OIlI0pv0r9CBVC%Os<2UIs4;9YY%|mu
zb{HB2V4%ELO=3Sa)nb?7?~4zwrxWFW@!W9g`pZzwf>2P|&u(#!5x>1RpQ0#HvYj|_>$WNTpcUJBIB*@+c9pu7S;uLRf^d5{aaR)QWzs121wyw*`wg(w}pVG!jLg{oA0
zsZfonZF8hGERg8b0ErPxY>=4L0*Mt%G9a;O1|*qSk_AbY=0K8*B{`7fX#pgKSds@x
zkuHH`DV7vik(QpSemD}sA
z!piLnv~s9o4PCpaz*8AtWEI(>K-Wj4afw|TzcGQ?Ld}~K<*!BHF0l&KT>e6NPx(}N
zto%-4Z?Vhmw?_mIVQ3a_R87#%R*5bN4=x;lY?{=8uh^*s4M|}_ZyyHCUyIxeobI^xWBSjFG9Q(EA
z1K32qFIW)`fmiSYAc+<217(*;+-1O)5^W9r0Ib$9z&IIU-s}3$@A^}9hF0Rs3BlJU
z4m!V(M-4S7!TUf$Wu%M=Vjxt|H08OLdQZV)(5WNT^W
z3e**#1Tq+qEYZmtvA-THMK-FFM*7kNDyJE;n0|f
z_&IEf@h~xJsyZ*>2ht^pbEYQGqZIYzMI@-TgpMf0lWH>0n@>!R`8GC@Q-t6qofsfm
zY_(`{lzdBA%0&}+#F^JWq5A@oP%nw34PZ_n=@^bi<|2mpci^L{#nF1v{MV3MM}qSa
z#-qn<`}4^3Z=h6KK-wblH2Njw=@NQ+H8q3(+1*kr{#_jIEO(>cb9b+#$2!;N{3`Z&
zgq|aJejh3>V4sOQQSvRhGjpHy!u>B=0R+k!`ifV9%rF>9P%ij1i4MPlCa3sR#ssQS
z{m=;Yp$YdqRv3q72r*+Qqkhz4q9qZ-te-&BRO4@=7MGE{1!St95Xg0Y213w85+>*4
z;~i)g;965;Bq)(GDg9?qPK@&Zz{fQlGtMH*M0I{|L!Zq&^cJI0KyQ933`C
z$=T?@ev`yOAYEi+@-V3RGSkqhMQupPhD;tN1oEuNCn`Q>!C*akyYiwH?q2~}i$Zqe
z$CX}y4MuN-SAa_=4V!d28JPeF4HHPT`1RD{$61~4OH5bN3h2v0B~cVqd<{2Zn4L%U
zAz7d6aQ;&N2#1Vaxq58CbG-?eJs#hMJ4U=yXky-*UuaC6^a_*!SN$F|{r%ckBvfAEt-jJ<
zX|Kr0{JZ*02b^HrDv|St4ToCW5X6}p!fx)nKI{zw86%>&fxSi1L|PSdgvj3ZT0z+C
zc6xX2!MlWysjA-$Y716KFjTwkb(^ji+yQ3(an(#gHOW)L;I6;NBy7RwxCTVV?{}OX
z=;0|xC!QA4qC|EbYf(vwiA`Vv+W@QaHho`CsL
zypr5i`1n)^V+g#1J;y1{<|tVJFd_@D0N27iXi50uxxvPB1Hd>9UqckO+$MV{ckgu;
zE?5Fhh}{M{pe_wqgeZ0gPvg+QI27t2K^dGh;AtLWbTUFSmSRgQmRg|=M>>^eWLt}7
zW!sEwYsJ0h;esgi#9JIBh6#Fg2ZS&9;WAgE~g-``q0gr=+m_Uo5X+hp4
zx_qdEWGO6-aEUF4eRVyx
zk}$CeqY0;%cjG-D^4DdGi{v*Xaw7g_^ms3NydOP2h#ntCk6Xf&y=}qih2#??7)-T5J**%Hs!c3
zg$C!kRJ@SE5BHmYfD&0#&4N8ozXz4$){H~b+574bpx15eISNAg5=RN1PCG6fGo30C
z#)FQ##|GtE?ONb=dL9d|ZEW1CwR?MlYy%u_!3UCHTN-}ZWAH1YUgJLldG(4g!Y*&_
z3tK*FzAs_cSed`pcKd!aY;}D&`DRl**%IbM`S(DgG}^|>4edKP9_D(9da1A6NeJ!H
zgKhpPR0sp|3;eQB{hrXfJ?0Aw0nd0Cnb_p8VH?0%*$rIIUHrG|^iH+yvBxu$
z>b7^+v>}RDN)*WvoDe7n4>w=of{@_A1tA3q4sfUw5{F=k1IGx3gzvAO8P8-rQKDUx
z8uj1Rf4%>u>eUV8B{jVJsaXG}1`%dC*Rd4cP1W6dm>
zJi%E7OBliwmav7hWG``iSt2FU^o4o~zof;`Uzt`*Mr7JrGQw@N^q}S@E#eE-mmEDJ
ziWjnBxshueej(b%Pc9#DusT?|yNQxEJpM
zai6#!??Lf^co6TLcu3^&9ug0WNAMmNkBZ0eJ`x-ildHy$x%^F#+h&E4fbSl|-VEYU
z-71U*W5MXk_%`P(7zswVdEpLV?np3#&$Il=eUj|#z6rQLKwyFZ1j}ajr>SZXc!K}~
zcHc&=#DKhn_9)wFylVlj*O
zwjsx2DApLOWU(3Z2Cx^|Hs9hg5pB)N6vh;Z%4VS%Ej;sh!eM|VjCnOnBLo{`KQUgn
zFp}$mdn|7#WAcRJ*Cg#(ga15y@u`J5SzQT=(L%9WiKHJ!UL}a4AdD8)sv;;al#BSD
ztX)^`*>Wk2iq*B+xp7*MR(8OZ&2fJ{s@8llc@~`U%i%dxnmw_oTMA2+F!C$KK$)dV
z1QcIYymml_9Kuv8)tF|KCU*cBPxA>b>3cs7P$cd(OR{={o6iCp8=9Nhm~R>2s1fm)
zILyb{2)G!7uQ4BH+s2j|o5G0MFk5BvTx`bt2}U?@%1n5?^r!c!{Odd@HP`%W@+hXy
zTZ(U}lqhYK!m3o6wNk~~@XPCgvMa$=Z{yRu>09$x15bA^T`325VTv$I#A8zyzmsS2
zmJ|3c60*I}&ZvYB00fNOOTiG%TWpcsh_TQkk#&Qp{g1h0B)@
z)jHx`H~|nr*JAKt*Qi6eBXh+fF6{v4*#fGxXUy(#l`2(gWxp8Y9c72>wLr?dF;=BL
zuT&{To~MjTMcFR}mtH=jY#Ef-&+Hk>3WM^Zybm2QL&o!pWj_o(FJTSISkjjAc(n2`
z-p4m{ExIS0bKgWi*`Kx=5=0D2ceP|kSRE1-#yC+8*^xQylj7`gNsH_m@)5jbjzCN2
z8_Kp=l~=*B$IuqiK*l!*4@haW=K1A4hc;dD*8*(X^*k8GdYS53&wG8{FDET}WAYTH
zQN4|Q7lU%Sxp%!dnY>Y6%WSi{1?zUny7SJgBy*a{b7fsA
z`Qc@GFFx}D)@Svi+V
zmB5$EB8Q+gMH!35ys7O^rHadbr4p1?Mh0TN7?rA(!}R5ASUV(%l3&QeN#?kV$NFe<
zpuxGr^@hSCS^Wh(;S4~`cUhe+SX=P*Th^8hJp#AteAkGL*sL2{4lriyG%*7HuWkX$
zV`mrMkh}^wWh9(J$hX4MW+2aEYGrMdf~#_xs+r(g&94a05A&%5N2B-^#fyqBD}Gh+
z>xZf6a~K^`Kx2FW8$8N4`X5G27xiPD
zgSgMUfrj4CK$1PHfrKV{Bv9pqkt~TorI5RbQe{awOR~CN3ng(@TZyD>D{sk5gm{s_
zECFiI4<;7#0wEp%P|kW7cwwOXQKC~*CT2s*a8)mtxXq_^&kNPc#wnpQ8`bly9~GBp
zt8xt(9b|19QPUDtMu-rcXxF!BrI!fI6VT+$QSA`lOO0CSWJ^x#XYpa~B4;o@u?&}6
z#xYo!^`YZ#j^8Zu!;OtXC5c@d+nuN+PP$HMDqG={!otj8LXEVQ&%z6Q8~vmv^%bgh
zbKS3@v0)c3NOt+=*db)KOk}3Jd5};@W=-=Ll7#u3=8@5v7LI2+)54*xBhzwKdT!3U
zbYV7#KERwP|2{9tdG9m$)enCDGb(>PUnD2nEL10Z*Y+|Eog#aAyOo2*sC#B)k=93!
zFy3Y>CNfocrREWG$3c%2dZaaNL{mD-M6Sq;q92_S-eu^~)9Qg^W1~k;M~`f)2abb{
z9=#nsdRsmET0QzZdi1q=477R-cJ%0vtla@&)oEXY=xs+1Xhk29!()!ZL(~3{hgYl0fP)=YWq}4bp1$!wuNMF1pN9b+gSW#5A-h|?zwi=|fI*+B;m?V&N)~nzN
z`v4qk;J*|w4V=Y{k2mjcs%Oi94?CC19*q3|vO>O!CFJ)Byhb1((7A_opbpuC%mV+V
z6Bl%a`=6jBan(^hr0q#^7|I2-9%it(z&M7c{Adzz
zzeC+9;7X6cg9I)Um;g{VZZS4aN*!i1L1ld{kbV>>clNo<-i2?@&V1u}t@&ERBuP>d
z$0SiEDFIdH^Uutw3=XdFBA6(3I#(6bay^_RkZ5iZ(^uCr+2`>U9u{)O-Hez<#?6cgD+*CK1HO9En
z?EP`tLSATW4&LMrx*xCK#G38#JV~1=vbSP7CC*87XGyB8TkRGONpzA6?Ba;hsHI!A
zOj3*7l^(iuNDe4^23*^;at2&=QN)ys0Jwt5;ObWKZ}RCx(cGfjD15~d?fp?r(7_!p=fU4KMu5toF*C_ySeHQsIJ
zS-AcnwWHZk{(K%WgXY4ejI)Y+j-^o9`Z}+T4kAJI&dnj0fpJ_r;hmcL6;PTOhI5I$
zj;~uHW>J^FotNb28W3LEWdfH%VRvs-#=sdY}@Xqev8froBsj6ZfWzh
zHe>(%jMdwD|HRbq(feR|3O={8e4>b5`F^nRlirCxbgaB}iz_!&@Wap8c>C&8?{}>J
z!`rhuo9?c22dkg*PW+`~^&j1y)veFD`qW=LR(}IublJG(BP@D(9G&o&0C2PU`RiGOHwvT
za+kz@NnDU`5O|!xH2~c4;;I8TzBtw*_SrRwi+rT8%YF^#XoN#x+iY7~xi;|h+2q3X
z9G#@~?MxaM-?;ql>EmPg9phP+rRQVA9Q)A8x{jO4(Jz-9&s|E}Z9glWa#M~$t~bs9
E2THDKGynhq
literal 0
HcmV?d00001
diff --git a/model/__pycache__/loss.cpython-38.pyc b/model/__pycache__/loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c845d39aa9076e19ad90c0af58f1f868d4eda6b
GIT binary patch
literal 821
zcmYjPON$dh5U%c?nLIXGbrtmJNf-|~cobQL^>NW4h?gPI4AY&MFfY41=!VI~UD&^o
z9QTI^{)WDK%3ttg)nwOALDlr5`m6fBnlHm)AJBfSe(K*0;16vMEx^rdbo(5G0>yGj
z_Fq=;gfn;#iYsvmN@%_i7a|E1`wVfg!+aQL6c3H?$m}x=8WN@;;VQa
zmzAa&*n#tsG_6;}B261@snQ+c-}8L_TF)gnZx@x{Zuxd!r6txK_)Pe+A3#X5PkQtK
z4wTcJxJGV(G;GVpn+^awjJunV%R8>0LW4-a5Y*!B`C32!SnOnMN3hL8E
z{S@lu#P&>X%@W;v8hb}HR*itQns$fbMFt_oDOf8NaX0c-Se1>`tGF1