未验证 提交 4382eee6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ch_pp-ocrv3_det (#2173)

* update ch_pp-ocrv3_det

* update
上级 3d13232e
...@@ -21,22 +21,36 @@ import ast ...@@ -21,22 +21,36 @@ import ast
import base64 import base64
import os import os
import time import time
from io import BytesIO
import cv2 import cv2
import numpy as np import numpy as np
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from PIL import Image from PIL import Image
from paddlehub.utils.utils import logger
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable from paddlehub.module.module import runnable
from paddlehub.module.module import serving from paddlehub.module.module import serving
from paddlehub.utils.utils import logger
def base64_to_cv2(b64str): def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR) data = cv2.imdecode(data, cv2.IMREAD_COLOR)
if data is None:
buf = BytesIO()
image_decode = base64.b64decode(b64str.encode('utf8'))
image = BytesIO(image_decode)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
data = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return data return data
...@@ -49,6 +63,7 @@ def base64_to_cv2(b64str): ...@@ -49,6 +63,7 @@ def base64_to_cv2(b64str):
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChPPOCRv3Det: class ChPPOCRv3Det:
def __init__(self, enable_mkldnn=False): def __init__(self, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
......
...@@ -4,13 +4,14 @@ import unittest ...@@ -4,13 +4,14 @@ import unittest
import cv2 import cv2
import requests import requests
import paddlehub as hub
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase): class TestHubModule(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640' img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
...@@ -34,8 +35,9 @@ class TestHubModule(unittest.TestCase): ...@@ -34,8 +35,9 @@ class TestHubModule(unittest.TestCase):
use_gpu=False, use_gpu=False,
visualization=False, visualization=False,
) )
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ self.assertEqual(
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text2(self): def test_detect_text2(self):
results = self.module.detect_text( results = self.module.detect_text(
...@@ -43,8 +45,9 @@ class TestHubModule(unittest.TestCase): ...@@ -43,8 +45,9 @@ class TestHubModule(unittest.TestCase):
use_gpu=False, use_gpu=False,
visualization=False, visualization=False,
) )
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ self.assertEqual(
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text3(self): def test_detect_text3(self):
results = self.module.detect_text( results = self.module.detect_text(
...@@ -52,8 +55,9 @@ class TestHubModule(unittest.TestCase): ...@@ -52,8 +55,9 @@ class TestHubModule(unittest.TestCase):
use_gpu=True, use_gpu=True,
visualization=False, visualization=False,
) )
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ self.assertEqual(
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text4(self): def test_detect_text4(self):
results = self.module.detect_text( results = self.module.detect_text(
...@@ -61,22 +65,15 @@ class TestHubModule(unittest.TestCase): ...@@ -61,22 +65,15 @@ class TestHubModule(unittest.TestCase):
use_gpu=False, use_gpu=False,
visualization=True, visualization=True,
) )
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ self.assertEqual(
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text5(self): def test_detect_text5(self):
self.assertRaises( self.assertRaises(AttributeError, self.module.detect_text, images=['tests/test.jpg'])
AttributeError,
self.module.detect_text,
images=['tests/test.jpg']
)
def test_detect_text6(self): def test_detect_text6(self):
self.assertRaises( self.assertRaises(AssertionError, self.module.detect_text, paths=['no.jpg'])
AssertionError,
self.module.detect_text,
paths=['no.jpg']
)
def test_save_inference_model(self): def test_save_inference_model(self):
self.module.save_inference_model('./inference/model') self.module.save_inference_model('./inference/model')
...@@ -87,4 +84,3 @@ class TestHubModule(unittest.TestCase): ...@@ -87,4 +84,3 @@ class TestHubModule(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册