80 lines
3.6 KiB
Python
80 lines
3.6 KiB
Python
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||
from timm.models import vit_base_patch16_224 as vit_base_16
|
||
from model.metric import ArcFace, CosFace
|
||
import torch.optim as optim
|
||
import torch.nn as nn
|
||
import timm
|
||
|
||
|
||
class trainer_tools:
|
||
def __init__(self, conf):
|
||
self.conf = conf
|
||
|
||
def get_backbone(self):
|
||
backbone_mapping = {
|
||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
|
||
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
|
||
'mobilevit_s': lambda: mobilevit_s(),
|
||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||
'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
|
||
'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
|
||
'mobilenetv3_large': lambda: MobileNetV3_Large(),
|
||
'vit_base': lambda: vit_base_16(pretrained=True),
|
||
'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
|
||
num_classes=self.conf.embedding_size)
|
||
}
|
||
return backbone_mapping
|
||
|
||
def get_metric(self, class_num):
|
||
# 优化后的metric选择代码块,使用字典映射提高可读性和扩展性
|
||
metric_mapping = {
|
||
'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
|
||
'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
|
||
'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
|
||
}
|
||
return metric_mapping
|
||
|
||
def get_optimizer(self, model, metric):
|
||
optimizer_mapping = {
|
||
'sgd': lambda: optim.SGD(
|
||
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||
lr=self.conf['training']['lr'],
|
||
weight_decay=self.conf['training']['weight_decay']
|
||
),
|
||
'adam': lambda: optim.Adam(
|
||
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||
lr=self.conf['training']['lr'],
|
||
weight_decay=self.conf['training']['weight_decay']
|
||
),
|
||
'adamw': lambda: optim.AdamW(
|
||
[{'params': model.parameters()}, {'params': metric.parameters()}],
|
||
lr=self.conf['training']['lr'],
|
||
weight_decay=self.conf['training']['weight_decay']
|
||
)
|
||
}
|
||
return optimizer_mapping
|
||
|
||
def get_scheduler(self, optimizer):
|
||
scheduler_mapping = {
|
||
'step': lambda: optim.lr_scheduler.StepLR(
|
||
optimizer,
|
||
step_size=self.conf['training']['lr_step'],
|
||
gamma=self.conf['training']['lr_decay']
|
||
),
|
||
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
|
||
optimizer,
|
||
T_max=self.conf['training']['epochs'],
|
||
eta_min=self.conf['training']['cosine_eta_min']
|
||
),
|
||
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||
optimizer,
|
||
T_0=self.conf['training'].get('cosine_t_0', 10),
|
||
T_mult=self.conf['training'].get('cosine_t_mult', 1),
|
||
eta_min=self.conf['training'].get('cosine_eta_min', 0)
|
||
)
|
||
}
|
||
return scheduler_mapping
|