未验证 提交 4889b32a 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ch_pp-ocrv3_det (#2032)

* update ch_pp-ocrv3_det

* Add det_db_score_mode (#1)
Co-authored-by: littletomatodonkey's avatarlittletomatodonkey <dazhiningsibuqu@163.com>

* fix a unused param

* update defalut value

* update README

* fix a typo
Co-authored-by: littletomatodonkey's avatarlittletomatodonkey <dazhiningsibuqu@163.com>
上级 cdfcee74
...@@ -84,26 +84,29 @@ ...@@ -84,26 +84,29 @@
- ```python - ```python
def detect_text(paths=[], def detect_text(images=[],
images=[], paths=[],
use_gpu=False, use_gpu=False,
output_dir='detection_result', output_dir='detection_result',
box_thresh=0.5,
visualization=False, visualization=False,
det_db_unclip_ratio=1.5) box_thresh=0.6,
det_db_unclip_ratio=1.5,
det_db_score_mode="fast")
``` ```
- 预测API,检测输入图片中的所有中文文本的位置。 - 预测API,检测输入图片中的所有中文文本的位置。
- **参数** - **参数**
- paths (list\[str\]): 图片的路径;
- images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式; - images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
- paths (list\[str\]): 图片的路径;
- use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** - use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
- box\_thresh (float): 检测文本框置信度的阈值; - box\_thresh (float): 检测文本框置信度的阈值;
- visualization (bool): 是否将识别结果保存为图片文件; - visualization (bool): 是否将识别结果保存为图片文件;
- output\_dir (str): 图片的保存路径,默认设为 detection\_result; - output\_dir (str): 图片的保存路径,默认设为 detection\_result;
- det\_db\_unclip\_ratio: 设置检测框的大小; - det\_db\_unclip\_ratio: 设置检测框的大小;
- det\_db\_score\_mode: 设置检测得分计算方式,“fast” / “slow”
- **返回** - **返回**
- res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为: - res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
...@@ -158,6 +161,10 @@ ...@@ -158,6 +161,10 @@
初始发布 初始发布
* 1.1.0
移除 Fluid API
- ```shell - ```shell
$ hub install ch_pp-ocrv3_det==1.0.0 $ hub install ch_pp-ocrv3_det==1.1.0
``` ```
...@@ -19,21 +19,15 @@ from __future__ import print_function ...@@ -19,21 +19,15 @@ from __future__ import print_function
import argparse import argparse
import ast import ast
import base64 import base64
import math
import os import os
import time import time
import cv2 import cv2
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from PIL import Image from PIL import Image
import paddlehub as hub from paddlehub.utils.utils import logger
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable from paddlehub.module.module import runnable
from paddlehub.module.module import serving from paddlehub.module.module import serving
...@@ -48,15 +42,14 @@ def base64_to_cv2(b64str): ...@@ -48,15 +42,14 @@ def base64_to_cv2(b64str):
@moduleinfo( @moduleinfo(
name="ch_pp-ocrv3_det", name="ch_pp-ocrv3_det",
version="1.0.0", version="1.1.0",
summary= summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.", "The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/text_recognition") type="cv/text_recognition")
class ChPPOCRv3Det(hub.Module): class ChPPOCRv3Det:
def __init__(self, enable_mkldnn=False):
def _initialize(self, enable_mkldnn=False):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -211,7 +204,7 @@ class ChPPOCRv3Det(hub.Module): ...@@ -211,7 +204,7 @@ class ChPPOCRv3Det(hub.Module):
postprocessor = DBPostProcess( postprocessor = DBPostProcess(
params={ params={
'thresh': 0.3, 'thresh': 0.3,
'box_thresh': 0.6, 'box_thresh': box_thresh,
'max_candidates': 1000, 'max_candidates': 1000,
'unclip_ratio': det_db_unclip_ratio, 'unclip_ratio': det_db_unclip_ratio,
'det_db_score_mode': det_db_score_mode, 'det_db_score_mode': det_db_score_mode,
...@@ -243,7 +236,7 @@ class ChPPOCRv3Det(hub.Module): ...@@ -243,7 +236,7 @@ class ChPPOCRv3Det(hub.Module):
dt_boxes_list = postprocessor(outs_dict, [ratio_list]) dt_boxes_list = postprocessor(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0] dt_boxes = dt_boxes_list[0]
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape) boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
res['data'] = boxes.astype(np.int).tolist() res['data'] = boxes.astype(np.int64).tolist()
all_imgs.append(im) all_imgs.append(im)
all_ratios.append(ratio_list) all_ratios.append(ratio_list)
if visualization: if visualization:
...@@ -319,7 +312,7 @@ class ChPPOCRv3Det(hub.Module): ...@@ -319,7 +312,7 @@ class ChPPOCRv3Det(hub.Module):
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--det_db_score_mode', '--det_db_score_mode',
type=str, type=str,
default="str", default="fast",
help="method to calc the final det score, one of fast(using box) and slow(using poly).") help="method to calc the final det score, one of fast(using box) and slow(using poly).")
def add_module_input_arg(self): def add_module_input_arg(self):
......
...@@ -20,11 +20,8 @@ import sys ...@@ -20,11 +20,8 @@ import sys
import cv2 import cv2
import numpy as np import numpy as np
import paddle
import pyclipper import pyclipper
from PIL import Image
from PIL import ImageDraw from PIL import ImageDraw
from PIL import ImageFont
from shapely.geometry import Polygon from shapely.geometry import Polygon
...@@ -168,9 +165,9 @@ class DBPostProcess(object): ...@@ -168,9 +165,9 @@ class DBPostProcess(object):
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype(np.int16)) boxes.append(box.astype(np.int64))
scores.append(score) scores.append(score)
return np.array(boxes, dtype=np.int16), scores return np.array(boxes, dtype=np.int64), scores
def unclip(self, box): def unclip(self, box):
unclip_ratio = self.unclip_ratio unclip_ratio = self.unclip_ratio
...@@ -208,15 +205,15 @@ class DBPostProcess(object): ...@@ -208,15 +205,15 @@ class DBPostProcess(object):
''' '''
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
box = _box.copy() box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int64), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour): def box_score_slow(self, bitmap, contour):
...@@ -237,7 +234,7 @@ class DBPostProcess(object): ...@@ -237,7 +234,7 @@ class DBPostProcess(object):
contour[:, 0] = contour[:, 0] - xmin contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int64), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="ch_pp-ocrv3_det")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_detect_text1(self):
results = self.module.detect_text(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text2(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text3(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text4(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
def test_detect_text5(self):
self.assertRaises(
AttributeError,
self.module.detect_text,
images=['tests/test.jpg']
)
def test_detect_text6(self):
self.assertRaises(
AssertionError,
self.module.detect_text,
paths=['no.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册