Files
ieemoo-ai-contrast/train_distill.py
2025-06-11 15:23:50 +08:00

206 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ResNet50蒸馏训练ResNet18实现
学生网络使用ArcFace损失
支持单机双卡训练
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler
from model import resnet18, resnet50, ArcFace
from tqdm import tqdm
import torch.nn.functional as F
from tools.dataset import load_data
# from config import config as conf
import yaml
import math
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = '0.0.0.0'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class DistillTrainer:
def __init__(self, rank, world_size, conf):
self.rank = rank
self.world_size = world_size
self.device = torch.device(f'cuda:{rank}')
# 初始化模型
self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
# 加载预训练教师模型
# teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
teacher_path = conf['models']['teacher_model_path']
if os.path.exists(teacher_path):
teacher_state = torch.load(teacher_path, map_location=self.device)
new_state_dict = {}
for k, v in teacher_state.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v # 去除前7个字符'module.'
else:
new_state_dict[k] = v
# 加载处理后的状态字典
self.teacher.load_state_dict(new_state_dict, strict=False)
if self.rank == 0:
print(f"Successfully loaded teacher model from {teacher_path}")
else:
raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
# 数据加载
self.train_loader, num_classes = load_data(training=True, cfg=conf)
self.val_loader, _ = load_data(training=False, cfg=conf)
# ArcFace损失
self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
# 分布式训练
if world_size > 1:
self.teacher = DDP(self.teacher, device_ids=[rank])
self.student = DDP(self.student, device_ids=[rank])
self.metric = DDP(self.metric, device_ids=[rank])
# 优化器
self.optimizer = torch.optim.SGD([
{'params': self.student.parameters()},
{'params': self.metric.parameters()}
], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
self.scaler = GradScaler()
# 损失函数
self.arcface_loss = nn.CrossEntropyLoss()
self.distill_loss = nn.KLDivLoss(reduction='batchmean')
self.conf = conf
def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
"""
余弦退火法动态调整蒸馏权重
参数:
epoch: 当前训练轮次
total_epochs: 总训练轮次
initial_weight: 初始蒸馏权重如0.8
final_weight: 最终蒸馏权重如0.1
返回:
当前轮次的蒸馏权重
"""
return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
def train_epoch(self, epoch):
self.teacher.eval()
self.student.train()
if self.rank == 0:
print(f"\nTeacher network type: {type(self.teacher)}")
print(f"Student network type: {type(self.student)}")
total_loss = 0
for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
data = data.to(self.device)
labels = labels.to(self.device)
# with autocast():
# 教师输出
with torch.no_grad():
teacher_logits = self.teacher(data)
# 学生输出
student_features = self.student(data)
student_logits = self.metric(student_features, labels)
# 计算损失
arc_loss = self.arcface_loss(student_logits, labels)
distill_loss = self.distill_loss(
F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
total_loss += loss.item()
self.scheduler.step()
return total_loss / len(self.train_loader)
def validate(self):
self.student.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, labels in self.val_loader:
data = data.to(self.device)
labels = labels.to(self.device)
features = self.student(data)
logits = self.metric(features, labels)
loss = self.arcface_loss(logits, labels)
total_loss += loss.item()
_, predicted = torch.max(logits.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return total_loss / len(self.val_loader), correct / total
def save_checkpoint(self, epoch, is_best=False):
if self.rank != 0:
return
state = {
'epoch': epoch,
'student_state_dict': self.student.state_dict(),
'metric_state_dict': self.metric.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}
filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
if not os.path.exists(self.conf['training']['checkpoints']):
os.makedirs(self.conf['training']['checkpoints'])
if filename != 'best.pth':
torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
else:
torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
def train(rank, world_size):
setup(rank, world_size)
with open('configs/distill.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
trainer = DistillTrainer(rank, world_size, conf)
best_acc = 0
for epoch in range(conf['training']['epochs']):
train_loss = trainer.train_epoch(epoch)
val_loss, val_acc = trainer.validate()
if rank == 0:
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
trainer.save_checkpoint(epoch, is_best=True)
cleanup()
if __name__ == '__main__':
world_size = torch.cuda.device_count()
if world_size > 1:
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
else:
train(0, 1)