增加学习率调度方式
This commit is contained in:
265
.idea/CopilotChatHistory.xml
generated
265
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,18 @@
|
|||||||
<component name="CopilotChatHistory">
|
<component name="CopilotChatHistory">
|
||||||
<option name="conversations">
|
<option name="conversations">
|
||||||
<list>
|
<list>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1749718122230" />
|
||||||
|
<option name="id" value="01976353bef6703884544447c919013c" />
|
||||||
|
<option name="title" value="新对话 2025年6月12日 16:48:42" />
|
||||||
|
<option name="updateTime" value="1749718122230" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1749648208122" />
|
||||||
|
<option name="id" value="01975f28f0fa7128afe7feddcdedb740" />
|
||||||
|
<option name="title" value="新对话 2025年6月11日 21:23:28" />
|
||||||
|
<option name="updateTime" value="1749648208122" />
|
||||||
|
</Conversation>
|
||||||
<Conversation>
|
<Conversation>
|
||||||
<option name="createTime" value="1749522765718" />
|
<option name="createTime" value="1749522765718" />
|
||||||
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
|
<option name="id" value="019757aed78e777c96c4b7007ff2fecc" />
|
||||||
@ -57,16 +69,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -91,16 +94,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
</list>
|
</list>
|
||||||
@ -135,16 +129,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -169,16 +154,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -203,16 +179,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -237,16 +204,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -271,16 +229,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -305,16 +254,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -339,16 +279,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -373,16 +304,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -407,16 +329,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -441,16 +354,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -475,16 +379,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -509,16 +404,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -543,16 +429,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -577,16 +454,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -611,16 +479,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -645,16 +504,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -679,16 +529,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -713,16 +554,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -773,16 +605,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -807,16 +630,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
<Turn>
|
<Turn>
|
||||||
@ -841,16 +655,7 @@
|
|||||||
</option>
|
</option>
|
||||||
<option name="status" value="SUCCESS" />
|
<option name="status" value="SUCCESS" />
|
||||||
<option name="variables">
|
<option name="variables">
|
||||||
<list>
|
<list />
|
||||||
<CodebaseVariable>
|
|
||||||
<option name="selectedPlaceHolder">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
<option name="selectedVariable">
|
|
||||||
<Object />
|
|
||||||
</option>
|
|
||||||
</CodebaseVariable>
|
|
||||||
</list>
|
|
||||||
</option>
|
</option>
|
||||||
</Turn>
|
</Turn>
|
||||||
</list>
|
</list>
|
||||||
|
@ -15,8 +15,8 @@ base:
|
|||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet34'
|
||||||
channel_ratio: 0.75
|
channel_ratio: 1.0
|
||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
training:
|
training:
|
||||||
@ -29,11 +29,14 @@ training:
|
|||||||
lr_step: 10 # 学习率调整间隔(epoch)
|
lr_step: 10 # 学习率调整间隔(epoch)
|
||||||
lr_decay: 0.98 # 学习率衰减率
|
lr_decay: 0.98 # 学习率衰减率
|
||||||
weight_decay: 0.0005 # 权重衰减
|
weight_decay: 0.0005 # 权重衰减
|
||||||
scheduler: "cosine_annealing" # 学习率调度器(可选:cosine_annealing/step/none)
|
scheduler: "cosine" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
checkpoints: "./checkpoints/resnet18_test/" # 模型保存目录
|
checkpoints: "./checkpoints/resnet34_20250612_scale=1.0/" # 模型保存目录
|
||||||
restore: false
|
restore: false
|
||||||
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
restore_model: "resnet18_test/epoch_600.pth" # 模型恢复路径
|
||||||
|
cosine_t_0: 10 # 初始周期长度
|
||||||
|
cosine_t_mult: 1 # 周期长度倍率
|
||||||
|
cosine_eta_min: 0.00001 # 最小学习率
|
||||||
|
|
||||||
# 验证参数
|
# 验证参数
|
||||||
validation:
|
validation:
|
||||||
|
@ -8,13 +8,13 @@ base:
|
|||||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
embedding_size: 256 # 特征维度
|
embedding_size: 256 # 特征维度
|
||||||
pin_memory: true # 是否启用pin_memory
|
pin_memory: true # 是否启用pin_memory
|
||||||
distributed: true # 是否启用分布式训练
|
distributed: false # 是否启用分布式训练
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 1.0
|
channel_ratio: 0.75
|
||||||
model_path: "./checkpoints/resnet18_scatter_6.2/best.pth"
|
model_path: "./checkpoints/resnet18_0515/best.pth"
|
||||||
half: false # 是否启用半精度测试(fp16)
|
half: false # 是否启用半精度测试(fp16)
|
||||||
|
|
||||||
# 数据配置
|
# 数据配置
|
||||||
@ -22,9 +22,9 @@ data:
|
|||||||
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
group_test: False # 数据集名称(示例用,可替换为实际数据集)
|
||||||
test_batch_size: 128 # 训练批次大小
|
test_batch_size: 128 # 训练批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
test_dir: "../data_center/scatter/" # 验证数据集根目录
|
test_dir: "../data_center/contrast_learning/contrast_test_data" # 验证数据集根目录
|
||||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||||
test_list: "../data_center/scatter/val_pair.txt"
|
test_list: "../data_center/contrast_learning/contrast_test_data/test_pair.txt"
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
|
27
configs/transform.yml
Normal file
27
configs/transform.yml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# configs/transform.yml
|
||||||
|
# pth转换onnx配置文件
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
base:
|
||||||
|
experiment_name: "model_comparison" # 实验名称(用于结果保存目录)
|
||||||
|
seed: 42 # 随机种子(保证可复现性)
|
||||||
|
device: "cuda" # 训练设备(cuda/cpu)
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
embedding_size: 256 # 特征维度
|
||||||
|
pin_memory: true # 是否启用pin_memory
|
||||||
|
distributed: true # 是否启用分布式训练
|
||||||
|
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
models:
|
||||||
|
backbone: 'resnet50'
|
||||||
|
channel_ratio: 1.0
|
||||||
|
model_path: "../checkpoints/resnet50_0519/best.pth"
|
||||||
|
onnx_model: "../checkpoints/resnet50_0519/best.onnx"
|
||||||
|
rknn_model: "../checkpoints/resnet50_0519/best.rknn"
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./logs" # 日志保存目录
|
||||||
|
tensorboard: true # 是否启用TensorBoard
|
||||||
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
@ -1,4 +1,4 @@
|
|||||||
from model import (resnet18, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||||
from timm.models import vit_base_patch16_224 as vit_base_16
|
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||||
from model.metric import ArcFace, CosFace
|
from model.metric import ArcFace, CosFace
|
||||||
@ -14,6 +14,8 @@ class trainer_tools:
|
|||||||
def get_backbone(self):
|
def get_backbone(self):
|
||||||
backbone_mapping = {
|
backbone_mapping = {
|
||||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||||
|
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
|
||||||
|
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
|
||||||
'mobilevit_s': lambda: mobilevit_s(),
|
'mobilevit_s': lambda: mobilevit_s(),
|
||||||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||||
@ -54,3 +56,24 @@ class trainer_tools:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
return optimizer_mapping
|
return optimizer_mapping
|
||||||
|
|
||||||
|
def get_scheduler(self, optimizer):
|
||||||
|
scheduler_mapping = {
|
||||||
|
'step': lambda: optim.lr_scheduler.StepLR(
|
||||||
|
optimizer,
|
||||||
|
step_size=self.conf['training']['lr_step'],
|
||||||
|
gamma=self.conf['training']['lr_decay']
|
||||||
|
),
|
||||||
|
'cosine': lambda: optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
T_max=self.conf['training']['epochs'],
|
||||||
|
eta_min=self.conf['training']['cosine_eta_min']
|
||||||
|
),
|
||||||
|
'cosine_warm': lambda: optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer,
|
||||||
|
T_0=self.conf['training'].get('cosine_t_0', 10),
|
||||||
|
T_mult=self.conf['training'].get('cosine_t_mult', 1),
|
||||||
|
eta_min=self.conf['training'].get('cosine_eta_min', 0)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return scheduler_mapping
|
||||||
|
171
getpairs.py
Normal file
171
getpairs.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
class PairGenerator:
|
||||||
|
"""Generate positive and negative image pairs for contrastive learning."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._setup_logging()
|
||||||
|
|
||||||
|
def _setup_logging(self):
|
||||||
|
"""Configure logging settings."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def _get_image_files(self, root_dir: str) -> Dict[str, List[str]]:
|
||||||
|
"""Scan directory and return dict of {folder: [image_paths]}."""
|
||||||
|
root = Path(root_dir)
|
||||||
|
if not root.is_dir():
|
||||||
|
raise ValueError(f"Invalid directory: {root_dir}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
str(folder): [str(f) for f in folder.iterdir() if f.is_file()]
|
||||||
|
for folder in root.iterdir() if folder.is_dir()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_same_pairs(
|
||||||
|
self,
|
||||||
|
files_dict: Dict[str, List[str]],
|
||||||
|
num_pairs: int,
|
||||||
|
group_size: Optional[int] = None
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""Generate positive pairs from same folder."""
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
for folder, files in files_dict.items():
|
||||||
|
if len(files) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if group_size:
|
||||||
|
# Group mode: generate all possible pairs within group
|
||||||
|
for i in range(0, len(files), group_size):
|
||||||
|
group = files[i:i+group_size]
|
||||||
|
pairs.extend([
|
||||||
|
(group[i], group[j], 1)
|
||||||
|
for i in range(len(group))
|
||||||
|
for j in range(i+1, len(group))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Individual mode: random pairs
|
||||||
|
try:
|
||||||
|
pairs.extend(self._random_pairs(files, min(3, len(files)//2)))
|
||||||
|
except ValueError as e:
|
||||||
|
self.logger.warning(f"Skipping folder {folder}: {str(e)}")
|
||||||
|
|
||||||
|
random.shuffle(pairs)
|
||||||
|
return pairs[:num_pairs]
|
||||||
|
|
||||||
|
def _generate_cross_pairs(
|
||||||
|
self,
|
||||||
|
files_dict: Dict[str, List[str]],
|
||||||
|
num_pairs: int
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""Generate negative pairs from different folders."""
|
||||||
|
folders = list(files_dict.keys())
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
while len(pairs) < num_pairs and len(folders) >= 2:
|
||||||
|
folder1, folder2 = random.sample(folders, 2)
|
||||||
|
file1 = random.choice(files_dict[folder1])
|
||||||
|
file2 = random.choice(files_dict[folder2])
|
||||||
|
|
||||||
|
if not any((f1 == file1 and f2 == file2) or (f1 == file2 and f2 == file1)
|
||||||
|
for f1, f2, _ in pairs):
|
||||||
|
pairs.append((file1, file2, 0))
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def _random_pairs(self, files: List[str], num_pairs: int) -> List[Tuple[str, str, int]]:
|
||||||
|
"""Generate random pairs from file list."""
|
||||||
|
if len(files) < 2 * num_pairs:
|
||||||
|
raise ValueError("Not enough files for requested pairs")
|
||||||
|
|
||||||
|
indices = random.sample(range(len(files)), 2 * num_pairs)
|
||||||
|
indices.sort()
|
||||||
|
return [(files[i], files[i+1], 1) for i in range(0, len(indices), 2)]
|
||||||
|
|
||||||
|
def get_pairs(self, root_dir: str, num_pairs: int = 2000) -> List[Tuple[str, str, int]]:
|
||||||
|
"""
|
||||||
|
Generate individual image pairs with labels (1=same, 0=different).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing subfolders of images
|
||||||
|
num_pairs: Number of pairs to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (path1, path2, label) tuples
|
||||||
|
"""
|
||||||
|
files_dict = self._get_image_files(root_dir)
|
||||||
|
|
||||||
|
same_pairs = self._generate_same_pairs(files_dict, num_pairs)
|
||||||
|
cross_pairs = self._generate_cross_pairs(files_dict, len(same_pairs))
|
||||||
|
|
||||||
|
pairs = same_pairs + cross_pairs
|
||||||
|
self.logger.info(f"Generated {len(pairs)} pairs ({len(same_pairs)} positive, {len(cross_pairs)} negative)")
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def get_group_pairs(
|
||||||
|
self,
|
||||||
|
root_dir: str,
|
||||||
|
img_num: int = 20,
|
||||||
|
group_num: int = 10,
|
||||||
|
num_pairs: int = 5000
|
||||||
|
) -> List[Tuple[str, str, int]]:
|
||||||
|
"""
|
||||||
|
Generate grouped image pairs with labels (1=same, 0=different).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing subfolders of images
|
||||||
|
img_num: Minimum images required per folder
|
||||||
|
group_num: Number of images per group
|
||||||
|
num_pairs: Number of pairs to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (path1, path2, label) tuples
|
||||||
|
"""
|
||||||
|
# Filter folders with enough images
|
||||||
|
files_dict = {
|
||||||
|
k: v for k, v in self._get_image_files(root_dir).items()
|
||||||
|
if len(v) >= img_num
|
||||||
|
}
|
||||||
|
|
||||||
|
# Split into groups
|
||||||
|
grouped_files = {}
|
||||||
|
for folder, files in files_dict.items():
|
||||||
|
random.shuffle(files)
|
||||||
|
grouped_files[folder] = [
|
||||||
|
files[i:i+group_num]
|
||||||
|
for i in range(0, len(files), group_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate pairs
|
||||||
|
same_pairs = self._generate_same_pairs(
|
||||||
|
grouped_files, num_pairs, group_size=group_num
|
||||||
|
)
|
||||||
|
cross_pairs = self._generate_cross_pairs(
|
||||||
|
grouped_files, len(same_pairs)
|
||||||
|
)
|
||||||
|
|
||||||
|
pairs = same_pairs + cross_pairs
|
||||||
|
self.logger.info(f"Generated {len(pairs)} group pairs")
|
||||||
|
|
||||||
|
# Save to JSON
|
||||||
|
with open("cross_same.json", 'w') as f:
|
||||||
|
json.dump(pairs, f)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generator = PairGenerator()
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
pairs = generator.get_pairs('/home/lc/contrast_nettest/data/contrast_test_data/test') # Individual pairs
|
||||||
|
# groups = generator.get_group_pairs('val') # Group pairs
|
@ -297,8 +297,8 @@ def init_model():
|
|||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(conf['model']['model_path'], map_location=conf['base']['device']))
|
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
||||||
if conf.model_half:
|
if conf['models']['half']:
|
||||||
model.half()
|
model.half()
|
||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
|
@ -37,11 +37,11 @@ def load_data(training=True, cfg=None):
|
|||||||
if training:
|
if training:
|
||||||
dataroot = cfg['data']['data_train_dir']
|
dataroot = cfg['data']['data_train_dir']
|
||||||
transform = train_transform
|
transform = train_transform
|
||||||
# transform = conf.train_transform
|
# transform.yml = conf.train_transform
|
||||||
batch_size = cfg['data']['train_batch_size']
|
batch_size = cfg['data']['train_batch_size']
|
||||||
else:
|
else:
|
||||||
dataroot = cfg['data']['data_val_dir']
|
dataroot = cfg['data']['data_val_dir']
|
||||||
# transform = conf.test_transform
|
# transform.yml = conf.test_transform
|
||||||
transform = test_transform
|
transform = test_transform
|
||||||
batch_size = cfg['data']['val_batch_size']
|
batch_size = cfg['data']['val_batch_size']
|
||||||
|
|
||||||
@ -56,13 +56,13 @@ def load_data(training=True, cfg=None):
|
|||||||
return loader, class_num
|
return loader, class_num
|
||||||
|
|
||||||
# def load_gift_data(action):
|
# def load_gift_data(action):
|
||||||
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||||||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||||||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||||||
# return train_dataset, val_dataset, test_dataset
|
# return train_dataset, val_dataset, test_dataset
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
./quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
../quant_imgs/20179457_20240924-110903_back_addGood_b82d2842766e_80_15583929052_tid-8_fid-72_bid-3.jpg
|
||||||
./quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
../quant_imgs/6928926002103_20240309-195044_front_returnGood_70f75407ef0e_225_18120111822_14_01.jpg
|
||||||
./quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
../quant_imgs/6928926002103_20240309-212145_front_returnGood_70f75407ef0e_225_18120111822_11_01.jpg
|
||||||
./quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
../quant_imgs/6928947479083_20241017-133830_front_returnGood_5478c9a48b7e_10_13799009402_tid-1_fid-20_bid-1.jpg
|
||||||
./quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
../quant_imgs/6928947479083_20241018-110450_front_addGood_5478c9a48c28_165_13773168720_tid-6_fid-36_bid-1.jpg
|
||||||
./quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
../quant_imgs/6930044166421_20240117-141516_c6a23f41-5b16-44c6-a03e-c32c25763442_back_returnGood_6930044166421_17_01.jpg
|
||||||
./quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
../quant_imgs/6930044166421_20240308-150916_back_returnGood_70f75407ef0e_175_13815402763_7_01.jpg
|
||||||
./quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
../quant_imgs/6930044168920_20240117-165633_3303629b-5fbd-423b-913d-8a64c1aa51dc_front_addGood_6930044168920_26_01.jpg
|
||||||
./quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
../quant_imgs/6930058201507_20240305-175434_front_addGood_70f75407ef0e_95_18120111822_28_01.jpg
|
||||||
./quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
../quant_imgs/6930639267885_20241014-120446_back_addGood_5478c9a48c3e_135_13773168720_tid-5_fid-99_bid-0.jpg
|
||||||
|
@ -2,17 +2,29 @@ import pdb
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from model import resnet18
|
from model import resnet18
|
||||||
from config import config as conf
|
# from config import config as conf
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from configs import trainer_tools
|
||||||
import cv2
|
import cv2
|
||||||
|
import yaml
|
||||||
|
|
||||||
def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth'):
|
def tranform_onnx_model():
|
||||||
# 定义模型
|
# # 定义模型
|
||||||
if model_name == 'resnet18':
|
# if model_name == 'resnet18':
|
||||||
model = resnet18(scale=0.75)
|
# model = resnet18(scale=0.75)
|
||||||
|
|
||||||
print('model_name >>> {}'.format(model_name))
|
with open('../configs/transform.yml', 'r') as f:
|
||||||
if conf.multiple_cards:
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
|
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
pretrained_weights = conf['models']['model_path']
|
||||||
|
print('model_name >>> {}'.format(conf['models']['backbone']))
|
||||||
|
if conf['base']['distributed']:
|
||||||
model = model.to(torch.device('cpu'))
|
model = model.to(torch.device('cpu'))
|
||||||
checkpoint = torch.load(pretrained_weights)
|
checkpoint = torch.load(pretrained_weights)
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
@ -22,22 +34,8 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
|||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
||||||
# try:
|
|
||||||
# model.load_state_dict(torch.load(pretrained_weights, map_location=torch.device('cpu')))
|
|
||||||
# except Exception as e:
|
|
||||||
# print(e)
|
|
||||||
# # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_weights, map_location='cpu').items()})
|
|
||||||
# model = nn.DataParallel(model).to(conf.device)
|
|
||||||
# model.load_state_dict(torch.load(conf.test_model, map_location=torch.device('cpu')))
|
|
||||||
|
|
||||||
|
|
||||||
# 转换为ONNX
|
# 转换为ONNX
|
||||||
if model_name == 'gift_type2':
|
|
||||||
input_shape = [1, 64, 13, 13]
|
|
||||||
elif model_name == 'gift_type3':
|
|
||||||
input_shape = [1, 3, 224, 224]
|
|
||||||
else:
|
|
||||||
# 假设输入数据的大小是通道数*高度*宽度,例如3*224*224
|
|
||||||
input_shape = [1, 3, 224, 224]
|
input_shape = [1, 3, 224, 224]
|
||||||
|
|
||||||
img = cv2.imread('./dog_224x224.jpg')
|
img = cv2.imread('./dog_224x224.jpg')
|
||||||
@ -59,5 +57,4 @@ def tranform_onnx_model(model_name, pretrained_weights='checkpoints/v3_small.pth
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tranform_onnx_model(model_name='resnet18', # ['resnet18', 'gift_type2', 'gift_type3'] #gift_type2指resnet18中间数据判断;gift3_type3指resnet原图计算推理
|
tranform_onnx_model()
|
||||||
pretrained_weights='./checkpoints/resnet18_scale=1.0/best.pth')
|
|
||||||
|
@ -6,15 +6,14 @@ import time
|
|||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from config import config as conf
|
|
||||||
from rknn.api import RKNN
|
from rknn.api import RKNN
|
||||||
|
import yaml
|
||||||
import config
|
with open('../configs/transform.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
# ONNX_MODEL = 'resnet50v2.onnx'
|
# ONNX_MODEL = 'resnet50v2.onnx'
|
||||||
# RKNN_MODEL = 'resnet50v2.rknn'
|
# RKNN_MODEL = 'resnet50v2.rknn'
|
||||||
ONNX_MODEL = 'checkpoints/resnet18_scale=1.0/best.onnx'
|
ONNX_MODEL = conf['models']['onnx_model']
|
||||||
RKNN_MODEL = 'checkpoints/resnet18_scale=1.0/best.rknn'
|
RKNN_MODEL = conf['models']['rknn_model']
|
||||||
|
|
||||||
|
|
||||||
# ONNX_MODEL = 'v3_small_0424.onnx'
|
# ONNX_MODEL = 'v3_small_0424.onnx'
|
||||||
|
@ -50,7 +50,7 @@ class FeatureExtractor:
|
|||||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = resnet18().to(self.conf['base']['device'])
|
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||||
|
|
||||||
# Handle multi-GPU case
|
# Handle multi-GPU case
|
||||||
if conf['base']['distributed']:
|
if conf['base']['distributed']:
|
||||||
|
@ -12,7 +12,7 @@ import matplotlib.pyplot as plt
|
|||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
with open('configs/scatter.yml', 'r') as f:
|
with open('configs/compare.yml', 'r') as f:
|
||||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
# Data Setup
|
# Data Setup
|
||||||
@ -47,11 +47,11 @@ else:
|
|||||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||||
if conf['training']['optimizer'] in optimizer_mapping:
|
if conf['training']['optimizer'] in optimizer_mapping:
|
||||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||||
scheduler = optim.lr_scheduler.StepLR(
|
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||||
optimizer,
|
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||||
step_size=conf['training']['lr_step'],
|
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||||
gamma=conf['training']['lr_decay']
|
conf['training']['scheduler']))
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user