From 061820c34f4fc8d0fb1e772e7d41a226956b4462 Mon Sep 17 00:00:00 2001 From: lee <770918727@qq.com> Date: Thu, 19 Jun 2025 17:36:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/transform.yml | 11 ++++++----- configs/write_feature.yml | 6 +++--- model/resnet_pre.py | 2 +- tools/model_rknn_transform.py | 7 +++++-- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/configs/transform.yml b/configs/transform.yml index d03e296..35a853f 100644 --- a/configs/transform.yml +++ b/configs/transform.yml @@ -14,11 +14,12 @@ base: # 模型配置 models: - backbone: 'resnet50' - channel_ratio: 1.0 - model_path: "../checkpoints/resnet50_0519/best.pth" - onnx_model: "../checkpoints/resnet50_0519/best.onnx" - rknn_model: "../checkpoints/resnet50_0519/best.rknn" + backbone: 'resnet18' + channel_ratio: 0.75 + model_path: "../checkpoints/resnet18_1009/best.pth" + onnx_model: "../checkpoints/resnet18_1009/best.onnx" + rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2.rknn" + rknn_batch_size: 1 # 日志与监控 logging: diff --git a/configs/write_feature.yml b/configs/write_feature.yml index c01576a..4485f3e 100644 --- a/configs/write_feature.yml +++ b/configs/write_feature.yml @@ -22,7 +22,7 @@ data: test_batch_size: 128 # 验证批次大小 num_workers: 32 # 数据加载线程数 half: true # 是否启用半精度数据 - img_dirs_path: "/personalDocument/lic/contrast_base" + img_dirs_path: "/personalDocument/lic/base+stlib" # img_dirs_path: "/home/lc/contrast_nettest/data/feature_json" xlsx_pth: false # 过滤商品, 默认None不进行过滤 @@ -42,7 +42,7 @@ logging: save: json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件 - json_path: "../feature_json/" # 保存单个json文件 + json_path: "../feature_json/base+stlib/" # 保存单个json文件路径 error_barcodes: "error_barcodes.txt" barcodes_statistics: "../search_library/barcodes_statistics.txt" - create_single_json: true \ No newline at end of file + create_single_json: true # 是否保存单个json文件 \ No newline at end of file diff --git a/model/resnet_pre.py b/model/resnet_pre.py index 724d3e7..d645617 100644 --- a/model/resnet_pre.py +++ b/model/resnet_pre.py @@ -205,7 +205,7 @@ class ResNet(nn.Module): if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer - print("ResNet scale: >>>>>>>>>> ", scale) + print("通道剪枝 {}".format(scale)) self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: diff --git a/tools/model_rknn_transform.py b/tools/model_rknn_transform.py index e2a9df9..074e838 100644 --- a/tools/model_rknn_transform.py +++ b/tools/model_rknn_transform.py @@ -99,7 +99,8 @@ if __name__ == '__main__': target_platform='rk3588', model_pruning=False, compress_weight=False, - single_core_mode=True) + single_core_mode=True, + enable_flash_attention=True) # rknn.config( # mean_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]] # std_values=[[127.5, 127.5, 127.5]], # 对于单通道图像,可以设置为 [[127.5]] @@ -121,7 +122,9 @@ if __name__ == '__main__': # Build model print('--> Building model') - ret = rknn.build(do_quantization=True, dataset='./dataset.txt') + ret = rknn.build(do_quantization=True, + dataset='./dataset.txt', + rknn_batch_size=conf['models']['rknn_batch_size']) # ret = rknn.build(do_quantization=False, dataset='./dataset.txt') if ret != 0: print('Build model failed!')