增加学习率调度方式

This commit is contained in:
lee
2025-06-13 10:45:53 +08:00
parent 37ecef40f7
commit 1803f319a5
13 changed files with 319 additions and 294 deletions

View File

@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
with open('configs/scatter.yml', 'r') as f:
with open('configs/compare.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup
@ -47,11 +47,11 @@ else:
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=conf['training']['lr_step'],
gamma=conf['training']['lr_decay']
)
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))