数据分析

This commit is contained in:
lee
2025-07-17 14:33:18 +08:00
parent 09f41f6289
commit 54898e30ec
12 changed files with 233 additions and 34 deletions

View File

@ -17,7 +17,7 @@ import yaml
from datetime import datetime
def load_configuration(config_path='configs/scatter.yml'):
def load_configuration(config_path='configs/compare.yml'):
"""加载配置文件"""
with open(config_path, 'r') as f:
return yaml.load(f, Loader=yaml.FullLoader)
@ -74,13 +74,13 @@ def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, sca
data = data.to(device)
labels = labels.to(device)
with torch.cuda.amp.autocast():
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
# with torch.cuda.amp.autocast():
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()