This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

56
configs/utils.py Normal file
View File

@ -0,0 +1,56 @@
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace
import torch.optim as optim
import torch.nn as nn
import timm
class trainer_tools:
def __init__(self, conf):
self.conf = conf
def get_backbone(self):
backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
'mobilenetv3_large': lambda: MobileNetV3_Large(),
'vit_base': lambda: vit_base_16(pretrained=True),
'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
num_classes=self.conf.embedding_size)
}
return backbone_mapping
def get_metric(self, class_num):
# 优化后的metric选择代码块使用字典映射提高可读性和扩展性
metric_mapping = {
'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
}
return metric_mapping
def get_optimizer(self, model, metric):
optimizer_mapping = {
'sgd': lambda: optim.SGD(
[{'params': model.parameters()}, {'params': metric.parameters()}],
lr=self.conf['training']['lr'],
weight_decay=self.conf['training']['weight_decay']
),
'adam': lambda: optim.Adam(
[{'params': model.parameters()}, {'params': metric.parameters()}],
lr=self.conf['training']['lr'],
weight_decay=self.conf['training']['weight_decay']
),
'adamw': lambda: optim.AdamW(
[{'params': model.parameters()}, {'params': metric.parameters()}],
lr=self.conf['training']['lr'],
weight_decay=self.conf['training']['weight_decay']
)
}
return optimizer_mapping