diff --git a/deploy/slim/prune/README.md b/deploy/slim/prune/README.md index bff1b78e6b583592fb699ba46c6a3740a63dae75..20d8c1e928cbc20c4bd891656e39f9d2bee0dc68 100644 --- a/deploy/slim/prune/README.md +++ b/deploy/slim/prune/README.md @@ -51,14 +51,14 @@ python setup.py install 进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析训练: ```bash -python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights="your trained model" Global.test_batch_size_per_card=1 +python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights="your trained model" Global.test_batch_size_per_card=1 ``` ### 4. 模型裁剪训练 裁剪时通过之前的敏感度分析文件决定每个网络层的裁剪比例。在具体实现时,为了尽可能多的保留从图像中提取的低阶特征,我们跳过了backbone中靠近输入的4个卷积层。同样,为了减少由于裁剪导致的模型性能损失,我们通过之前敏感度分析所获得的敏感度表,人工挑选出了一些冗余较少,对裁剪较为敏感的[网络层](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/slim/prune/pruning_and_finetune.py#L41)(指在较低的裁剪比例下就导致很高性能损失的网络层),并在之后的裁剪过程中选择避开这些网络层。裁剪过后finetune的过程沿用OCR检测模型原始的训练策略。 ```bash -python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 +python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 ``` 通过对比可以发现,经过裁剪训练保存的模型更小。 @@ -66,7 +66,7 @@ python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml - 在得到裁剪训练保存的模型后,我们可以将其导出为inference_model: ```bash -python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model +python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model ``` inference model的预测和部署参考: diff --git a/deploy/slim/prune/README_en.md b/deploy/slim/prune/README_en.md index 7adbd86c6145a6854e19cca62b9f995e07cb6b13..3136dc8ad4ac66d512003f7133b5c65918415a7f 100644 --- a/deploy/slim/prune/README_en.md +++ b/deploy/slim/prune/README_en.md @@ -55,7 +55,7 @@ Enter the PaddleOCR root directory,perform sensitivity analysis on the model w ```bash -python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 +python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 ``` @@ -67,7 +67,7 @@ python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Gl ```bash -python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 +python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 ``` @@ -76,7 +76,7 @@ python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml - We can export the pruned model as inference_model for deployment: ```bash -python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model +python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model ``` Reference for prediction and deployment of inference model: diff --git a/deploy/slim/prune/pruning_and_finetune.py b/deploy/slim/prune/pruning_and_finetune.py index 0a03cb449523232dbeedf1d7066b4ea4ba01f31f..bf54b79915591d6516798110bdf283ab36e060f6 100644 --- a/deploy/slim/prune/pruning_and_finetune.py +++ b/deploy/slim/prune/pruning_and_finetune.py @@ -92,7 +92,8 @@ def main(): sen = load_sensitivities("sensitivities_0.data") for i in skip_list: - sen.pop(i) + if i in sen.keys(): + sen.pop(i) back_bone_list = ['conv' + str(x) for x in range(1, 5)] for i in back_bone_list: for key in list(sen.keys()):