diff --git a/configs/det/det_db_icdar15_reader.yml b/configs/det/det_db_icdar15_reader.yml index 388cd3184538a344fa42b01625d02cc8a0634df1..0f99257b53a366ccdb2521ca742198adfe3ff556 100755 --- a/configs/det/det_db_icdar15_reader.yml +++ b/configs/det/det_db_icdar15_reader.yml @@ -15,7 +15,7 @@ EvalReader: TestReader: reader_function: ppocr.data.det.dataset_traversal,EvalTestReader process_function: ppocr.data.det.db_process,DBProcessTest - single_img_path: + infer_img: img_set_dir: ./train_data/icdar2015/text_localization/ label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt test_image_shape: [736, 1280] diff --git a/configs/det/det_east_icdar15_reader.yml b/configs/det/det_east_icdar15_reader.yml index 478bfcd834b8090988a4187e71bd74beca9fa095..060ed4dd380d0457574c1d20be3225c7fd188108 100755 --- a/configs/det/det_east_icdar15_reader.yml +++ b/configs/det/det_east_icdar15_reader.yml @@ -17,7 +17,7 @@ EvalReader: TestReader: reader_function: ppocr.data.det.dataset_traversal,EvalTestReader process_function: ppocr.data.det.east_process,EASTProcessTest - single_img_path: + infer_img: img_set_dir: ./train_data/icdar2015/text_localization/ label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt do_eval: True diff --git a/configs/rec/rec_benchmark_reader.yml b/configs/rec/rec_benchmark_reader.yml index 43cd514ca9f8e54c54d0a84b48c9b44347aaef0e..3d1e3e0b22ce04573c73f51cbef26133415b9aa3 100755 --- a/configs/rec/rec_benchmark_reader.yml +++ b/configs/rec/rec_benchmark_reader.yml @@ -10,4 +10,4 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,LMDBReader lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ - infer_img: + infer_img: ./infer_img diff --git a/doc/detection.md b/doc/detection.md index fce534d1b406d9fb06d94d12c831de22ff70a481..2fbe3c427f726c3661bd945c5aa310a874f2a1d2 100644 --- a/doc/detection.md +++ b/doc/detection.md @@ -79,10 +79,10 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./ou 测试单张图像的检测效果 ``` -python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.single_img_path="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" +python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" ``` 测试文件夹下所有图像的检测效果 ``` -python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.single_img_path="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy" +python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy" ``` diff --git a/doc/inference.md b/doc/inference.md index 371d2f87d7feb364a1028bae2cf70af5a7c21d16..ca1d6af5adfbf608be3bbd226c2ab0d16c6b7d2e 100644 --- a/doc/inference.md +++ b/doc/inference.md @@ -200,7 +200,7 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model 如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令: ``` -python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/rec/" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" +python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" ``` 执行命令后,识别结果图像如下: diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 3bd931857866e79d04f1fd70d07a368d744dbf84..3051c60d370e532248dc45497792f26017f0f337 100755 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -84,7 +84,7 @@ class EvalTestReader(object): img_path = os.path.join(img_set_dir, img_name) img_list.append(img_path) else: - img_path = self.params['single_img_path'] + img_path = self.params['infer_img'] img_list = get_image_file_list(img_path) def batch_iter_reader(): diff --git a/tools/eval.py b/tools/eval.py index fcb572c65271270bd2b06a8d731b317b99dde6ed..d1762e9294d43a141694d14cca7d401ab134c972 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -78,6 +78,7 @@ def main(): 'fetch_name_list':eval_fetch_name_list,\ 'fetch_varname_list':eval_fetch_varname_list} metrics = eval_det_run(exe, config, eval_info_dict, "eval") + print("Eval result", metrics) else: reader_type = config['Global']['reader_yml'] if "benchmark" not in reader_type: diff --git a/tools/eval_utils/eval_det_utils.py b/tools/eval_utils/eval_det_utils.py index f0be714fc04f4afe2abf5be6f6e5fcb3a9803a66..252c93641e278e426b47cba760cb429ee0a4c93b 100644 --- a/tools/eval_utils/eval_det_utils.py +++ b/tools/eval_utils/eval_det_utils.py @@ -34,6 +34,7 @@ import json from copy import deepcopy import cv2 from ppocr.data.reader_main import reader_main +import os def cal_det_res(exe, config, eval_info_dict): @@ -43,6 +44,8 @@ def cal_det_res(exe, config, eval_info_dict): postprocess_params.update(global_params) postprocess = create_module(postprocess_params['function']) \ (params=postprocess_params) + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) with open(save_res_path, "wb") as fout: tackling_num = 0 for data in eval_info_dict['reader'](): @@ -93,7 +96,7 @@ def load_label_infor(label_file_path, do_ignore=False): if text == "###" and do_ignore: ignore = True bbox_infor[bno]['ignore'] = ignore - img_name_label_dict[substr[0]] = bbox_infor + img_name_label_dict[os.path.basename(substr[0])] = bbox_infor return img_name_label_dict diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c8c0797b499055ec681d0362a054a30d7322b65a..6fa51e70d4353158a8d537b6b51e7df5400b0e0c 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -31,6 +31,7 @@ class TextRecognizer(object): image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.rec_image_shape = image_shape self.character_type = args.rec_char_type + self.rec_batch_num = args.rec_batch_num char_ops_params = {} char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_dict_path"] = args.rec_char_dict_path @@ -59,8 +60,8 @@ class TextRecognizer(object): def __call__(self, img_list): img_num = len(img_list) - batch_num = 30 rec_res = [] + batch_num = self.rec_batch_num predict_time = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 8b0abd706cd60a23753a4b53915a0fe5fdaa66e4..3953fa0df76c3929a52af39901f7eeb0569c68b2 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -89,7 +89,7 @@ def sorted_boxes(dt_boxes): sorted boxes(array) with shape [4, 2] """ num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1]) + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) _boxes = list(sorted_boxes) for i in range(num_boxes - 1): diff --git a/tools/infer/utility.py b/tools/infer/utility.py index fd72a19f6dac1bc4ae79565e9a85134e2cd6945d..947a5495779f56af50944fbfcb44f2c6a4dc94ca 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -56,6 +56,7 @@ def parse_args(): parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument( "--rec_char_dict_path", type=str, @@ -172,7 +173,8 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): continue font = ImageFont.truetype( "./doc/simfang.ttf", font_size, encoding="utf-8") - new_txt = str(count) + ': ' + txt + ' ' + '%.3f' % (scores[count]) + new_txt = str(count) + ': ' + txt + ' ' + '%.3f' % ( + scores[count]) draw_txt.text( (20, gap * (count + 1)), new_txt, txt_color, font=font) count += 1 diff --git a/tools/infer_det.py b/tools/infer_det.py index 9da617d1fa0039db39efbaaa913f545956524c94..800067655a1da29f4cd517207db24c2aac15c810 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -106,7 +106,6 @@ def main(): with open(save_res_path, "wb") as fout: test_reader = reader_main(config=config, mode='test') - # image_file_list = get_image_file_list(args.image_dir) tackling_num = 0 for data in test_reader(): img_num = len(data) @@ -135,7 +134,7 @@ def main(): elif config['Global']['algorithm'] == 'DB': dic = {'maps': outs[0]} else: - raise Exception("only support algorithm: ['EAST', 'BD']") + raise Exception("only support algorithm: ['EAST', 'DB']") dt_boxes_list = postprocess(dic, ratio_list) for ino in range(img_num): dt_boxes = dt_boxes_list[ino]