This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

205
train_distill.py Normal file
View File

@ -0,0 +1,205 @@
"""
ResNet50蒸馏训练ResNet18实现
学生网络使用ArcFace损失
支持单机双卡训练
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler
from model import resnet18, resnet50, ArcFace
from tqdm import tqdm
import torch.nn.functional as F
from tools.dataset import load_data
# from config import config as conf
import yaml
import math
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = '0.0.0.0'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class DistillTrainer:
def __init__(self, rank, world_size, conf):
self.rank = rank
self.world_size = world_size
self.device = torch.device(f'cuda:{rank}')
# 初始化模型
self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
# 加载预训练教师模型
# teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
teacher_path = conf['models']['teacher_model_path']
if os.path.exists(teacher_path):
teacher_state = torch.load(teacher_path, map_location=self.device)
new_state_dict = {}
for k, v in teacher_state.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v # 去除前7个字符'module.'
else:
new_state_dict[k] = v
# 加载处理后的状态字典
self.teacher.load_state_dict(new_state_dict, strict=False)
if self.rank == 0:
print(f"Successfully loaded teacher model from {teacher_path}")
else:
raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
# 数据加载
self.train_loader, num_classes = load_data(training=True, cfg=conf)
self.val_loader, _ = load_data(training=False, cfg=conf)
# ArcFace损失
self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
# 分布式训练
if world_size > 1:
self.teacher = DDP(self.teacher, device_ids=[rank])
self.student = DDP(self.student, device_ids=[rank])
self.metric = DDP(self.metric, device_ids=[rank])
# 优化器
self.optimizer = torch.optim.SGD([
{'params': self.student.parameters()},
{'params': self.metric.parameters()}
], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
self.scaler = GradScaler()
# 损失函数
self.arcface_loss = nn.CrossEntropyLoss()
self.distill_loss = nn.KLDivLoss(reduction='batchmean')
self.conf = conf
def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
"""
余弦退火法动态调整蒸馏权重
参数:
epoch: 当前训练轮次
total_epochs: 总训练轮次
initial_weight: 初始蒸馏权重如0.8
final_weight: 最终蒸馏权重如0.1
返回:
当前轮次的蒸馏权重
"""
return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
def train_epoch(self, epoch):
self.teacher.eval()
self.student.train()
if self.rank == 0:
print(f"\nTeacher network type: {type(self.teacher)}")
print(f"Student network type: {type(self.student)}")
total_loss = 0
for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
data = data.to(self.device)
labels = labels.to(self.device)
# with autocast():
# 教师输出
with torch.no_grad():
teacher_logits = self.teacher(data)
# 学生输出
student_features = self.student(data)
student_logits = self.metric(student_features, labels)
# 计算损失
arc_loss = self.arcface_loss(student_logits, labels)
distill_loss = self.distill_loss(
F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
total_loss += loss.item()
self.scheduler.step()
return total_loss / len(self.train_loader)
def validate(self):
self.student.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, labels in self.val_loader:
data = data.to(self.device)
labels = labels.to(self.device)
features = self.student(data)
logits = self.metric(features, labels)
loss = self.arcface_loss(logits, labels)
total_loss += loss.item()
_, predicted = torch.max(logits.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return total_loss / len(self.val_loader), correct / total
def save_checkpoint(self, epoch, is_best=False):
if self.rank != 0:
return
state = {
'epoch': epoch,
'student_state_dict': self.student.state_dict(),
'metric_state_dict': self.metric.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}
filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
if not os.path.exists(self.conf['training']['checkpoints']):
os.makedirs(self.conf['training']['checkpoints'])
if filename != 'best.pth':
torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
else:
torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
def train(rank, world_size):
setup(rank, world_size)
with open('configs/distill.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
trainer = DistillTrainer(rank, world_size, conf)
best_acc = 0
for epoch in range(conf['training']['epochs']):
train_loss = trainer.train_epoch(epoch)
val_loss, val_acc = trainer.validate()
if rank == 0:
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
trainer.save_checkpoint(epoch, is_best=True)
cleanup()
if __name__ == '__main__':
world_size = torch.cuda.device_count()
if world_size > 1:
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
else:
train(0, 1)