Files
ieemoo-ai-contrast/configs/utils.py
2025-06-11 15:23:50 +08:00

57 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace
import torch.optim as optim
import torch.nn as nn
import timm
class trainer_tools:
def __init__(self, conf):
self.conf = conf
def get_backbone(self):
backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
'PPLCNET_x0_5': lambda: PPLCNET_x0_5(),
'PPLCNET_x2_5': lambda: PPLCNET_x2_5(),
'mobilenetv3_large': lambda: MobileNetV3_Large(),
'vit_base': lambda: vit_base_16(pretrained=True),
'efficientnet': lambda: timm.create_model('efficientnet_b0', pretrained=True,
num_classes=self.conf.embedding_size)
}
return backbone_mapping
def get_metric(self, class_num):
# 优化后的metric选择代码块使用字典映射提高可读性和扩展性
metric_mapping = {
'arcface': lambda: ArcFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
'cosface': lambda: CosFace(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device']),
'softmax': lambda: nn.Linear(self.conf['base']['embedding_size'], class_num).to(self.conf['base']['device'])
}
return metric_mapping
def get_optimizer(self, model, metric):
optimizer_mapping = {
'sgd': lambda: optim.SGD(
[{'params': model.parameters()}, {'params': metric.parameters()}],
lr=self.conf['training']['lr'],
weight_decay=self.conf['training']['weight_decay']
),
'adam': lambda: optim.Adam(
[{'params': model.parameters()}, {'params': metric.parameters()}],
lr=self.conf['training']['lr'],
weight_decay=self.conf['training']['weight_decay']
),
'adamw': lambda: optim.AdamW(
[{'params': model.parameters()}, {'params': metric.parameters()}],
lr=self.conf['training']['lr'],
weight_decay=self.conf['training']['weight_decay']
)
}
return optimizer_mapping