diff --git a/deploy/slim/prune/eval_det_utils.py b/deploy/slim/prune/eval_det_utils.py deleted file mode 100644 index d41490ef73cba49b592951427e6bcbe26d15fa6c..0000000000000000000000000000000000000000 --- a/deploy/slim/prune/eval_det_utils.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import sys -import logging -import numpy as np -import paddle.fluid as fluid - -__dir__ = os.path.dirname(__file__) -sys.path.append(__dir__) -sys.path.append(os.path.join(__dir__, '..', '..', '..')) - -__all__ = ['eval_det_run'] - -import logging -FORMAT = '%(asctime)s-%(levelname)s: %(message)s' -logging.basicConfig(level=logging.INFO, format=FORMAT) -logger = logging.getLogger(__name__) - -import cv2 -import json -from copy import deepcopy -from ppocr.utils.utility import create_module -from ppocr.data.reader_main import reader_main -from tools.eval_utils.eval_det_iou import DetectionIoUEvaluator - - -def cal_det_res(exe, config, eval_info_dict): - global_params = config['Global'] - save_res_path = global_params['save_res_path'] - postprocess_params = deepcopy(config["PostProcess"]) - postprocess_params.update(global_params) - postprocess = create_module(postprocess_params['function']) \ - (params=postprocess_params) - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - with open(save_res_path, "wb") as fout: - tackling_num = 0 - for data in eval_info_dict['reader'](): - img_num = len(data) - tackling_num = tackling_num + img_num - logger.info("test tackling num:%d", tackling_num) - img_list = [] - ratio_list = [] - img_name_list = [] - for ino in range(img_num): - img_list.append(data[ino][0]) - ratio_list.append(data[ino][1]) - img_name_list.append(data[ino][2]) - try: - img_list = np.concatenate(img_list, axis=0) - except: - err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \ - Please set \"test_batch_size_per_card\" in main yml as 1\n \ - or add \"test_image_shape: [h, w]\" in reader yml for EvalReader." - - raise Exception(err) - outs = exe.run(eval_info_dict['program'], \ - feed={'image': img_list}, \ - fetch_list=eval_info_dict['fetch_varname_list']) - outs_dict = {} - for tno in range(len(outs)): - fetch_name = eval_info_dict['fetch_name_list'][tno] - fetch_value = np.array(outs[tno]) - outs_dict[fetch_name] = fetch_value - dt_boxes_list = postprocess(outs_dict, ratio_list) - for ino in range(img_num): - dt_boxes = dt_boxes_list[ino] - img_name = img_name_list[ino] - dt_boxes_json = [] - for box in dt_boxes: - tmp_json = {"transcription": ""} - tmp_json['points'] = box.tolist() - dt_boxes_json.append(tmp_json) - otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n" - fout.write(otstr.encode()) - return - - -def load_label_infor(label_file_path, do_ignore=False): - img_name_label_dict = {} - with open(label_file_path, "rb") as fin: - lines = fin.readlines() - for line in lines: - substr = line.decode().strip("\n").split("\t") - bbox_infor = json.loads(substr[1]) - bbox_num = len(bbox_infor) - for bno in range(bbox_num): - text = bbox_infor[bno]['transcription'] - ignore = False - if text == "###" and do_ignore: - ignore = True - bbox_infor[bno]['ignore'] = ignore - img_name_label_dict[os.path.basename(substr[0])] = bbox_infor - return img_name_label_dict - - -def cal_det_metrics(gt_label_path, save_res_path): - """ - calculate the detection metrics - Args: - gt_label_path(string): The groundtruth detection label file path - save_res_path(string): The saved predicted detection label path - return: - claculated metrics including Hmean, precision and recall - """ - evaluator = DetectionIoUEvaluator() - gt_label_infor = load_label_infor(gt_label_path, do_ignore=True) - dt_label_infor = load_label_infor(save_res_path) - results = [] - for img_name in gt_label_infor: - gt_label = gt_label_infor[img_name] - if img_name not in dt_label_infor: - dt_label = [] - else: - dt_label = dt_label_infor[img_name] - result = evaluator.evaluate_image(gt_label, dt_label) - results.append(result) - methodMetrics = evaluator.combine_results(results) - return methodMetrics - - -def eval_det_run(eval_args, mode='eval'): - exe = eval_args['exe'] - config = eval_args['config'] - eval_info_dict = eval_args['eval_info_dict'] - cal_det_res(exe, config, eval_info_dict) - - save_res_path = config['Global']['save_res_path'] - if mode == "eval": - gt_label_path = config['EvalReader']['label_file_path'] - metrics = cal_det_metrics(gt_label_path, save_res_path) - else: - gt_label_path = config['TestReader']['label_file_path'] - do_eval = config['TestReader']['do_eval'] - if do_eval: - metrics = cal_det_metrics(gt_label_path, save_res_path) - else: - metrics = {} - return metrics['hmean'] diff --git a/deploy/slim/prune/pruning_and_finetune.py b/deploy/slim/prune/pruning_and_finetune.py index 86baf02661eae892b2e57e2725a7897c1029b4cd..d0b657240270816ca4d4d1aed1849ce516d3afeb 100644 --- a/deploy/slim/prune/pruning_and_finetune.py +++ b/deploy/slim/prune/pruning_and_finetune.py @@ -104,14 +104,6 @@ def main(): # compile program for multi-devices init_model(config, train_program, exe) - # params = get_pruned_params(train_program) - ''' - sens_file = ['sensitivities_'+ str(x) for x in range(0,4)] - sens = [] - for f in sens_file: - sens.append(load_sensitivities(f+'.data')) - sen = merge_sensitive(sens) - ''' sen = load_sensitivities("sensitivities_0.data") for i in skip_list: sen.pop(i) @@ -161,28 +153,7 @@ def main(): program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) -def test_reader(): - config = program.load_config(FLAGS.config) - program.merge_config(FLAGS.opt) - print(config) - train_reader = reader_main(config=config, mode="train") - import time - starttime = time.time() - count = 0 - try: - for data in train_reader(): - count += 1 - if count % 1 == 0: - batch_time = time.time() - starttime - starttime = time.time() - print("reader:", count, len(data), batch_time) - except Exception as e: - logger.info(e) - logger.info("finish reader: {}, Success!".format(count)) - - if __name__ == '__main__': parser = program.ArgsParser() FLAGS = parser.parse_args() main() -# test_reader() diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index beaeebede383933ef9394cfcbe6dc245bdebc853..2d51655f5b559471ebc9799c53a1daf13b4ea53c 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -42,7 +42,7 @@ import cv2 from paddle import fluid import paddleslim as slim from copy import deepcopy -from eval_det_utils import eval_det_run +from tools.eval_utils.eval_det_utils import eval_det_run from tools import program from ppocr.utils.utility import initial_logger @@ -65,6 +65,14 @@ def get_pruned_params(program): return params +def eval_function(eval_args, mode='eval'): + exe = eval_args['exe'] + config = eval_args['config'] + eval_info_dict = eval_args['eval_info_dict'] + metrics = eval_det_run(exe, config, eval_info_dict, mode=mode) + return metrics['hmean'] + + def main(): config = program.load_config(FLAGS.config) program.merge_config(FLAGS.opt) @@ -99,7 +107,7 @@ def main(): 'fetch_varname_list':eval_fetch_varname_list} eval_args = dict() eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict} - metrics = eval_det_run(eval_args) + metrics = eval_function(eval_args) print("Baseline: {}".format(metrics)) params = get_pruned_params(eval_program) @@ -108,7 +116,7 @@ def main(): eval_program, place, params, - eval_det_run, + eval_function, sensitivities_file="sensitivities_0.data", pruned_ratios=[0.1], eval_args=eval_args,