数据分析
This commit is contained in:
@ -17,7 +17,7 @@ import yaml
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def load_configuration(config_path='configs/scatter.yml'):
|
||||
def load_configuration(config_path='configs/compare.yml'):
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r') as f:
|
||||
return yaml.load(f, Loader=yaml.FullLoader)
|
||||
@ -74,13 +74,13 @@ def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, sca
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
embeddings = model(data)
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(embeddings, labels)
|
||||
else:
|
||||
thetas = metric(embeddings)
|
||||
loss = criterion(thetas, labels)
|
||||
# with torch.cuda.amp.autocast():
|
||||
embeddings = model(data)
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(embeddings, labels)
|
||||
else:
|
||||
thetas = metric(embeddings)
|
||||
loss = criterion(thetas, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
|
Reference in New Issue
Block a user