from model import (resnet18, resnet34, resnet50, resnet101, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5) from timm.models import vit_base_patch16_224 as vit_base_16 from model.metric import ArcFace, CosFace import torch.optim as optim import torch.nn as nn import timm class trainer_tools: def __init__(self, conf): self.conf = conf def get_backbone(self): backbone_mapping = { 'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio'], pretrained=True), 'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio'], pretrained=True), 'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio'], pretrained=True), 'resnet101': lambda: resnet101(scale=self.conf['models']['channel_ratio'], pretrained=True), 'mobilevit_s': lambda: mobilevit_s(), 'mobilenetv3_small': lambda: MobileNetV3_Small(), 'PPLCNET_x1_0': lambda: PPLCNET_x1_0(), 'PPLCNET_x0_5': lambda: PPLCNET_x0_5(), 'PPLCNET_x2_5': lambda: PPLCNET_x2_5(), 'mobilenetv3_large': lambda: MobileNetV3_Large(), 'vit_base': lambda: vit_base_16(pretrained=True), 'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True, num_classes=self.conf.embedding_size) } return backbone_mapping def get_metric(self, class_num): # 优化后的metric选择代码块,使用字典映射提高可读性和扩展性 metric_mapping = { 'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']), 'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']), 'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']) } return metric_mapping def get_optimizer(self, model, metric): optimizer_mapping = { 'sgd': lambda: optim.SGD( [{'params': model.parameters()}, {'params': metric.parameters()}], lr=self.conf['training']['lr'], weight_decay=self.conf['training']['weight_decay'] ), 'adam': lambda: optim.Adam( [{'params': model.parameters()}, {'params': metric.parameters()}], lr=self.conf['training']['lr'], weight_decay=self.conf['training']['weight_decay'] ), 'adamw': lambda: optim.AdamW( [{'params': model.parameters()}, {'params': metric.parameters()}], lr=self.conf['training']['lr'], weight_decay=self.conf['training']['weight_decay'] ) } return optimizer_mapping def get_scheduler(self, optimizer): scheduler_mapping = { 'step': lambda: optim.lr_scheduler.StepLR( optimizer, step_size=self.conf['training']['lr_step'], gamma=self.conf['training']['lr_decay'] ), 'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.conf['training']['epochs'], eta_min=self.conf['training']['cosine_eta_min'] ), 'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=self.conf['training'].get('cosine_t_0', 10), T_mult=self.conf['training'].get('cosine_t_mult', 1), eta_min=self.conf['training'].get('cosine_eta_min', 0) ) } return scheduler_mapping