From b0a93a2e28ec8c7f2ffdecaf168fc384b667829b Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Thu, 29 Dec 2022 10:15:46 +0800 Subject: [PATCH] update chinese_text_detection_db_mobile (#2168) * update chinese_text_detection_db_mobile * update README --- .../README.md | 8 +- .../module.py | 58 +++++-------- .../processor.py | 19 ++-- .../requirements.txt | 2 + .../chinese_text_detection_db_mobile/test.py | 86 +++++++++++++++++++ 5 files changed, 126 insertions(+), 47 deletions(-) create mode 100644 modules/image/text_recognition/chinese_text_detection_db_mobile/requirements.txt create mode 100644 modules/image/text_recognition/chinese_text_detection_db_mobile/test.py diff --git a/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md b/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md index eccf5688..2bebc158 100644 --- a/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md +++ b/modules/image/text_recognition/chinese_text_detection_db_mobile/README.md @@ -67,7 +67,7 @@ ``` - 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) -- ### 2、代码示例 +- ### 2、预测代码示例 - ```python import paddlehub as hub @@ -186,6 +186,10 @@ 移除 fluid api +* 1.1.0 + + 适配 PaddleHub 2.x 版本 + - ```shell - $ hub install chinese_text_detection_db_mobile==1.0.5 + $ hub install chinese_text_detection_db_mobile==1.1.0 ``` diff --git a/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py b/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py index c5e1b1b0..be874db6 100644 --- a/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py +++ b/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,46 +5,57 @@ from __future__ import print_function import argparse import ast import base64 -import math import os import time +from io import BytesIO import cv2 import numpy as np -import paddle from paddle.inference import Config from paddle.inference import create_predictor from PIL import Image -import paddlehub as hub -from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo from paddlehub.module.module import runnable from paddlehub.module.module import serving +from paddlehub.utils.log import logger def base64_to_cv2(b64str): data = base64.b64decode(b64str.encode('utf8')) data = np.fromstring(data, np.uint8) data = cv2.imdecode(data, cv2.IMREAD_COLOR) + if data is None: + buf = BytesIO() + image_decode = base64.b64decode(b64str.encode('utf8')) + image = BytesIO(image_decode) + im = Image.open(image) + rgb = im.convert('RGB') + rgb.save(buf, 'jpeg') + buf.seek(0) + image_bytes = buf.read() + data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8") + image_decode = base64.b64decode(data_base64) + img_array = np.frombuffer(image_decode, np.uint8) + data = cv2.imdecode(img_array, cv2.IMREAD_COLOR) return data @moduleinfo( name="chinese_text_detection_db_mobile", - version="1.0.5", + version="1.1.0", summary= "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", author="paddle-dev", author_email="paddle-dev@baidu.com", type="cv/text_recognition") -class ChineseTextDetectionDB(hub.Module): +class ChineseTextDetectionDB: - def _initialize(self, enable_mkldnn=False): + def __init__(self, enable_mkldnn=False): """ initialize with the necessary elements """ - self.pretrained_model_path = os.path.join(self.directory, 'inference_model') + self.pretrained_model_path = os.path.join(self.directory, 'inference_model', 'model') self.enable_mkldnn = enable_mkldnn self._set_config() @@ -62,8 +72,8 @@ class ChineseTextDetectionDB(hub.Module): """ predictor config setting """ - model_file_path = os.path.join(self.pretrained_model_path, 'model') - params_file_path = os.path.join(self.pretrained_model_path, 'params') + model_file_path = self.pretrained_model_path + '.pdmodel' + params_file_path = self.pretrained_model_path + '.pdiparams' config = Config(model_file_path, params_file_path) try: @@ -205,7 +215,7 @@ class ChineseTextDetectionDB(hub.Module): preprocessor = DBProcessTest(params={'max_side_len': 960}) postprocessor = DBPostProcess(params={ 'thresh': 0.3, - 'box_thresh': 0.5, + 'box_thresh': box_thresh, 'max_candidates': 1000, 'unclip_ratio': 1.6 }) @@ -237,7 +247,7 @@ class ChineseTextDetectionDB(hub.Module): dt_boxes_list = postprocessor(outs_dict, [ratio_list]) dt_boxes = dt_boxes_list[0] boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape) - res['data'] = boxes.astype(np.int).tolist() + res['data'] = boxes.astype(np.int64).tolist() all_imgs.append(im) all_ratios.append(ratio_list) @@ -256,28 +266,6 @@ class ChineseTextDetectionDB(hub.Module): return all_results - def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True): - if combined: - model_filename = "__model__" if not model_filename else model_filename - params_filename = "__params__" if not params_filename else params_filename - place = paddle.CPUPlace() - exe = paddle.Executor(place) - - model_file_path = os.path.join(self.pretrained_model_path, 'model') - params_file_path = os.path.join(self.pretrained_model_path, 'params') - program, feeded_var_names, target_vars = paddle.static.load_inference_model(dirname=self.pretrained_model_path, - model_filename=model_file_path, - params_filename=params_file_path, - executor=exe) - - paddle.static.save_inference_model(dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) - @serving def serving_method(self, images, **kwargs): """ diff --git a/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py b/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py index b5e76cbe..af50f11e 100644 --- a/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py +++ b/modules/image/text_recognition/chinese_text_detection_db_mobile/processor.py @@ -1,15 +1,14 @@ -# -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import sys -from PIL import Image, ImageDraw, ImageFont -from shapely.geometry import Polygon import cv2 import numpy as np import pyclipper +from PIL import ImageDraw +from shapely.geometry import Polygon class DBProcessTest(object): @@ -138,7 +137,7 @@ class DBPostProcess(object): contours, _ = outs[0], outs[1] num_contours = min(len(contours), self.max_candidates) - boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) + boxes = np.zeros((num_contours, 4, 2), dtype=np.int64) scores = np.zeros((num_contours, ), dtype=np.float32) for index in range(num_contours): @@ -162,7 +161,7 @@ class DBPostProcess(object): box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) - boxes[index, :, :] = box.astype(np.int16) + boxes[index, :, :] = box.astype(np.int64) scores[index] = score return boxes, scores @@ -199,15 +198,15 @@ class DBPostProcess(object): def box_score_fast(self, bitmap, _box): h, w = bitmap.shape[:2] box = _box.copy() - xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) - xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) - ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) - ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1) mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) box[:, 0] = box[:, 0] - xmin box[:, 1] = box[:, 1] - ymin - cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int64), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] def __call__(self, outs_dict, ratio_list): diff --git a/modules/image/text_recognition/chinese_text_detection_db_mobile/requirements.txt b/modules/image/text_recognition/chinese_text_detection_db_mobile/requirements.txt new file mode 100644 index 00000000..7159e62c --- /dev/null +++ b/modules/image/text_recognition/chinese_text_detection_db_mobile/requirements.txt @@ -0,0 +1,2 @@ +shapely +pyclipper diff --git a/modules/image/text_recognition/chinese_text_detection_db_mobile/test.py b/modules/image/text_recognition/chinese_text_detection_db_mobile/test.py new file mode 100644 index 00000000..106033ae --- /dev/null +++ b/modules/image/text_recognition/chinese_text_detection_db_mobile/test.py @@ -0,0 +1,86 @@ +import os +import shutil +import unittest + +import cv2 +import requests + +import paddlehub as hub + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestHubModule(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + cls.module = hub.Module(name="chinese_text_detection_db_mobile") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + shutil.rmtree('detection_result') + + def test_detect_text1(self): + results = self.module.detect_text( + paths=['tests/test.jpg'], + use_gpu=False, + visualization=False, + ) + self.assertEqual( + results[0]['data'], + [[[259, 201], [376, 199], [376, 238], [259, 240]], [[282, 163], [351, 163], [351, 200], [282, 200]]]) + + def test_detect_text2(self): + results = self.module.detect_text( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=False, + ) + self.assertEqual( + results[0]['data'], + [[[259, 201], [376, 199], [376, 238], [259, 240]], [[282, 163], [351, 163], [351, 200], [282, 200]]]) + + def test_detect_text3(self): + results = self.module.detect_text( + images=[cv2.imread('tests/test.jpg')], + use_gpu=True, + visualization=False, + ) + self.assertEqual( + results[0]['data'], + [[[259, 201], [376, 199], [376, 238], [259, 240]], [[282, 163], [351, 163], [351, 200], [282, 200]]]) + + def test_detect_text4(self): + results = self.module.detect_text( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=True, + ) + self.assertEqual( + results[0]['data'], + [[[259, 201], [376, 199], [376, 238], [259, 240]], [[282, 163], [351, 163], [351, 200], [282, 200]]]) + + def test_detect_text5(self): + self.assertRaises(AttributeError, self.module.detect_text, images=['tests/test.jpg']) + + def test_detect_text6(self): + self.assertRaises(AssertionError, self.module.detect_text, paths=['no.jpg']) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model.pdiparams')) + + +if __name__ == "__main__": + unittest.main() -- GitLab