""" ResNet50蒸馏训练ResNet18实现 学生网络使用ArcFace损失 支持单机双卡训练 """ import os import torch import torch.nn as nn import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.lr_scheduler import CosineAnnealingLR from torch.cuda.amp import GradScaler from model import resnet18, resnet50, ArcFace from tqdm import tqdm import torch.nn.functional as F from tools.dataset import load_data # from config import config as conf import yaml import math def setup(rank, world_size): os.environ['MASTER_ADDR'] = '0.0.0.0' os.environ['MASTER_PORT'] = '12355' dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() class DistillTrainer: def __init__(self, rank, world_size, conf): self.rank = rank self.world_size = world_size self.device = torch.device(f'cuda:{rank}') # 初始化模型 self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device) self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device) # 加载预训练教师模型 # teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth') teacher_path = conf['models']['teacher_model_path'] if os.path.exists(teacher_path): teacher_state = torch.load(teacher_path, map_location=self.device) new_state_dict = {} for k, v in teacher_state.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v # 去除前7个字符'module.' else: new_state_dict[k] = v # 加载处理后的状态字典 self.teacher.load_state_dict(new_state_dict, strict=False) if self.rank == 0: print(f"Successfully loaded teacher model from {teacher_path}") else: raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}") # 数据加载 self.train_loader, num_classes = load_data(training=True, cfg=conf) self.val_loader, _ = load_data(training=False, cfg=conf) # ArcFace损失 self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device) # 分布式训练 if world_size > 1: self.teacher = DDP(self.teacher, device_ids=[rank]) self.student = DDP(self.student, device_ids=[rank]) self.metric = DDP(self.metric, device_ids=[rank]) # 优化器 self.optimizer = torch.optim.SGD([ {'params': self.student.parameters()}, {'params': self.metric.parameters()} ], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4) self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs']) self.scaler = GradScaler() # 损失函数 self.arcface_loss = nn.CrossEntropyLoss() self.distill_loss = nn.KLDivLoss(reduction='batchmean') self.conf = conf def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1): """ 余弦退火法动态调整蒸馏权重 参数: epoch: 当前训练轮次 total_epochs: 总训练轮次 initial_weight: 初始蒸馏权重(如0.8) final_weight: 最终蒸馏权重(如0.1) 返回: 当前轮次的蒸馏权重 """ return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs)) def train_epoch(self, epoch): self.teacher.eval() self.student.train() if self.rank == 0: print(f"\nTeacher network type: {type(self.teacher)}") print(f"Student network type: {type(self.student)}") total_loss = 0 for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"): data = data.to(self.device) labels = labels.to(self.device) # with autocast(): # 教师输出 with torch.no_grad(): teacher_logits = self.teacher(data) # 学生输出 student_features = self.student(data) student_logits = self.metric(student_features, labels) # 计算损失 arc_loss = self.arcface_loss(student_logits, labels) distill_loss = self.distill_loss( F.log_softmax(student_features / self.conf['training']['temperature'], dim=1), F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1) ) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模 current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight']) loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() total_loss += loss.item() self.scheduler.step() return total_loss / len(self.train_loader) def validate(self): self.student.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): for data, labels in self.val_loader: data = data.to(self.device) labels = labels.to(self.device) features = self.student(data) logits = self.metric(features, labels) loss = self.arcface_loss(logits, labels) total_loss += loss.item() _, predicted = torch.max(logits.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return total_loss / len(self.val_loader), correct / total def save_checkpoint(self, epoch, is_best=False): if self.rank != 0: return state = { 'epoch': epoch, 'student_state_dict': self.student.state_dict(), 'metric_state_dict': self.metric.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), } filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth' if not os.path.exists(self.conf['training']['checkpoints']): os.makedirs(self.conf['training']['checkpoints']) if filename != 'best.pth': torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename)) else: torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename)) def train(rank, world_size): setup(rank, world_size) with open('configs/distill.yml', 'r') as f: conf = yaml.load(f, Loader=yaml.FullLoader) trainer = DistillTrainer(rank, world_size, conf) best_acc = 0 for epoch in range(conf['training']['epochs']): train_loss = trainer.train_epoch(epoch) val_loss, val_acc = trainer.validate() if rank == 0: print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}") if val_acc > best_acc: best_acc = val_acc trainer.save_checkpoint(epoch, is_best=True) cleanup() if __name__ == '__main__': world_size = torch.cuda.device_count() if world_size > 1: mp.spawn(train, args=(world_size,), nprocs=world_size, join=True) else: train(0, 1)