未验证 提交 ceda319a 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ch_pp-ocrv3 (#2033)

* update ch_pp-ocrv3

* update README

* update defalut valuex

* add a param
上级 cb817fd9
......@@ -58,7 +58,7 @@
```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例
- ### 2、预测代码示例
- ```python
import paddlehub as hub
......@@ -91,10 +91,11 @@
use_gpu=False,
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
box_thresh=0.6,
text_thresh=0.5,
angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5)
det_db_unclip_ratio=1.5,
det_db_score_mode="fast"):
```
- 预测API,检测输入图片中的所有中文文本的位置。
......@@ -110,6 +111,8 @@
- visualization (bool): 是否将识别结果保存为图片文件;
- output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
- det\_db\_unclip\_ratio: 设置检测框的大小;
- det\_db\_score\_mode: 设置检测得分计算方式,“fast” / “slow”
- **返回**
- res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
......@@ -166,6 +169,10 @@
初始发布
* 1.1.0
移除 Fluid API
- ```shell
$ hub install ch_pp-ocrv3==1.0.0
$ hub install ch_pp-ocrv3==1.1.0
```
......@@ -22,11 +22,7 @@ import time
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from PIL import Image
import paddlehub as hub
......@@ -35,7 +31,7 @@ from .utils import base64_to_cv2
from .utils import draw_ocr
from .utils import get_image_ext
from .utils import sorted_boxes
from paddlehub.common.logger import logger
from paddlehub.utils.utils import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
......@@ -43,15 +39,14 @@ from paddlehub.module.module import serving
@moduleinfo(
name="ch_pp-ocrv3",
version="1.0.0",
version="1.1.0",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChPPOCRv3(hub.Module):
def _initialize(self, text_detector_module=None, enable_mkldnn=False):
class ChPPOCRv3:
def __init__(self, text_detector_module=None, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
......@@ -124,7 +119,7 @@ class ChPPOCRv3(hub.Module):
if not self._text_detector_module:
self._text_detector_module = hub.Module(name='ch_pp-ocrv3_det',
enable_mkldnn=self.enable_mkldnn,
version='1.0.0')
version='1.1.0')
return self._text_detector_module
def read_images(self, paths=[]):
......@@ -210,10 +205,11 @@ class ChPPOCRv3(hub.Module):
use_gpu=False,
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
box_thresh=0.6,
text_thresh=0.5,
angle_classification_thresh=0.9,
det_db_unclip_ratio=1.5):
det_db_unclip_ratio=1.5,
det_db_score_mode="fast"):
"""
Get the chinese texts in the predicted images.
Args:
......@@ -227,6 +223,7 @@ class ChPPOCRv3(hub.Module):
text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
det_db_unclip_ratio(float): unclip ratio for post processing in DB detection.
det_db_score_mode(str): method to calc the final det score, one of fast(using box) and slow(using poly).
Returns:
res (list): The result of chinese texts and save path of images.
"""
......@@ -253,7 +250,8 @@ class ChPPOCRv3(hub.Module):
detection_results = self.text_detector_module.detect_text(images=predicted_data,
use_gpu=self.use_gpu,
box_thresh=box_thresh,
det_db_unclip_ratio=det_db_unclip_ratio)
det_db_unclip_ratio=det_db_unclip_ratio,
det_db_score_mode=det_db_score_mode)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
all_results = []
......@@ -281,7 +279,7 @@ class ChPPOCRv3(hub.Module):
rec_res_final.append({
'text': text,
'confidence': float(score),
'text_box_position': boxes[index].astype(np.int).tolist()
'text_box_position': boxes[index].astype(np.int64).tolist()
})
result['data'] = rec_res_final
......@@ -444,6 +442,7 @@ class ChPPOCRv3(hub.Module):
use_gpu=args.use_gpu,
output_dir=args.output_dir,
det_db_unclip_ratio=args.det_db_unclip_ratio,
det_db_score_mode=args.det_db_score_mode,
visualization=args.visualization)
return results
......@@ -467,6 +466,11 @@ class ChPPOCRv3(hub.Module):
type=float,
default=1.5,
help="unclip ratio for post processing in DB detection.")
self.arg_config_group.add_argument(
'--det_db_score_mode',
type=str,
default="fast",
help="method to calc the final det score, one of fast(using box) and slow(using poly).")
def add_module_input_arg(self):
"""
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="ch_pp-ocrv3")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('ocr_result')
def test_recognize_text1(self):
results = self.module.recognize_text(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text2(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text3(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text4(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text5(self):
self.assertRaises(
AttributeError,
self.module.recognize_text,
images=['tests/test.jpg']
)
def test_recognize_text6(self):
self.assertRaises(
AssertionError,
self.module.recognize_text,
paths=['no.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/model.pdiparams'))
self.assertTrue(os.path.exists('./inference/model/_text_detector_module.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/_text_detector_module.pdiparams'))
if __name__ == "__main__":
unittest.main()
......@@ -69,7 +69,7 @@ def text_visual(texts, scores, font_file, img_h=400, img_w=600, threshold=0.):
assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
blank_img = np.ones(shape=[img_h, img_w], dtype=np.uint8) * 255
blank_img[:, img_w - 1:] = 0
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册