From 268ed03b3ec01af8ff2de9f9329659022c13bc5e Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Mon, 11 May 2020 19:59:07 +0800 Subject: [PATCH] update detection doc and infer code --- doc/detection.md | 16 +++++++-- ppocr/data/det/db_process.py | 3 -- ppocr/data/reader_main.py | 6 ---- tools/eval_utils/eval_det_iou.py | 4 +++ tools/eval_utils/eval_det_utils.py | 8 +++++ tools/infer/utility.py | 8 ----- tools/infer_det.py | 54 ++++++++++++++++++++++++++++-- tools/program.py | 16 --------- 8 files changed, 76 insertions(+), 39 deletions(-) diff --git a/doc/detection.md b/doc/detection.md index b8b8a9cc..6124b74f 100644 --- a/doc/detection.md +++ b/doc/detection.md @@ -32,8 +32,8 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ## 3.2 快速启动训练 -首先下载pretrain model,目前支持两种backbone,分别是MobileNetV3、ResNet50,您可以根据需求使用PaddleClas中的模型更换 -backbone。 +首先下载pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd, +您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。 ``` # 下载MobileNetV3的预训练模型 wget -P /PaddleOCR/pretrained_model/ 模型链接 @@ -63,7 +63,17 @@ PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall 运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。 ``` -python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints ./output/best_accuracy +python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints="./output/best_accuracy" ``` ## 3.4 测试检测效果 + +测试单张图像的检测效果 +``` +python3 tools/infer_det.py -c config/det/det_db_mv3.yml -o TestReader.single_img_path="./demo.jpg" +``` + +测试文件夹下所有图像的检测效果 +``` +python3 tools/infer_det.py -c config/det/det_db_mv3.yml -o TestReader.single_img_path="./demo_img/" +``` diff --git a/ppocr/data/det/db_process.py b/ppocr/data/det/db_process.py index 2a6393a1..faca9ac2 100644 --- a/ppocr/data/det/db_process.py +++ b/ppocr/data/det/db_process.py @@ -124,9 +124,6 @@ class DBProcessTest(object): def resize_image_type0(self, im): """ resize image to a size multiple of 32 which is required by the network - :param im: the resized image - :param max_side_len: limit of max image size to avoid out of memory in gpu - :return: the resized image and the resize ratio """ max_side_len = self.max_side_len h, w, _ = im.shape diff --git a/ppocr/data/reader_main.py b/ppocr/data/reader_main.py index 323620bc..55bd1e08 100755 --- a/ppocr/data/reader_main.py +++ b/ppocr/data/reader_main.py @@ -73,9 +73,3 @@ def reader_main(config=None, mode=None): return paddle.reader.multiprocess_reader(readers, False) else: return function(mode) - - -def test_reader(image_shape, img_path): - img = cv2.imread(img_path) - norm_img = process_image(img, image_shape) - return norm_img diff --git a/tools/eval_utils/eval_det_iou.py b/tools/eval_utils/eval_det_iou.py index c6dacb3e..2f5ff2f1 100644 --- a/tools/eval_utils/eval_det_iou.py +++ b/tools/eval_utils/eval_det_iou.py @@ -3,6 +3,10 @@ from collections import namedtuple import numpy as np from shapely.geometry import Polygon +""" +reference from : +https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8 +""" class DetectionIoUEvaluator(object): diff --git a/tools/eval_utils/eval_det_utils.py b/tools/eval_utils/eval_det_utils.py index 015cba99..f0be714f 100644 --- a/tools/eval_utils/eval_det_utils.py +++ b/tools/eval_utils/eval_det_utils.py @@ -98,6 +98,14 @@ def load_label_infor(label_file_path, do_ignore=False): def cal_det_metrics(gt_label_path, save_res_path): + """ + calculate the detection metrics + Args: + gt_label_path(string): The groundtruth detection label file path + save_res_path(string): The saved predicted detection label path + return: + claculated metrics including Hmean、precision and recall + """ evaluator = DetectionIoUEvaluator() gt_label_infor = load_label_infor(gt_label_path, do_ignore=True) dt_label_infor = load_label_infor(save_res_path) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index e8c76de9..f1f7a8a0 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -99,15 +99,7 @@ def create_predictor(args, mode): config.disable_gpu() config.disable_glog_info() - # config.switch_ir_optim(args.ir_optim) - # if args.use_tensorrt: - # config.enable_tensorrt_engine( - # precision_mode=AnalysisConfig.Precision.Half - # if args.use_fp16 else AnalysisConfig.Precision.Float32, - # max_batch_size=args.batch_size) - - # config.enable_memory_optim() # use zero copy config.switch_use_feed_fetch_ops(False) predictor = create_paddle_predictor(config) diff --git a/tools/infer_det.py b/tools/infer_det.py index d616323d..7998cdb6 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -34,7 +34,7 @@ def set_paddle_flags(**kwargs): # NOTE(paddle-dev): All of these flags should be # set before `import paddle`. Otherwise, it would -# not take any effect. +# not take any effect. set_paddle_flags( FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory ) @@ -44,6 +44,7 @@ from ppocr.utils.utility import create_module import program from ppocr.utils.save_load import init_model from ppocr.data.reader_main import reader_main +import cv2 from ppocr.utils.utility import initial_logger logger = initial_logger() @@ -67,6 +68,50 @@ def draw_det_res(dt_boxes, config, img_name, ino): logger.info("The detected Image saved in {}".format(save_path)) +def simple_reader(img_file, config): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + + batch_size = config['Global']['test_batch_size_per_card'] + global_params = config['Global'] + params = deepcopy(config['TestReader']) + params.update(global_params) + reader_function = params['process_function'] + process_function = create_module(reader_function)(params) + + def batch_iter_reader(): + batch_outs = [] + for img_path in imgs_lists: + img = cv2.imread(img_path) + if img.shape[-1] == 1 or len(list(img.shape)) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + if img is None: + logger.info("load image error:" + img_path) + continue + outs = process_function(img) + outs.append(os.path.basename(img_path)) + print(outs[0].shape, outs[2]) + batch_outs.append(outs) + if len(batch_outs) == batch_size: + yield batch_outs + batch_outs = [] + if len(batch_outs) != 0: + yield batch_outs + + return batch_iter_reader + + def main(): config = program.load_config(FLAGS.config) program.merge_config(FLAGS.opt) @@ -103,7 +148,9 @@ def main(): save_res_path = config['Global']['save_res_path'] with open(save_res_path, "wb") as fout: - test_reader = reader_main(config=config, mode='test') + # test_reader = reader_main(config=config, mode='test') + single_img_path = config['TestReader']['single_img_path'] + test_reader = simple_reader(img_file=single_img_path, config=config) tackling_num = 0 for data in test_reader(): img_num = len(data) @@ -116,6 +163,7 @@ def main(): img_list.append(data[ino][0]) ratio_list.append(data[ino][1]) img_name_list.append(data[ino][2]) + img_list = np.concatenate(img_list, axis=0) outs = exe.run(eval_prog,\ feed={'image': img_list},\ @@ -126,7 +174,7 @@ def main(): postprocess_params.update(global_params) postprocess = create_module(postprocess_params['function'])\ (params=postprocess_params) - dt_boxes_list = postprocess(outs, ratio_list) + dt_boxes_list = postprocess({"maps": outs[0]}, ratio_list) for ino in range(img_num): dt_boxes = dt_boxes_list[ino] img_name = img_name_list[ino] diff --git a/tools/program.py b/tools/program.py index a34e56ca..f74aacc7 100755 --- a/tools/program.py +++ b/tools/program.py @@ -185,22 +185,6 @@ def build(config, main_prog, startup_prog, mode): def build_export(config, main_prog, startup_prog): """ - Build a program using a model and an optimizer - 1. create feeds - 2. create a dataloader - 3. create a model - 4. create fetchs - 5. create an optimizer - - Args: - config(dict): config - main_prog(): main program - startup_prog(): startup program - is_train(bool): train or valid - - Returns: - dataloader(): a bridge between the model and the data - fetchs(dict): dict of model outputs(included loss and measures) """ with fluid.program_guard(main_prog, startup_prog): with fluid.unique_name.guard(): -- GitLab