未验证 提交 41791d9b 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update chinese_text_detection_db_server (#2170)

* update chinese_text_detection_db_server

* update README
上级 626b77eb
......@@ -67,7 +67,7 @@
```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例
- ### 2、预测代码示例
- ```python
import paddlehub as hub
......@@ -174,7 +174,11 @@
* 1.0.3
移除 fluid api
* 1.1.0
适配 PaddleHub 2.x 版本
- ```shell
$ hub install chinese_text_detection_db_server==1.0.3
$ hub install chinese_text_detection_db_server==1.1.0
```
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -9,43 +8,55 @@ import base64
import math
import os
import time
from io import BytesIO
import cv2
import numpy as np
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from PIL import Image
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
from paddlehub.utils.log import logger
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
if data is None:
buf = BytesIO()
image_decode = base64.b64decode(b64str.encode('utf8'))
image = BytesIO(image_decode)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
data = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return data
@moduleinfo(
name="chinese_text_detection_db_server",
version="1.0.3",
version="1.1.0",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseTextDetectionDBServer(hub.Module):
class ChineseTextDetectionDBServer:
def _initialize(self, enable_mkldnn=False):
def __init__(self, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory, 'inference_model')
self.pretrained_model_path = os.path.join(self.directory, 'inference_model', 'model')
self.enable_mkldnn = enable_mkldnn
self._set_config()
......@@ -62,8 +73,8 @@ class ChineseTextDetectionDBServer(hub.Module):
"""
predictor config setting
"""
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
model_file_path = self.pretrained_model_path + '.pdmodel'
params_file_path = self.pretrained_model_path + '.pdiparams'
config = Config(model_file_path, params_file_path)
try:
......@@ -211,7 +222,7 @@ class ChineseTextDetectionDBServer(hub.Module):
data_out = self.output_tensors[0].copy_to_cpu()
dt_boxes_list = postprocessor(data_out, [ratio_list])
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
res['data'] = boxes.astype(np.int).tolist()
res['data'] = boxes.astype(np.int64).tolist()
all_imgs.append(im)
all_ratios.append(ratio_list)
......@@ -230,28 +241,6 @@ class ChineseTextDetectionDBServer(hub.Module):
return all_results
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
program, feeded_var_names, target_vars = paddle.static.load_inference_model(dirname=self.pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from PIL import Image, ImageDraw, ImageFont
from shapely.geometry import Polygon
import cv2
import numpy as np
import pyclipper
from PIL import ImageDraw
from shapely.geometry import Polygon
class DBPreProcess(object):
def __init__(self, max_side_len=960):
self.max_side_len = max_side_len
......@@ -103,7 +103,7 @@ class DBPostProcess(object):
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int64)
scores = np.zeros((num_contours, ), dtype=np.float32)
for index in range(num_contours):
......@@ -127,7 +127,7 @@ class DBPostProcess(object):
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes[index, :, :] = box.astype(np.int16)
boxes[index, :, :] = box.astype(np.int64)
scores[index] = score
return boxes, scores
......@@ -163,15 +163,15 @@ class DBPostProcess(object):
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int64), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, predictions, ratio_list):
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="chinese_text_detection_db_server")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_detect_text1(self):
results = self.module.detect_text(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])
def test_detect_text2(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])
def test_detect_text3(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])
def test_detect_text4(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])
def test_detect_text5(self):
self.assertRaises(AttributeError, self.module.detect_text, images=['tests/test.jpg'])
def test_detect_text6(self):
self.assertRaises(AssertionError, self.module.detect_text, paths=['no.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册