diff --git a/deploy/slim/prune/README.md b/deploy/slim/prune/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f28d2be01be6ae896956aa543a9bead85596609c --- /dev/null +++ b/deploy/slim/prune/README.md @@ -0,0 +1,40 @@ +> 运行示例前请先安装develop版本PaddleSlim + +# 模型裁剪压缩教程 + +## 概述 + +该示例使用PaddleSlim提供的[裁剪压缩API](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)对OCR模型进行压缩。 +在阅读该示例前,建议您先了解以下内容: + +- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md) +- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) + +## 安装PaddleSlim +可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。 + + + +## 敏感度分析训练 + +进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析: + +```bash +python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 +``` + +## 裁剪模型与fine-tune + +```bash +python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 +``` + + + +## 评估并导出 + +在得到裁剪训练保存的模型后,我们可以将其导出为inference_model,用于预测部署: + +```bash +python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model +``` diff --git a/deploy/slim/prune/eval_det_utils.py b/deploy/slim/prune/eval_det_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d41490ef73cba49b592951427e6bcbe26d15fa6c --- /dev/null +++ b/deploy/slim/prune/eval_det_utils.py @@ -0,0 +1,156 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import logging +import numpy as np +import paddle.fluid as fluid + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..', '..', '..')) + +__all__ = ['eval_det_run'] + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +import cv2 +import json +from copy import deepcopy +from ppocr.utils.utility import create_module +from ppocr.data.reader_main import reader_main +from tools.eval_utils.eval_det_iou import DetectionIoUEvaluator + + +def cal_det_res(exe, config, eval_info_dict): + global_params = config['Global'] + save_res_path = global_params['save_res_path'] + postprocess_params = deepcopy(config["PostProcess"]) + postprocess_params.update(global_params) + postprocess = create_module(postprocess_params['function']) \ + (params=postprocess_params) + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + with open(save_res_path, "wb") as fout: + tackling_num = 0 + for data in eval_info_dict['reader'](): + img_num = len(data) + tackling_num = tackling_num + img_num + logger.info("test tackling num:%d", tackling_num) + img_list = [] + ratio_list = [] + img_name_list = [] + for ino in range(img_num): + img_list.append(data[ino][0]) + ratio_list.append(data[ino][1]) + img_name_list.append(data[ino][2]) + try: + img_list = np.concatenate(img_list, axis=0) + except: + err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \ + Please set \"test_batch_size_per_card\" in main yml as 1\n \ + or add \"test_image_shape: [h, w]\" in reader yml for EvalReader." + + raise Exception(err) + outs = exe.run(eval_info_dict['program'], \ + feed={'image': img_list}, \ + fetch_list=eval_info_dict['fetch_varname_list']) + outs_dict = {} + for tno in range(len(outs)): + fetch_name = eval_info_dict['fetch_name_list'][tno] + fetch_value = np.array(outs[tno]) + outs_dict[fetch_name] = fetch_value + dt_boxes_list = postprocess(outs_dict, ratio_list) + for ino in range(img_num): + dt_boxes = dt_boxes_list[ino] + img_name = img_name_list[ino] + dt_boxes_json = [] + for box in dt_boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_json.append(tmp_json) + otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n" + fout.write(otstr.encode()) + return + + +def load_label_infor(label_file_path, do_ignore=False): + img_name_label_dict = {} + with open(label_file_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + substr = line.decode().strip("\n").split("\t") + bbox_infor = json.loads(substr[1]) + bbox_num = len(bbox_infor) + for bno in range(bbox_num): + text = bbox_infor[bno]['transcription'] + ignore = False + if text == "###" and do_ignore: + ignore = True + bbox_infor[bno]['ignore'] = ignore + img_name_label_dict[os.path.basename(substr[0])] = bbox_infor + return img_name_label_dict + + +def cal_det_metrics(gt_label_path, save_res_path): + """ + calculate the detection metrics + Args: + gt_label_path(string): The groundtruth detection label file path + save_res_path(string): The saved predicted detection label path + return: + claculated metrics including Hmean, precision and recall + """ + evaluator = DetectionIoUEvaluator() + gt_label_infor = load_label_infor(gt_label_path, do_ignore=True) + dt_label_infor = load_label_infor(save_res_path) + results = [] + for img_name in gt_label_infor: + gt_label = gt_label_infor[img_name] + if img_name not in dt_label_infor: + dt_label = [] + else: + dt_label = dt_label_infor[img_name] + result = evaluator.evaluate_image(gt_label, dt_label) + results.append(result) + methodMetrics = evaluator.combine_results(results) + return methodMetrics + + +def eval_det_run(eval_args, mode='eval'): + exe = eval_args['exe'] + config = eval_args['config'] + eval_info_dict = eval_args['eval_info_dict'] + cal_det_res(exe, config, eval_info_dict) + + save_res_path = config['Global']['save_res_path'] + if mode == "eval": + gt_label_path = config['EvalReader']['label_file_path'] + metrics = cal_det_metrics(gt_label_path, save_res_path) + else: + gt_label_path = config['TestReader']['label_file_path'] + do_eval = config['TestReader']['do_eval'] + if do_eval: + metrics = cal_det_metrics(gt_label_path, save_res_path) + else: + metrics = {} + return metrics['hmean'] diff --git a/deploy/slim/prune/export_prune_model.py b/deploy/slim/prune/export_prune_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ca4abc5e34830d86da8b12652a262fee7c4a56 --- /dev/null +++ b/deploy/slim/prune/export_prune_model.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..', '..', '..')) +sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import program +from paddle import fluid +from ppocr.utils.utility import initial_logger +logger = initial_logger() +from ppocr.utils.save_load import init_model +from paddleslim.prune import load_model + + +def main(): + startup_prog, eval_program, place, config, _ = program.preprocess() + + feeded_var_names, target_vars, fetches_var_name = program.build_export( + config, eval_program, startup_prog) + eval_program = eval_program.clone(for_test=True) + exe = fluid.Executor(place) + exe.run(startup_prog) + + if config['Global']['checkpoints'] is not None: + path = config['Global']['checkpoints'] + else: + path = config['Global']['pretrain_weights'] + + load_model(exe, eval_program, path) + + save_inference_dir = config['Global']['save_inference_dir'] + if not os.path.exists(save_inference_dir): + os.makedirs(save_inference_dir) + fluid.io.save_inference_model( + dirname=save_inference_dir, + feeded_var_names=feeded_var_names, + main_program=eval_program, + target_vars=target_vars, + executor=exe, + model_filename='model', + params_filename='params') + print("inference model saved in {}/model and {}/params".format( + save_inference_dir, save_inference_dir)) + print("save success, output_name_list:", fetches_var_name) + + +if __name__ == '__main__': + main() diff --git a/deploy/slim/prune/pruning_and_finetune.py b/deploy/slim/prune/pruning_and_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..86baf02661eae892b2e57e2725a7897c1029b4cd --- /dev/null +++ b/deploy/slim/prune/pruning_and_finetune.py @@ -0,0 +1,188 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import numpy as np +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..', '..', '..')) +sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import tools.program as program +from paddle import fluid +from ppocr.utils.utility import initial_logger +logger = initial_logger() +from ppocr.data.reader_main import reader_main +from ppocr.utils.save_load import init_model +from ppocr.utils.character import CharacterOps +from ppocr.utils.utility import initial_logger +from paddleslim.prune import Pruner, save_model +from paddleslim.analysis import flops +from paddleslim.core.graph_wrapper import * +from paddleslim.prune import load_sensitivities, get_ratios_by_loss, merge_sensitive +logger = initial_logger() + +skip_list = [ + 'conv10_linear_weights', 'conv11_linear_weights', 'conv12_expand_weights', + 'conv12_linear_weights', 'conv12_se_2_weights', 'conv13_linear_weights', + 'conv2_linear_weights', 'conv4_linear_weights', 'conv5_expand_weights', + 'conv5_linear_weights', 'conv5_se_2_weights', 'conv6_linear_weights', + 'conv7_linear_weights', 'conv8_expand_weights', 'conv8_linear_weights', + 'conv9_expand_weights', 'conv9_linear_weights' +] + + +def main(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + logger.info(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + program.check_gpu(use_gpu) + + alg = config['Global']['algorithm'] + assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] + if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: + config['Global']['char_ops'] = CharacterOps(config['Global']) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + startup_program = fluid.Program() + train_program = fluid.Program() + train_build_outputs = program.build( + config, train_program, startup_program, mode='train') + train_loader = train_build_outputs[0] + train_fetch_name_list = train_build_outputs[1] + train_fetch_varname_list = train_build_outputs[2] + train_opt_loss_name = train_build_outputs[3] + + eval_program = fluid.Program() + eval_build_outputs = program.build( + config, eval_program, startup_program, mode='eval') + eval_fetch_name_list = eval_build_outputs[1] + eval_fetch_varname_list = eval_build_outputs[2] + eval_program = eval_program.clone(for_test=True) + + train_reader = reader_main(config=config, mode="train") + train_loader.set_sample_list_generator(train_reader, places=place) + + eval_reader = reader_main(config=config, mode="eval") + + exe = fluid.Executor(place) + exe.run(startup_program) + + # compile program for multi-devices + init_model(config, train_program, exe) + + # params = get_pruned_params(train_program) + ''' + sens_file = ['sensitivities_'+ str(x) for x in range(0,4)] + sens = [] + for f in sens_file: + sens.append(load_sensitivities(f+'.data')) + sen = merge_sensitive(sens) + ''' + sen = load_sensitivities("sensitivities_0.data") + for i in skip_list: + sen.pop(i) + back_bone_list = ['conv' + str(x) for x in range(1, 5)] + for i in back_bone_list: + for key in list(sen.keys()): + if i + '_' in key: + sen.pop(key) + ratios = get_ratios_by_loss(sen, 0.03) + logger.info("FLOPs before pruning: {}".format(flops(eval_program))) + pruner = Pruner(criterion='geometry_median') + print("ratios: {}".format(ratios)) + pruned_val_program, _, _ = pruner.prune( + eval_program, + fluid.global_scope(), + params=ratios.keys(), + ratios=ratios.values(), + place=place, + only_graph=True) + + pruned_program, _, _ = pruner.prune( + train_program, + fluid.global_scope(), + params=ratios.keys(), + ratios=ratios.values(), + place=place) + logger.info("FLOPs after pruning: {}".format(flops(pruned_val_program))) + train_compile_program = program.create_multi_devices_program( + pruned_program, train_opt_loss_name) + + + train_info_dict = {'compile_program':train_compile_program,\ + 'train_program':pruned_program,\ + 'reader':train_loader,\ + 'fetch_name_list':train_fetch_name_list,\ + 'fetch_varname_list':train_fetch_varname_list} + + eval_info_dict = {'program':pruned_val_program,\ + 'reader':eval_reader,\ + 'fetch_name_list':eval_fetch_name_list,\ + 'fetch_varname_list':eval_fetch_varname_list} + + if alg in ['EAST', 'DB']: + program.train_eval_det_run( + config, exe, train_info_dict, eval_info_dict, is_pruning=True) + else: + program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) + + +def test_reader(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + print(config) + train_reader = reader_main(config=config, mode="train") + import time + starttime = time.time() + count = 0 + try: + for data in train_reader(): + count += 1 + if count % 1 == 0: + batch_time = time.time() - starttime + starttime = time.time() + print("reader:", count, len(data), batch_time) + except Exception as e: + logger.info(e) + logger.info("finish reader: {}, Success!".format(count)) + + +if __name__ == '__main__': + parser = program.ArgsParser() + FLAGS = parser.parse_args() + main() +# test_reader() diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py new file mode 100644 index 0000000000000000000000000000000000000000..beaeebede383933ef9394cfcbe6dc245bdebc853 --- /dev/null +++ b/deploy/slim/prune/sensitivity_anal.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..', '..', '..')) +sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import json +import cv2 +from paddle import fluid +import paddleslim as slim +from copy import deepcopy +from eval_det_utils import eval_det_run + +from tools import program +from ppocr.utils.utility import initial_logger +from ppocr.data.reader_main import reader_main +from ppocr.utils.save_load import init_model +from ppocr.utils.character import CharacterOps +from ppocr.utils.utility import create_module +from ppocr.data.reader_main import reader_main + +logger = initial_logger() + + +def get_pruned_params(program): + params = [] + for param in program.global_block().all_parameters(): + if len( + param.shape + ) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name: + params.append(param.name) + return params + + +def main(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + logger.info(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + program.check_gpu(use_gpu) + + alg = config['Global']['algorithm'] + assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] + if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: + config['Global']['char_ops'] = CharacterOps(config['Global']) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + startup_prog = fluid.Program() + eval_program = fluid.Program() + eval_build_outputs = program.build( + config, eval_program, startup_prog, mode='test') + eval_fetch_name_list = eval_build_outputs[1] + eval_fetch_varname_list = eval_build_outputs[2] + eval_program = eval_program.clone(for_test=True) + exe = fluid.Executor(place) + exe.run(startup_prog) + + init_model(config, eval_program, exe) + + eval_reader = reader_main(config=config, mode="eval") + eval_info_dict = {'program':eval_program,\ + 'reader':eval_reader,\ + 'fetch_name_list':eval_fetch_name_list,\ + 'fetch_varname_list':eval_fetch_varname_list} + eval_args = dict() + eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict} + metrics = eval_det_run(eval_args) + print("Baseline: {}".format(metrics)) + + params = get_pruned_params(eval_program) + print('Start to analyze') + sens_0 = slim.prune.sensitivity( + eval_program, + place, + params, + eval_det_run, + sensitivities_file="sensitivities_0.data", + pruned_ratios=[0.1], + eval_args=eval_args, + criterion='geometry_median') + + +if __name__ == '__main__': + parser = program.ArgsParser() + FLAGS = parser.parse_args() + main() diff --git a/tools/program.py b/tools/program.py index 56f6b6993022d092095d7c3545f9ea8833900bc6..7c0177336a3161d6039248f3bad6b63b47406878 100755 --- a/tools/program.py +++ b/tools/program.py @@ -33,6 +33,7 @@ from eval_utils.eval_rec_utils import eval_rec_run from ppocr.utils.save_load import save_model import numpy as np from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps +import paddleslim as slim class ArgsParser(ArgumentParser): @@ -238,7 +239,11 @@ def create_multi_devices_program(program, loss_var_name): return compile_program -def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): +def train_eval_det_run(config, + exe, + train_info_dict, + eval_info_dict, + is_pruning=False): train_batch_id = 0 log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num'] @@ -294,7 +299,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): best_batch_id = train_batch_id best_epoch = epoch save_path = save_model_dir + "/best_accuracy" - save_model(train_info_dict['train_program'], save_path) + if is_pruning: + slim.prune.save_model( + exe, train_info_dict['train_program'], + save_path) + else: + save_model(train_info_dict['train_program'], + save_path) strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format( train_batch_id, metrics, best_eval_hmean, best_epoch, best_batch_id) @@ -305,10 +316,18 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): train_loader.reset() if epoch == 0 and save_epoch_step == 1: save_path = save_model_dir + "/iter_epoch_0" - save_model(train_info_dict['train_program'], save_path) + if is_pruning: + slim.prune.save_model(exe, train_info_dict['train_program'], + save_path) + else: + save_model(train_info_dict['train_program'], save_path) if epoch > 0 and epoch % save_epoch_step == 0: save_path = save_model_dir + "/iter_epoch_%d" % (epoch) - save_model(train_info_dict['train_program'], save_path) + if is_pruning: + slim.prune.save_model(exe, train_info_dict['train_program'], + save_path) + else: + save_model(train_info_dict['train_program'], save_path) return