206 lines
7.6 KiB
Python
206 lines
7.6 KiB
Python
"""
|
||
ResNet50蒸馏训练ResNet18实现
|
||
学生网络使用ArcFace损失
|
||
支持单机双卡训练
|
||
"""
|
||
|
||
import os
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.distributed as dist
|
||
import torch.multiprocessing as mp
|
||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||
from torch.cuda.amp import GradScaler
|
||
from model import resnet18, resnet50, ArcFace
|
||
from tqdm import tqdm
|
||
import torch.nn.functional as F
|
||
from tools.dataset import load_data
|
||
# from config import config as conf
|
||
import yaml
|
||
import math
|
||
def setup(rank, world_size):
|
||
os.environ['MASTER_ADDR'] = '0.0.0.0'
|
||
os.environ['MASTER_PORT'] = '12355'
|
||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||
|
||
def cleanup():
|
||
dist.destroy_process_group()
|
||
|
||
class DistillTrainer:
|
||
def __init__(self, rank, world_size, conf):
|
||
self.rank = rank
|
||
self.world_size = world_size
|
||
self.device = torch.device(f'cuda:{rank}')
|
||
|
||
# 初始化模型
|
||
self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
|
||
self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
|
||
|
||
# 加载预训练教师模型
|
||
# teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
|
||
teacher_path = conf['models']['teacher_model_path']
|
||
if os.path.exists(teacher_path):
|
||
teacher_state = torch.load(teacher_path, map_location=self.device)
|
||
new_state_dict = {}
|
||
for k, v in teacher_state.items():
|
||
if k.startswith('module.'):
|
||
new_state_dict[k[7:]] = v # 去除前7个字符'module.'
|
||
else:
|
||
new_state_dict[k] = v
|
||
# 加载处理后的状态字典
|
||
self.teacher.load_state_dict(new_state_dict, strict=False)
|
||
|
||
if self.rank == 0:
|
||
print(f"Successfully loaded teacher model from {teacher_path}")
|
||
else:
|
||
raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
|
||
|
||
# 数据加载
|
||
self.train_loader, num_classes = load_data(training=True, cfg=conf)
|
||
self.val_loader, _ = load_data(training=False, cfg=conf)
|
||
|
||
# ArcFace损失
|
||
self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
|
||
|
||
# 分布式训练
|
||
if world_size > 1:
|
||
self.teacher = DDP(self.teacher, device_ids=[rank])
|
||
self.student = DDP(self.student, device_ids=[rank])
|
||
self.metric = DDP(self.metric, device_ids=[rank])
|
||
|
||
# 优化器
|
||
self.optimizer = torch.optim.SGD([
|
||
{'params': self.student.parameters()},
|
||
{'params': self.metric.parameters()}
|
||
], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
|
||
|
||
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
|
||
self.scaler = GradScaler()
|
||
|
||
# 损失函数
|
||
self.arcface_loss = nn.CrossEntropyLoss()
|
||
self.distill_loss = nn.KLDivLoss(reduction='batchmean')
|
||
self.conf = conf
|
||
|
||
def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
|
||
"""
|
||
余弦退火法动态调整蒸馏权重
|
||
参数:
|
||
epoch: 当前训练轮次
|
||
total_epochs: 总训练轮次
|
||
initial_weight: 初始蒸馏权重(如0.8)
|
||
final_weight: 最终蒸馏权重(如0.1)
|
||
返回:
|
||
当前轮次的蒸馏权重
|
||
"""
|
||
return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
|
||
def train_epoch(self, epoch):
|
||
self.teacher.eval()
|
||
self.student.train()
|
||
|
||
if self.rank == 0:
|
||
print(f"\nTeacher network type: {type(self.teacher)}")
|
||
print(f"Student network type: {type(self.student)}")
|
||
|
||
total_loss = 0
|
||
for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
|
||
data = data.to(self.device)
|
||
labels = labels.to(self.device)
|
||
|
||
# with autocast():
|
||
# 教师输出
|
||
with torch.no_grad():
|
||
teacher_logits = self.teacher(data)
|
||
|
||
# 学生输出
|
||
student_features = self.student(data)
|
||
student_logits = self.metric(student_features, labels)
|
||
|
||
# 计算损失
|
||
arc_loss = self.arcface_loss(student_logits, labels)
|
||
distill_loss = self.distill_loss(
|
||
F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
|
||
F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
|
||
) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
|
||
current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
|
||
loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
|
||
|
||
self.optimizer.zero_grad()
|
||
self.scaler.scale(loss).backward()
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
|
||
total_loss += loss.item()
|
||
|
||
self.scheduler.step()
|
||
return total_loss / len(self.train_loader)
|
||
|
||
def validate(self):
|
||
self.student.eval()
|
||
total_loss = 0
|
||
correct = 0
|
||
total = 0
|
||
|
||
with torch.no_grad():
|
||
for data, labels in self.val_loader:
|
||
data = data.to(self.device)
|
||
labels = labels.to(self.device)
|
||
|
||
features = self.student(data)
|
||
logits = self.metric(features, labels)
|
||
|
||
loss = self.arcface_loss(logits, labels)
|
||
total_loss += loss.item()
|
||
|
||
_, predicted = torch.max(logits.data, 1)
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum().item()
|
||
|
||
return total_loss / len(self.val_loader), correct / total
|
||
|
||
def save_checkpoint(self, epoch, is_best=False):
|
||
if self.rank != 0:
|
||
return
|
||
|
||
state = {
|
||
'epoch': epoch,
|
||
'student_state_dict': self.student.state_dict(),
|
||
'metric_state_dict': self.metric.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
}
|
||
|
||
filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
|
||
if not os.path.exists(self.conf['training']['checkpoints']):
|
||
os.makedirs(self.conf['training']['checkpoints'])
|
||
if filename != 'best.pth':
|
||
torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
|
||
else:
|
||
torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
|
||
|
||
def train(rank, world_size):
|
||
setup(rank, world_size)
|
||
with open('configs/distill.yml', 'r') as f:
|
||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||
trainer = DistillTrainer(rank, world_size, conf)
|
||
best_acc = 0
|
||
for epoch in range(conf['training']['epochs']):
|
||
train_loss = trainer.train_epoch(epoch)
|
||
val_loss, val_acc = trainer.validate()
|
||
|
||
if rank == 0:
|
||
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
|
||
|
||
if val_acc > best_acc:
|
||
best_acc = val_acc
|
||
trainer.save_checkpoint(epoch, is_best=True)
|
||
|
||
cleanup()
|
||
|
||
if __name__ == '__main__':
|
||
world_size = torch.cuda.device_count()
|
||
if world_size > 1:
|
||
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
||
else:
|
||
train(0, 1)
|