增加学习率调度方式

This commit is contained in:
lee
2025-06-13 10:45:53 +08:00
parent 37ecef40f7
commit 1803f319a5
13 changed files with 319 additions and 294 deletions

View File

@ -297,8 +297,8 @@ def init_model():
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
else:
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
if conf.model_half:
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))