未验证 提交 3d13232e 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update chinese_ocr_db_crnn_mobile (#2169)

* update chinese_ocr_db_crnn_mobile

* update README
上级 b0a93a2e
......@@ -69,7 +69,7 @@
```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例
- ### 2、预测代码示例
- ```python
import paddlehub as hub
......@@ -170,6 +170,9 @@
print(r.json()["results"])
```
- ### Gradio App 支持
从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/chinese_ocr_db_crnn_mobile 在浏览器中访问 chinese_ocr_db_crnn_mobile 的 Gradio App。
## 五、更新历史
* 1.0.0
......@@ -190,16 +193,20 @@
* 1.1.1
支持文本中空格识别。
支持文本中空格识别。
* 1.1.2
修复只能检出30字段问题。
修复只能检出30字段问题。
* 1.1.3
移除 fluid api
移除 fluid api
* 1.2.0
适配 PaddleHub 2.x 添加 Gradio APP
- ```shell
$ hub install chinese_ocr_db_crnn_mobile==1.1.3
$ hub install chinese_ocr_db_crnn_mobile==1.2.0
```
......@@ -171,6 +171,9 @@
print(r.json()["results"])
```
- ### Gradio APP support
Starting with PaddleHub 2.3.1, the Gradio APP for chinese_ocr_db_crnn_mobile is supported to be accessed in the browser using the link http://127.0.0.1:8866/gradio/chinese_ocr_db_crnn_mobile.
## V. Release Note
* 1.0.0
......@@ -191,15 +194,19 @@
* 1.1.1
Supports recognition of spaces in text.
Supports recognition of spaces in text.
* 1.1.2
Fixed an issue where only 30 fields can be detected.
Fixed an issue where only 30 fields can be detected.
* 1.1.3
Remove fluid api
Remove fluid api
* 1.2.0
Support PaddleHub 2.x version. Add Gradio APP support.
- ```shell
$ hub install chinese_ocr_db_crnn_mobile==1.1.3
......
......@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string
import numpy as np
import string
class CharacterOps(object):
......
# -*- coding:utf-8 -*-
import argparse
import ast
import copy
......@@ -9,33 +8,33 @@ import time
import cv2
import numpy as np
import paddle
from chinese_ocr_db_crnn_mobile.character import CharacterOps
from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2
from chinese_ocr_db_crnn_mobile.utils import draw_ocr
from chinese_ocr_db_crnn_mobile.utils import get_image_ext
from chinese_ocr_db_crnn_mobile.utils import sorted_boxes
from paddle.inference import Config
from paddle.inference import create_predictor
from PIL import Image
import paddlehub as hub
from paddlehub.common.logger import logger
from .character import CharacterOps
from .utils import base64_to_cv2
from .utils import draw_ocr
from .utils import get_image_ext
from .utils import sorted_boxes
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
from paddlehub.utils.log import logger
@moduleinfo(
name="chinese_ocr_db_crnn_mobile",
version="1.1.3",
version="1.2.0",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseOCRDBCRNN(hub.Module):
class ChineseOCRDBCRNN:
def _initialize(self, text_detector_module=None, enable_mkldnn=False):
def __init__(self, text_detector_module=None, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
......@@ -53,8 +52,8 @@ class ChineseOCRDBCRNN(hub.Module):
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.enable_mkldnn = enable_mkldnn
self.rec_pretrained_model_path = os.path.join(self.directory, 'inference_model', 'character_rec')
self.cls_pretrained_model_path = os.path.join(self.directory, 'inference_model', 'angle_cls')
self.rec_pretrained_model_path = os.path.join(self.directory, 'inference_model', 'character_rec', 'model')
self.cls_pretrained_model_path = os.path.join(self.directory, 'inference_model', 'angle_cls', 'model')
self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config(
self.rec_pretrained_model_path)
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
......@@ -64,8 +63,8 @@ class ChineseOCRDBCRNN(hub.Module):
"""
predictor config path
"""
model_file_path = os.path.join(pretrained_model_path, 'model')
params_file_path = os.path.join(pretrained_model_path, 'params')
model_file_path = pretrained_model_path + '.pdmodel'
params_file_path = pretrained_model_path + '.pdiparams'
config = Config(model_file_path, params_file_path)
try:
......@@ -107,8 +106,7 @@ class ChineseOCRDBCRNN(hub.Module):
"""
if not self._text_detector_module:
self._text_detector_module = hub.Module(name='chinese_text_detection_db_mobile',
enable_mkldnn=self.enable_mkldnn,
version='1.0.4')
enable_mkldnn=self.enable_mkldnn)
return self._text_detector_module
def read_images(self, paths=[]):
......@@ -407,63 +405,46 @@ class ChineseOCRDBCRNN(hub.Module):
return rec_res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
def save_inference_model(self, dirname):
detector_dir = os.path.join(dirname, 'text_detector')
classifier_dir = os.path.join(dirname, 'angle_classifier')
recognizer_dir = os.path.join(dirname, 'text_recognizer')
self._save_detector_model(detector_dir, model_filename, params_filename, combined)
self._save_classifier_model(classifier_dir, model_filename, params_filename, combined)
self._save_recognizer_model(recognizer_dir, model_filename, params_filename, combined)
self._save_detector_model(detector_dir)
self._save_classifier_model(classifier_dir)
self._save_recognizer_model(recognizer_dir)
logger.info("The inference model has been saved in the path {}".format(os.path.realpath(dirname)))
def _save_detector_model(self, dirname, model_filename=None, params_filename=None, combined=True):
self.text_detector_module.save_inference_model(dirname, model_filename, params_filename, combined)
def _save_detector_model(self, dirname):
self.text_detector_module.save_inference_model(dirname)
def _save_recognizer_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
def _save_recognizer_model(self, dirname):
place = paddle.CPUPlace()
exe = paddle.Executor(place)
model_file_path = os.path.join(self.rec_pretrained_model_path, 'model')
params_file_path = os.path.join(self.rec_pretrained_model_path, 'params')
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.rec_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
exe = paddle.static.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(self.rec_pretrained_model_path,
executor=exe)
global_block = program.global_block()
feed_vars = [global_block.var(item) for item in feeded_var_names]
paddle.static.save_inference_model(dirname,
feed_vars=feed_vars,
fetch_vars=target_vars,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
def _save_classifier_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
program=program)
def _save_classifier_model(self, dirname):
place = paddle.CPUPlace()
exe = paddle.Executor(place)
model_file_path = os.path.join(self.cls_pretrained_model_path, 'model')
params_file_path = os.path.join(self.cls_pretrained_model_path, 'params')
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.cls_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
exe = paddle.static.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(self.cls_pretrained_model_path,
executor=exe)
global_block = program.global_block()
feed_vars = [global_block.var(item) for item in feeded_var_names]
paddle.static.save_inference_model(dirname,
feed_vars=feed_vars,
fetch_vars=target_vars,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
program=program)
@runnable
def run_cmd(self, argvs):
......@@ -511,3 +492,25 @@ class ChineseOCRDBCRNN(hub.Module):
Add the command input options
"""
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
def create_gradio_app(self):
import gradio as gr
def inference(image, use_gpu=False, box_thresh=0.5, text_thresh=0.5, angle_classification_thresh=0.9):
return self.recognize_text(paths=[image],
use_gpu=use_gpu,
output_dir=None,
visualization=False,
box_thresh=box_thresh,
text_thresh=text_thresh,
angle_classification_thresh=angle_classification_thresh)
return gr.Interface(inference, [
gr.Image(type='filepath'),
gr.Checkbox(),
gr.Slider(0, 1.0, 0.5, step=0.01),
gr.Slider(0, 1.0, 0.5, step=0.01),
gr.Slider(0, 1.0, 0.5, step=0.01)
], [gr.JSON(label='results')],
title='chinese_ocr_db_crnn_mobile',
allow_flagging=False)
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="chinese_ocr_db_crnn_mobile")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('ocr_result')
def test_recognize_text1(self):
results = self.module.recognize_text(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [{
'text': 'GIVE',
'confidence': 0.9329867362976074,
'text_box_position': [[282, 163], [351, 163], [351, 200], [282, 200]]
}, {
'text': 'THANKS',
'confidence': 0.9966865181922913,
'text_box_position': [[259, 201], [376, 199], [376, 238], [259, 240]]
}])
def test_recognize_text2(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [{
'text': 'GIVE',
'confidence': 0.9329867362976074,
'text_box_position': [[282, 163], [351, 163], [351, 200], [282, 200]]
}, {
'text': 'THANKS',
'confidence': 0.9966865181922913,
'text_box_position': [[259, 201], [376, 199], [376, 238], [259, 240]]
}])
def test_recognize_text3(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(results[0]['data'], [{
'text': 'GIVE',
'confidence': 0.9329867362976074,
'text_box_position': [[282, 163], [351, 163], [351, 200], [282, 200]]
}, {
'text': 'THANKS',
'confidence': 0.9966865181922913,
'text_box_position': [[259, 201], [376, 199], [376, 238], [259, 240]]
}])
def test_recognize_text4(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(results[0]['data'], [{
'text': 'GIVE',
'confidence': 0.9329867362976074,
'text_box_position': [[282, 163], [351, 163], [351, 200], [282, 200]]
}, {
'text': 'THANKS',
'confidence': 0.9966865181922913,
'text_box_position': [[259, 201], [376, 199], [376, 238], [259, 240]]
}])
def test_recognize_text5(self):
self.assertRaises(AttributeError, self.module.recognize_text, images=['tests/test.jpg'])
def test_recognize_text6(self):
self.assertRaises(AssertionError, self.module.recognize_text, paths=['no.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model/angle_classifier.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/angle_classifier.pdiparams'))
self.assertTrue(os.path.exists('./inference/model/text_detector.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/text_detector.pdiparams'))
self.assertTrue(os.path.exists('./inference/model/text_recognizer.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/text_recognizer.pdiparams'))
if __name__ == "__main__":
unittest.main()
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import math
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
import base64
import cv2
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
def draw_ocr(image, boxes, txts, scores, font_file, draw_txt=True, drop_score=0.5):
......@@ -174,4 +176,17 @@ def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
if data is None:
buf = BytesIO()
image_decode = base64.b64decode(b64str.encode('utf8'))
image = BytesIO(image_decode)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
data = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册