增加学习率调度方式
This commit is contained in:
@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
with open('configs/scatter.yml', 'r') as f:
|
||||
with open('configs/compare.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Data Setup
|
||||
@ -47,11 +47,11 @@ else:
|
||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||
if conf['training']['optimizer'] in optimizer_mapping:
|
||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=conf['training']['lr_step'],
|
||||
gamma=conf['training']['lr_decay']
|
||||
)
|
||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||
conf['training']['scheduler']))
|
||||
|
||||
else:
|
||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||
|
||||
|
Reference in New Issue
Block a user