未验证 提交 aa9f30fa 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update multi_languages_ocr_db_crnn (#2034)

* update multi_languages_ocr_db_crnn

* fix a version bug
上级 4889b32a
...@@ -218,6 +218,11 @@ ...@@ -218,6 +218,11 @@
* 1.0.0 * 1.0.0
初始发布 初始发布
* 1.1.0
移除 Fluid API
- ```shell - ```shell
$ hub install multi_languages_ocr_db_crnn==1.0.0 $ hub install multi_languages_ocr_db_crnn==1.1.0
``` ```
import argparse import argparse
import sys
import os import os
import ast import ast
import paddle import paddle
import paddle.static
import paddle2onnx import paddle2onnx
import paddle2onnx as p2o import paddle2onnx as p2o
import paddle.fluid as fluid
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
from paddleocr.ppocr.utils.logging import get_logger from paddleocr.ppocr.utils.logging import get_logger
from paddleocr.tools.infer.utility import base64_to_cv2 from paddleocr.tools.infer.utility import base64_to_cv2
...@@ -17,7 +16,7 @@ from .utils import read_images, save_result_image, mkdir ...@@ -17,7 +16,7 @@ from .utils import read_images, save_result_image, mkdir
@moduleinfo( @moduleinfo(
name="multi_languages_ocr_db_crnn", name="multi_languages_ocr_db_crnn",
version="1.0.0", version="1.1.0",
summary="ocr service", summary="ocr service",
author="PaddlePaddle", author="PaddlePaddle",
type="cv/text_recognition") type="cv/text_recognition")
...@@ -45,8 +44,6 @@ class MultiLangOCR: ...@@ -45,8 +44,6 @@ class MultiLangOCR:
""" """
self.lang = lang self.lang = lang
self.logger = get_logger() self.logger = get_logger()
argc = len(sys.argv)
if argc == 1 or argc > 1 and sys.argv[1] == 'serving':
self.det = det self.det = det
self.rec = rec self.rec = rec
self.use_angle_cls = use_angle_cls self.use_angle_cls = use_angle_cls
...@@ -189,7 +186,7 @@ class MultiLangOCR: ...@@ -189,7 +186,7 @@ class MultiLangOCR:
opset_version(int): operator set opset_version(int): operator set
''' '''
v0, v1, v2 = paddle2onnx.__version__.split('.') v0, v1, v2 = paddle2onnx.__version__.split('.')
if int(v1) < 9: if int(v0) == 0 and int(v1) < 9:
raise ImportError("paddle2onnx>=0.9.0 is required") raise ImportError("paddle2onnx>=0.9.0 is required")
if input_shape_dict is not None and not isinstance(input_shape_dict, dict): if input_shape_dict is not None and not isinstance(input_shape_dict, dict):
...@@ -200,19 +197,11 @@ class MultiLangOCR: ...@@ -200,19 +197,11 @@ class MultiLangOCR:
path_dict = {"det": self.det_model_dir, "rec": self.rec_model_dir, "cls": self.cls_model_dir} path_dict = {"det": self.det_model_dir, "rec": self.rec_model_dir, "cls": self.cls_model_dir}
for (key, path) in path_dict.items(): for (key, path) in path_dict.items():
model_filename = 'inference.pdmodel'
params_filename = 'inference.pdiparams'
save_file = os.path.join(dirname, '{}_{}.onnx'.format(self.name, key)) save_file = os.path.join(dirname, '{}_{}.onnx'.format(self.name, key))
# convert model save with 'paddle.fluid.io.save_inference_model' exe = paddle.static.Executor(paddle.CPUPlace())
if hasattr(paddle, 'enable_static'): [program, feed_var_names, fetch_vars] = paddle.static.load_inference_model(
paddle.enable_static() os.path.join(path, 'inference'), exe)
exe = fluid.Executor(fluid.CPUPlace())
if model_filename is None and params_filename is None:
[program, feed_var_names, fetch_vars] = fluid.io.load_inference_model(path, exe)
else:
[program, feed_var_names, fetch_vars] = fluid.io.load_inference_model(
path, exe, model_filename=model_filename, params_filename=params_filename)
onnx_proto = p2o.run_convert(program, input_shape_dict=input_shape_dict, opset_version=opset_version) onnx_proto = p2o.run_convert(program, input_shape_dict=input_shape_dict, opset_version=opset_version)
mkdir(save_file) mkdir(save_file)
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="multi_languages_ocr_db_crnn")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('onnx')
shutil.rmtree('ocr_result')
def test_recognize_text1(self):
results = self.module.recognize_text(
paths=['tests/test.jpg'],
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text2(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
visualization=False,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text3(self):
results = self.module.recognize_text(
images=[cv2.imread('tests/test.jpg')],
visualization=True,
)
self.assertEqual(results[0]['data'], [
{
'text': 'GIVE.', 'confidence': 0.9509806632995605,
'text_box_position': [[283, 162], [352, 162], [352, 202], [283, 202]]
},
{
'text': 'THANKS', 'confidence': 0.9943129420280457,
'text_box_position': [[261, 202], [376, 202], [376, 239], [261, 239]]
}])
def test_recognize_text4(self):
self.assertRaises(
AttributeError,
self.module.recognize_text,
images=['tests/test.jpg']
)
def test_recognize_text5(self):
self.assertRaises(
AssertionError,
self.module.recognize_text,
paths=['no.jpg']
)
def test_export_onnx_model(self):
self.module.export_onnx_model(dirname='onnx', input_shape_dict=None, opset_version=10)
self.assertTrue(os.path.isfile('onnx/multi_languages_ocr_db_crnn_cls.onnx'))
self.assertTrue(os.path.isfile('onnx/multi_languages_ocr_db_crnn_det.onnx'))
self.assertTrue(os.path.isfile('onnx/multi_languages_ocr_db_crnn_rec.onnx'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册