From 7dccfe57a08a2fb0f0d798f51f70b09886c4ea65 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 17 Nov 2021 12:32:05 +0800 Subject: [PATCH] improve system prediction and remove some hard code (#4643) * fix center yaml * rm init_center param * fix typo * improve pred system --- .../ch_PP-OCRv2_rec_enhanced_ctc_loss.yml | 3 +- doc/doc_ch/models_list.md | 4 +-- doc/doc_en/models_list_en.md | 4 +-- ppocr/losses/center_loss.py | 10 ++---- tools/infer/predict_system.py | 36 +++++++++++-------- tools/infer/utility.py | 8 ++++- 6 files changed, 37 insertions(+), 28 deletions(-) diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml index 71612030..5be96969 100644 --- a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml +++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml @@ -62,8 +62,7 @@ Loss: weight: 0.05 num_classes: 6625 feat_dim: 96 - init_center: false - center_file_path: "./train_center.pkl" + center_file_path: # you can also try to add ace loss on your own dataset # - ACELoss: # weight: 0.1 diff --git a/doc/doc_ch/models_list.md b/doc/doc_ch/models_list.md index 31ab6a2c..8f1a53bc 100644 --- a/doc/doc_ch/models_list.md +++ b/doc/doc_ch/models_list.md @@ -33,8 +33,8 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训 |模型名称|模型简介|配置文件|推理模型大小|下载地址| | --- | --- | --- | --- | --- | -|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| -|ch_PP-OCRv2_det|【最新】原始超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| +|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| +|ch_PP-OCRv2_det|【最新】原始超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| |ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)| 2.6M |[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)| |ch_ppocr_mobile_v2.0_det|原始超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)| |ch_ppocr_server_v2.0_det|通用模型,支持中英文、多语种文本检测,比超轻量模型更大,但效果更好|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)| diff --git a/doc/doc_en/models_list_en.md b/doc/doc_en/models_list_en.md index dbb48602..e3cf251c 100644 --- a/doc/doc_en/models_list_en.md +++ b/doc/doc_en/models_list_en.md @@ -29,8 +29,8 @@ Relationship of the above models is as follows. |model name|description|config|model size|download| | --- | --- | --- | --- | --- | -|ch_PP-OCRv2_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| -|ch_PP-OCRv2_det|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| +|ch_PP-OCRv2_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| +|ch_PP-OCRv2_det|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| |ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|2.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)| |ch_ppocr_mobile_v2.0_det|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)| |ch_ppocr_server_v2.0_det|General model, which is larger than the lightweight model, but achieved better performance|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)| diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py index f8c57fdd..f62b8af3 100644 --- a/ppocr/losses/center_loss.py +++ b/ppocr/losses/center_loss.py @@ -30,21 +30,17 @@ class CenterLoss(nn.Layer): Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. """ - def __init__(self, - num_classes=6625, - feat_dim=96, - init_center=False, - center_file_path=None): + def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None): super().__init__() self.num_classes = num_classes self.feat_dim = feat_dim self.centers = paddle.randn( shape=[self.num_classes, self.feat_dim]).astype("float64") - if init_center: + if center_file_path is not None: assert os.path.exists( center_file_path - ), f"center path({center_file_path}) must exist when init_center is set as True." + ), f"center path({center_file_path}) must exist when it is not None." with open(center_file_path, 'rb') as f: char_dict = pickle.load(f) for key in char_dict.keys(): diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index b5edd015..8d674809 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -49,11 +49,19 @@ class TextSystem(object): if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) - def print_draw_crop_rec_res(self, img_crop_list, rec_res): + self.args = args + self.crop_image_res_index = 0 + + def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res): + os.makedirs(output_dir, exist_ok=True) bbox_num = len(img_crop_list) for bno in range(bbox_num): - cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) - logger.info(bno, rec_res[bno]) + cv2.imwrite( + os.path.join(output_dir, + f"mg_crop_{bno+self.crop_image_res_index}.jpg"), + img_crop_list[bno]) + logger.debug(f"{bno}, {rec_res[bno]}") + self.crop_image_res_index += bbox_num def __call__(self, img, cls=True): ori_im = img.copy() @@ -80,7 +88,9 @@ class TextSystem(object): rec_res, elapse = self.text_recognizer(img_crop_list) logger.debug("rec_res num : {}, elapse : {}".format( len(rec_res), elapse)) - # self.print_draw_crop_rec_res(img_crop_list, rec_res) + if self.args.save_crop_res: + self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, + rec_res) filter_boxes, filter_rec_res = [], [] for box, rec_reuslt in zip(dt_boxes, rec_res): text, score = rec_reuslt @@ -135,17 +145,17 @@ def main(args): if not flag: img = cv2.imread(image_file) if img is None: - logger.info("error in loading image:{}".format(image_file)) + logger.debug("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime total_time += elapse - logger.info( + logger.debug( str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: - logger.info("{}, {:.3f}".format(text, score)) + logger.debug("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) @@ -160,19 +170,17 @@ def main(args): scores, drop_score=drop_score, font_path=font_path) - draw_img_save = "./inference_results/" - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) + draw_img_save_dir = args.draw_img_save_dir + os.makedirs(draw_img_save_dir, exist_ok=True) if flag: image_file = image_file[:-3] + "png" cv2.imwrite( - os.path.join(draw_img_save, os.path.basename(image_file)), + os.path.join(draw_img_save_dir, os.path.basename(image_file)), draw_img[:, :, ::-1]) - logger.info("The visualized image saved in {}".format( - os.path.join(draw_img_save, os.path.basename(image_file)))) + logger.debug("The visualized image saved in {}".format( + os.path.join(draw_img_save_dir, os.path.basename(image_file)))) logger.info("The predict total time is {}".format(time.time() - _st)) - logger.info("\nThe predict total time is {}".format(total_time)) if args.benchmark: text_sys.text_detector.autolog.report() text_sys.text_recognizer.autolog.report() diff --git a/tools/infer/utility.py b/tools/infer/utility.py index cab91841..85f68d9b 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -110,7 +110,13 @@ def init_args(): parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) - parser.add_argument("--warmup", type=str2bool, default=True) + parser.add_argument("--warmup", type=str2bool, default=False) + + # + parser.add_argument( + "--draw_img_save_dir", type=str, default="./inference_results") + parser.add_argument("--save_crop_res", type=str2bool, default=False) + parser.add_argument("--crop_res_save_dir", type=str, default="./output") # multi-process parser.add_argument("--use_mp", type=str2bool, default=False) -- GitLab