From ceda319accdeb81df85431b4ebc439f0398e80ec Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Fri, 4 Nov 2022 17:53:14 +0800 Subject: [PATCH] update ch_pp-ocrv3 (#2033) * update ch_pp-ocrv3 * update README * update defalut valuex * add a param --- .../text_recognition/ch_pp-ocrv3/README.md | 27 ++-- .../text_recognition/ch_pp-ocrv3/module.py | 32 +++-- .../text_recognition/ch_pp-ocrv3/test.py | 120 ++++++++++++++++++ .../text_recognition/ch_pp-ocrv3/utils.py | 2 +- 4 files changed, 156 insertions(+), 25 deletions(-) create mode 100644 modules/image/text_recognition/ch_pp-ocrv3/test.py diff --git a/modules/image/text_recognition/ch_pp-ocrv3/README.md b/modules/image/text_recognition/ch_pp-ocrv3/README.md index ac3812a7..f38ef70c 100755 --- a/modules/image/text_recognition/ch_pp-ocrv3/README.md +++ b/modules/image/text_recognition/ch_pp-ocrv3/README.md @@ -58,7 +58,7 @@ ``` - 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) -- ### 2、代码示例 +- ### 2、预测代码示例 - ```python import paddlehub as hub @@ -87,14 +87,15 @@ - ```python def recognize_text(images=[], - paths=[], - use_gpu=False, - output_dir='ocr_result', - visualization=False, - box_thresh=0.5, - text_thresh=0.5, - angle_classification_thresh=0.9, - det_db_unclip_ratio=1.5) + paths=[], + use_gpu=False, + output_dir='ocr_result', + visualization=False, + box_thresh=0.6, + text_thresh=0.5, + angle_classification_thresh=0.9, + det_db_unclip_ratio=1.5, + det_db_score_mode="fast"): ``` - 预测API,检测输入图片中的所有中文文本的位置。 @@ -110,6 +111,8 @@ - visualization (bool): 是否将识别结果保存为图片文件; - output\_dir (str): 图片的保存路径,默认设为 ocr\_result; - det\_db\_unclip\_ratio: 设置检测框的大小; + - det\_db\_score\_mode: 设置检测得分计算方式,“fast” / “slow” + - **返回** - res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为: @@ -166,6 +169,10 @@ 初始发布 +* 1.1.0 + + 移除 Fluid API + - ```shell - $ hub install ch_pp-ocrv3==1.0.0 + $ hub install ch_pp-ocrv3==1.1.0 ``` diff --git a/modules/image/text_recognition/ch_pp-ocrv3/module.py b/modules/image/text_recognition/ch_pp-ocrv3/module.py index 133f8d78..f79de328 100644 --- a/modules/image/text_recognition/ch_pp-ocrv3/module.py +++ b/modules/image/text_recognition/ch_pp-ocrv3/module.py @@ -22,11 +22,7 @@ import time import cv2 import numpy as np import paddle -import paddle.fluid as fluid import paddle.inference as paddle_infer -from paddle.fluid.core import AnalysisConfig -from paddle.fluid.core import create_paddle_predictor -from paddle.fluid.core import PaddleTensor from PIL import Image import paddlehub as hub @@ -35,7 +31,7 @@ from .utils import base64_to_cv2 from .utils import draw_ocr from .utils import get_image_ext from .utils import sorted_boxes -from paddlehub.common.logger import logger +from paddlehub.utils.utils import logger from paddlehub.module.module import moduleinfo from paddlehub.module.module import runnable from paddlehub.module.module import serving @@ -43,15 +39,14 @@ from paddlehub.module.module import serving @moduleinfo( name="ch_pp-ocrv3", - version="1.0.0", + version="1.1.0", summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \ based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ", author="paddle-dev", author_email="paddle-dev@baidu.com", type="cv/text_recognition") -class ChPPOCRv3(hub.Module): - - def _initialize(self, text_detector_module=None, enable_mkldnn=False): +class ChPPOCRv3: + def __init__(self, text_detector_module=None, enable_mkldnn=False): """ initialize with the necessary elements """ @@ -124,7 +119,7 @@ class ChPPOCRv3(hub.Module): if not self._text_detector_module: self._text_detector_module = hub.Module(name='ch_pp-ocrv3_det', enable_mkldnn=self.enable_mkldnn, - version='1.0.0') + version='1.1.0') return self._text_detector_module def read_images(self, paths=[]): @@ -210,10 +205,11 @@ class ChPPOCRv3(hub.Module): use_gpu=False, output_dir='ocr_result', visualization=False, - box_thresh=0.5, + box_thresh=0.6, text_thresh=0.5, angle_classification_thresh=0.9, - det_db_unclip_ratio=1.5): + det_db_unclip_ratio=1.5, + det_db_score_mode="fast"): """ Get the chinese texts in the predicted images. Args: @@ -227,6 +223,7 @@ class ChPPOCRv3(hub.Module): text_thresh(float): the threshold of the chinese text recognition confidence angle_classification_thresh(float): the threshold of the angle classification confidence det_db_unclip_ratio(float): unclip ratio for post processing in DB detection. + det_db_score_mode(str): method to calc the final det score, one of fast(using box) and slow(using poly). Returns: res (list): The result of chinese texts and save path of images. """ @@ -253,7 +250,8 @@ class ChPPOCRv3(hub.Module): detection_results = self.text_detector_module.detect_text(images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh, - det_db_unclip_ratio=det_db_unclip_ratio) + det_db_unclip_ratio=det_db_unclip_ratio, + det_db_score_mode=det_db_score_mode) boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] all_results = [] @@ -281,7 +279,7 @@ class ChPPOCRv3(hub.Module): rec_res_final.append({ 'text': text, 'confidence': float(score), - 'text_box_position': boxes[index].astype(np.int).tolist() + 'text_box_position': boxes[index].astype(np.int64).tolist() }) result['data'] = rec_res_final @@ -444,6 +442,7 @@ class ChPPOCRv3(hub.Module): use_gpu=args.use_gpu, output_dir=args.output_dir, det_db_unclip_ratio=args.det_db_unclip_ratio, + det_db_score_mode=args.det_db_score_mode, visualization=args.visualization) return results @@ -467,6 +466,11 @@ class ChPPOCRv3(hub.Module): type=float, default=1.5, help="unclip ratio for post processing in DB detection.") + self.arg_config_group.add_argument( + '--det_db_score_mode', + type=str, + default="fast", + help="method to calc the final det score, one of fast(using box) and slow(using poly).") def add_module_input_arg(self): """ diff --git a/modules/image/text_recognition/ch_pp-ocrv3/test.py b/modules/image/text_recognition/ch_pp-ocrv3/test.py new file mode 100644 index 00000000..a61f54e4 --- /dev/null +++ b/modules/image/text_recognition/ch_pp-ocrv3/test.py @@ -0,0 +1,120 @@ +import os +import shutil +import unittest + +import cv2 +import requests +import paddlehub as hub + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestHubModule(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + cls.module = hub.Module(name="ch_pp-ocrv3") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + shutil.rmtree('ocr_result') + + def test_recognize_text1(self): + results = self.module.recognize_text( + paths=['tests/test.jpg'], + use_gpu=False, + visualization=False, + ) + self.assertEqual(results[0]['data'], [ + { + 'text': 'GIVE.', 'confidence': 0.9509806632995605, + 'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]] + }, + { + 'text': 'THANKS', 'confidence': 0.9943129420280457, + 'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]] + }]) + + def test_recognize_text2(self): + results = self.module.recognize_text( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=False, + ) + self.assertEqual(results[0]['data'], [ + { + 'text': 'GIVE.', 'confidence': 0.9509806632995605, + 'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]] + }, + { + 'text': 'THANKS', 'confidence': 0.9943129420280457, + 'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]] + }]) + + def test_recognize_text3(self): + results = self.module.recognize_text( + images=[cv2.imread('tests/test.jpg')], + use_gpu=True, + visualization=False, + ) + self.assertEqual(results[0]['data'], [ + { + 'text': 'GIVE.', 'confidence': 0.9509806632995605, + 'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]] + }, + { + 'text': 'THANKS', 'confidence': 0.9943129420280457, + 'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]] + }]) + + def test_recognize_text4(self): + results = self.module.recognize_text( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=True, + ) + self.assertEqual(results[0]['data'], [ + { + 'text': 'GIVE.', 'confidence': 0.9509806632995605, + 'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]] + }, + { + 'text': 'THANKS', 'confidence': 0.9943129420280457, + 'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]] + }]) + + def test_recognize_text5(self): + self.assertRaises( + AttributeError, + self.module.recognize_text, + images=['tests/test.jpg'] + ) + + def test_recognize_text6(self): + self.assertRaises( + AssertionError, + self.module.recognize_text, + paths=['no.jpg'] + ) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model/model.pdiparams')) + + self.assertTrue(os.path.exists('./inference/model/_text_detector_module.pdmodel')) + self.assertTrue(os.path.exists('./inference/model/_text_detector_module.pdiparams')) + + +if __name__ == "__main__": + unittest.main() diff --git a/modules/image/text_recognition/ch_pp-ocrv3/utils.py b/modules/image/text_recognition/ch_pp-ocrv3/utils.py index d6309f2f..b547a3f2 100644 --- a/modules/image/text_recognition/ch_pp-ocrv3/utils.py +++ b/modules/image/text_recognition/ch_pp-ocrv3/utils.py @@ -69,7 +69,7 @@ def text_visual(texts, scores, font_file, img_h=400, img_w=600, threshold=0.): assert len(texts) == len(scores), "The number of txts and corresponding scores must match" def create_blank_img(): - blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255 + blank_img = np.ones(shape=[img_h, img_w], dtype=np.uint8) * 255 blank_img[:, img_w - 1:] = 0 blank_img = Image.fromarray(blank_img).convert("RGB") draw_txt = ImageDraw.Draw(blank_img) -- GitLab