未验证 提交 ceda319a 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ch_pp-ocrv3 (#2033)

* update ch_pp-ocrv3

* update README

* update defalut valuex

* add a param
上级 cb817fd9
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
``` ```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) - 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例 - ### 2、预测代码示例
- ```python - ```python
import paddlehub as hub import paddlehub as hub
...@@ -87,14 +87,15 @@ ...@@ -87,14 +87,15 @@
- ```python - ```python
def recognize_text(images=[], def recognize_text(images=[],
paths=[], paths=[],
use_gpu=False, use_gpu=False,
output_dir='ocr_result', output_dir='ocr_result',
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.6,
text_thresh=0.5, text_thresh=0.5,
angle_classification_thresh=0.9, angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5) det_db_unclip_ratio=1.5,
det_db_score_mode="fast"):
``` ```
- 预测API,检测输入图片中的所有中文文本的位置。 - 预测API,检测输入图片中的所有中文文本的位置。
...@@ -110,6 +111,8 @@ ...@@ -110,6 +111,8 @@
- visualization (bool): 是否将识别结果保存为图片文件; - visualization (bool): 是否将识别结果保存为图片文件;
- output\_dir (str): 图片的保存路径,默认设为 ocr\_result; - output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
- det\_db\_unclip\_ratio: 设置检测框的大小; - det\_db\_unclip\_ratio: 设置检测框的大小;
- det\_db\_score\_mode: 设置检测得分计算方式,“fast” / “slow”
- **返回** - **返回**
- res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为: - res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
...@@ -166,6 +169,10 @@ ...@@ -166,6 +169,10 @@
初始发布 初始发布
* 1.1.0
移除 Fluid API
- ```shell - ```shell
$ hub install ch_pp-ocrv3==1.0.0 $ hub install ch_pp-ocrv3==1.1.0
``` ```
...@@ -22,11 +22,7 @@ import time ...@@ -22,11 +22,7 @@ import time
import cv2 import cv2
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from PIL import Image from PIL import Image
import paddlehub as hub import paddlehub as hub
...@@ -35,7 +31,7 @@ from .utils import base64_to_cv2 ...@@ -35,7 +31,7 @@ from .utils import base64_to_cv2
from .utils import draw_ocr from .utils import draw_ocr
from .utils import get_image_ext from .utils import get_image_ext
from .utils import sorted_boxes from .utils import sorted_boxes
from paddlehub.common.logger import logger from paddlehub.utils.utils import logger
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable from paddlehub.module.module import runnable
from paddlehub.module.module import serving from paddlehub.module.module import serving
...@@ -43,15 +39,14 @@ from paddlehub.module.module import serving ...@@ -43,15 +39,14 @@ from paddlehub.module.module import serving
@moduleinfo( @moduleinfo(
name="ch_pp-ocrv3", name="ch_pp-ocrv3",
version="1.0.0", version="1.1.0",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \ summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ", based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChPPOCRv3(hub.Module): class ChPPOCRv3:
def __init__(self, text_detector_module=None, enable_mkldnn=False):
def _initialize(self, text_detector_module=None, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -124,7 +119,7 @@ class ChPPOCRv3(hub.Module): ...@@ -124,7 +119,7 @@ class ChPPOCRv3(hub.Module):
if not self._text_detector_module: if not self._text_detector_module:
self._text_detector_module = hub.Module(name='ch_pp-ocrv3_det', self._text_detector_module = hub.Module(name='ch_pp-ocrv3_det',
enable_mkldnn=self.enable_mkldnn, enable_mkldnn=self.enable_mkldnn,
version='1.0.0') version='1.1.0')
return self._text_detector_module return self._text_detector_module
def read_images(self, paths=[]): def read_images(self, paths=[]):
...@@ -210,10 +205,11 @@ class ChPPOCRv3(hub.Module): ...@@ -210,10 +205,11 @@ class ChPPOCRv3(hub.Module):
use_gpu=False, use_gpu=False,
output_dir='ocr_result', output_dir='ocr_result',
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.6,
text_thresh=0.5, text_thresh=0.5,
angle_classification_thresh=0.9, angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5): det_db_unclip_ratio=1.5,
det_db_score_mode="fast"):
""" """
Get the chinese texts in the predicted images. Get the chinese texts in the predicted images.
Args: Args:
...@@ -227,6 +223,7 @@ class ChPPOCRv3(hub.Module): ...@@ -227,6 +223,7 @@ class ChPPOCRv3(hub.Module):
text_thresh(float): the threshold of the chinese text recognition confidence text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence angle_classification_thresh(float): the threshold of the angle classification confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection. det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
det_db_score_mode(str): method to calc the final det score, one of fast(using box) and slow(using poly).
Returns: Returns:
res (list): The result of chinese texts and save path of images. res (list): The result of chinese texts and save path of images.
""" """
...@@ -253,7 +250,8 @@ class ChPPOCRv3(hub.Module): ...@@ -253,7 +250,8 @@ class ChPPOCRv3(hub.Module):
detection_results = self.text_detector_module.detect_text(images=predicted_data, detection_results = self.text_detector_module.detect_text(images=predicted_data,
use_gpu=self.use_gpu, use_gpu=self.use_gpu,
box_thresh=box_thresh, box_thresh=box_thresh,
det_db_unclip_ratio=det_db_unclip_ratio) det_db_unclip_ratio=det_db_unclip_ratio,
det_db_score_mode=det_db_score_mode)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
all_results = [] all_results = []
...@@ -281,7 +279,7 @@ class ChPPOCRv3(hub.Module): ...@@ -281,7 +279,7 @@ class ChPPOCRv3(hub.Module):
rec_res_final.append({ rec_res_final.append({
'text': text, 'text': text,
'confidence': float(score), 'confidence': float(score),
'text_box_position': boxes[index].astype(np.int).tolist() 'text_box_position': boxes[index].astype(np.int64).tolist()
}) })
result['data'] = rec_res_final result['data'] = rec_res_final
...@@ -444,6 +442,7 @@ class ChPPOCRv3(hub.Module): ...@@ -444,6 +442,7 @@ class ChPPOCRv3(hub.Module):
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
output_dir=args.output_dir, output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio, det_db_unclip_ratio=args.det_db_unclip_ratio,
det_db_score_mode=args.det_db_score_mode,
visualization=args.visualization) visualization=args.visualization)
return results return results
...@@ -467,6 +466,11 @@ class ChPPOCRv3(hub.Module): ...@@ -467,6 +466,11 @@ class ChPPOCRv3(hub.Module):
type=float, type=float,
default=1.5, default=1.5,
help="unclip ratio for post processing in DB detection.") help="unclip ratio for post processing in DB detection.")
self.arg_config_group.add_argument(
'--det_db_score_mode',
type=str,
default="fast",
help="method to calc the final det score, one of fast(using box) and slow(using poly).")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="ch_pp-ocrv3")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('ocr_result')
def test_recognize_text1(self):
results = self.module.recognize_text(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text2(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text3(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text4(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text5(self):
self.assertRaises(
AttributeError,
self.module.recognize_text,
images=['tests/test.jpg']
)
def test_recognize_text6(self):
self.assertRaises(
AssertionError,
self.module.recognize_text,
paths=['no.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/model.pdiparams'))
self.assertTrue(os.path.exists('./inference/model/_text_detector_module.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/_text_detector_module.pdiparams'))
if __name__ == "__main__":
unittest.main()
...@@ -69,7 +69,7 @@ def text_visual(texts, scores, font_file, img_h=400, img_w=600, threshold=0.): ...@@ -69,7 +69,7 @@ def text_visual(texts, scores, font_file, img_h=400, img_w=600, threshold=0.):
assert len(texts) == len(scores), "The number of txts and corresponding scores must match" assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
def create_blank_img(): def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255 blank_img = np.ones(shape=[img_h, img_w], dtype=np.uint8) * 255
blank_img[:, img_w - 1:] = 0 blank_img[:, img_w - 1:] = 0
blank_img = Image.fromarray(blank_img).convert("RGB") blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img) draw_txt = ImageDraw.Draw(blank_img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册