rebuild
This commit is contained in:
205
train_distill.py
Normal file
205
train_distill.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""
|
||||
ResNet50蒸馏训练ResNet18实现
|
||||
学生网络使用ArcFace损失
|
||||
支持单机双卡训练
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.cuda.amp import GradScaler
|
||||
from model import resnet18, resnet50, ArcFace
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
from tools.dataset import load_data
|
||||
# from config import config as conf
|
||||
import yaml
|
||||
import math
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = '0.0.0.0'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DistillTrainer:
|
||||
def __init__(self, rank, world_size, conf):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.device = torch.device(f'cuda:{rank}')
|
||||
|
||||
# 初始化模型
|
||||
self.teacher = resnet50(pretrained=True, scale=conf['models']['channel_ratio']).to(self.device)
|
||||
self.student = resnet18(pretrained=True, scale=conf['models']['student_channel_ratio']).to(self.device)
|
||||
|
||||
# 加载预训练教师模型
|
||||
# teacher_path = os.path.join('checkpoints', 'resnet50_0519', 'best.pth')
|
||||
teacher_path = conf['models']['teacher_model_path']
|
||||
if os.path.exists(teacher_path):
|
||||
teacher_state = torch.load(teacher_path, map_location=self.device)
|
||||
new_state_dict = {}
|
||||
for k, v in teacher_state.items():
|
||||
if k.startswith('module.'):
|
||||
new_state_dict[k[7:]] = v # 去除前7个字符'module.'
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
# 加载处理后的状态字典
|
||||
self.teacher.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"Successfully loaded teacher model from {teacher_path}")
|
||||
else:
|
||||
raise FileNotFoundError(f"Teacher model weights not found at {teacher_path}")
|
||||
|
||||
# 数据加载
|
||||
self.train_loader, num_classes = load_data(training=True, cfg=conf)
|
||||
self.val_loader, _ = load_data(training=False, cfg=conf)
|
||||
|
||||
# ArcFace损失
|
||||
self.metric = ArcFace(conf['base']['embedding_size'], num_classes).to(self.device)
|
||||
|
||||
# 分布式训练
|
||||
if world_size > 1:
|
||||
self.teacher = DDP(self.teacher, device_ids=[rank])
|
||||
self.student = DDP(self.student, device_ids=[rank])
|
||||
self.metric = DDP(self.metric, device_ids=[rank])
|
||||
|
||||
# 优化器
|
||||
self.optimizer = torch.optim.SGD([
|
||||
{'params': self.student.parameters()},
|
||||
{'params': self.metric.parameters()}
|
||||
], lr=conf['training']['lr'], momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=conf['training']['epochs'])
|
||||
self.scaler = GradScaler()
|
||||
|
||||
# 损失函数
|
||||
self.arcface_loss = nn.CrossEntropyLoss()
|
||||
self.distill_loss = nn.KLDivLoss(reduction='batchmean')
|
||||
self.conf = conf
|
||||
|
||||
def cosine_annealing(self, epoch, total_epochs, initial_weight, final_weight=0.1):
|
||||
"""
|
||||
余弦退火法动态调整蒸馏权重
|
||||
参数:
|
||||
epoch: 当前训练轮次
|
||||
total_epochs: 总训练轮次
|
||||
initial_weight: 初始蒸馏权重(如0.8)
|
||||
final_weight: 最终蒸馏权重(如0.1)
|
||||
返回:
|
||||
当前轮次的蒸馏权重
|
||||
"""
|
||||
return final_weight + 0.5 * (initial_weight - final_weight) * (1 + math.cos(math.pi * epoch / total_epochs))
|
||||
def train_epoch(self, epoch):
|
||||
self.teacher.eval()
|
||||
self.student.train()
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"\nTeacher network type: {type(self.teacher)}")
|
||||
print(f"Student network type: {type(self.student)}")
|
||||
|
||||
total_loss = 0
|
||||
for data, labels in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
|
||||
data = data.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# with autocast():
|
||||
# 教师输出
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(data)
|
||||
|
||||
# 学生输出
|
||||
student_features = self.student(data)
|
||||
student_logits = self.metric(student_features, labels)
|
||||
|
||||
# 计算损失
|
||||
arc_loss = self.arcface_loss(student_logits, labels)
|
||||
distill_loss = self.distill_loss(
|
||||
F.log_softmax(student_features / self.conf['training']['temperature'], dim=1),
|
||||
F.softmax(teacher_logits / self.conf['training']['temperature'], dim=1)
|
||||
) * (self.conf['training']['temperature'] ** 2) # 温度缩放后需要乘以T^2保持梯度规模
|
||||
current_distill_weight = self.cosine_annealing(epoch, self.conf['training']['epochs'], self.conf['training']['distill_weight'])
|
||||
loss = (1-current_distill_weight) * arc_loss + current_distill_weight * distill_loss
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
self.scheduler.step()
|
||||
return total_loss / len(self.train_loader)
|
||||
|
||||
def validate(self):
|
||||
self.student.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, labels in self.val_loader:
|
||||
data = data.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
features = self.student(data)
|
||||
logits = self.metric(features, labels)
|
||||
|
||||
loss = self.arcface_loss(logits, labels)
|
||||
total_loss += loss.item()
|
||||
|
||||
_, predicted = torch.max(logits.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
return total_loss / len(self.val_loader), correct / total
|
||||
|
||||
def save_checkpoint(self, epoch, is_best=False):
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
state = {
|
||||
'epoch': epoch,
|
||||
'student_state_dict': self.student.state_dict(),
|
||||
'metric_state_dict': self.metric.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
filename = 'best.pth' if is_best else f'checkpoint_{epoch}.pth'
|
||||
if not os.path.exists(self.conf['training']['checkpoints']):
|
||||
os.makedirs(self.conf['training']['checkpoints'])
|
||||
if filename != 'best.pth':
|
||||
torch.save(state, os.path.join(self.conf['training']['checkpoints'], filename))
|
||||
else:
|
||||
torch.save(state['student_state_dict'], os.path.join(self.conf['training']['checkpoints'], filename))
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
with open('configs/distill.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
trainer = DistillTrainer(rank, world_size, conf)
|
||||
best_acc = 0
|
||||
for epoch in range(conf['training']['epochs']):
|
||||
train_loss = trainer.train_epoch(epoch)
|
||||
val_loss, val_acc = trainer.validate()
|
||||
|
||||
if rank == 0:
|
||||
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
trainer.save_checkpoint(epoch, is_best=True)
|
||||
|
||||
cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
world_size = torch.cuda.device_count()
|
||||
if world_size > 1:
|
||||
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
||||
else:
|
||||
train(0, 1)
|
Reference in New Issue
Block a user