未验证 提交 1247278f 编写于 作者: S Steffy-zxf 提交者: GitHub

Add ocr modules (#611)

* add ocr modules (chinese_ocr_db_rcnn & chinese_text_detection_db)
上级 9fdbfa49
## 概述
chinese_ocr_db_rcnn Module用于识别图片当中的汉字。其基于[chinese_text_detection_db Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db&en_category=TextRecognition)检测得到的文本框,继续识别文本框中的中文文字。识别文字算法采用CRNN(Convolutional Recurrent Neural Network)即卷积递归神经网络。其是DCNN和RNN的组合,专门用于识别图像中的序列式对象。与CTC loss配合使用,进行文字识别,可以直接从文本词级或行级的标注中学习,不需要详细的字符级的标注。该Module支持直接预测。
<p align="center">
<img src="https://bj.bcebos.com/paddlehub/model/image/ocr/rcnn.png" hspace='10'/> <br />
</p>
更多详情参考[An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition](https://arxiv.org/pdf/1507.05717.pdf)
## 命令行预测
```shell
$ hub run chinese_ocr_db_rcnn --input_path "/PATH/TO/IMAGE"
```
## API
```python
def recognize_text(images=[],
paths=[],
use_gpu=False,
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5)
```
预测API,检测输入图片中的所有中文文本的位置。
**参数**
* paths (list\[str\]): 图片的路径;
* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
* box\_thresh (float): 检测文本框置信度的阈值;
* text\_thresh (float): 识别中文文本置信度的阈值;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 detection\_result;
**返回**
* res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
* data (list\[dict\]): 识别文本结果,列表中每一个元素为 dict,各字段为:
* text(str): 识别得到的文本
* confidence(float): 识别文本结果置信度
* text_box_position(numpy.ndarray): 文本框在原图中的像素坐标,4*2的矩阵,依次表示文本框左下、右下、右上、左上顶点的坐标
如果无识别结果则data为\[\]
* save_path (str, optional): 识别结果的保存路径,如不保存图片则save_path为''
### 代码示例
```python
import paddlehub as hub
import cv2
ocr = hub.Module(name="chinese_ocr_db_rcnn")
result = ocr.recognize_text(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = ocr.recognize_text(paths=['/PATH/TO/IMAGE'])
```
* 样例结果示例
<p align="center">
<img src="https://bj.bcebos.com/paddlehub/model/image/ocr/ocr_res.jpg" hspace='10'/> <br />
</p>
## 服务部署
PaddleHub Serving 可以部署一个目标检测的在线服务。
### 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m chinese_ocr_db_rcnn
```
这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
### 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/chinese_ocr_db_rcnn"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 查看代码
https://github.com/PaddlePaddle/PaddleOCR
### 依赖
paddlepaddle >= 1.7.2
paddlehub >= 1.6.0
## 更新历史
* 1.0.0
初始发布
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import string
class CharacterOps(object):
""" Convert between text-label and text-index """
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n")
self.character_str += line
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
else:
self.character_str = None
assert self.character_str is not None, \
"Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if self.character_type == "en":
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
text = np.array(text_list)
return text
def decode(self, text_index, is_remove_duplicate=False):
""" convert text-index into text-label. """
char_list = []
char_num = self.get_char_num()
if self.loss_type == "attention":
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
ignored_tokens = [beg_idx, end_idx]
else:
ignored_tokens = [char_num]
for idx in range(len(text_index)):
if text_index[idx] in ignored_tokens:
continue
if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
text = ''.join(char_list)
return text
def get_char_num(self):
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
if self.loss_type == "attention":
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx"\
% beg_or_end
return idx
else:
err = "error in get_beg_end_flag_idx when using the loss %s"\
% (self.loss_type)
assert False, err
def cal_predicts_accuracy(char_ops,
preds,
preds_lod,
labels,
labels_lod,
is_remove_duplicate=False):
acc_num = 0
img_num = 0
for ino in range(len(labels_lod) - 1):
beg_no = preds_lod[ino]
end_no = preds_lod[ino + 1]
preds_text = preds[beg_no:end_no].reshape(-1)
preds_text = char_ops.decode(preds_text, is_remove_duplicate)
beg_no = labels_lod[ino]
end_no = labels_lod[ino + 1]
labels_text = labels[beg_no:end_no].reshape(-1)
labels_text = char_ops.decode(labels_text, is_remove_duplicate)
img_num += 1
if preds_text == labels_text:
acc_num += 1
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
target_lod = [0]
convert_ids = []
for ino in range(img_num):
end_pos = np.where(preds[ino, :] == 1)[0]
if len(end_pos) <= 1:
text_list = preds[ino, 1:]
else:
text_list = preds[ino, 1:end_pos[1]]
target_lod.append(target_lod[ino] + len(text_list))
convert_ids = convert_ids + list(text_list)
convert_ids = np.array(convert_ids)
convert_ids = convert_ids.reshape((-1, 1))
return convert_ids, target_lod
def convert_rec_label_to_lod(ori_labels):
img_num = len(ori_labels)
target_lod = [0]
convert_ids = []
for ino in range(img_num):
target_lod.append(target_lod[ino] + len(ori_labels[ino]))
convert_ids = convert_ids + list(ori_labels[ino])
convert_ids = np.array(convert_ids)
convert_ids = convert_ids.reshape((-1, 1))
return convert_ids, target_lod
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import copy
import math
import os
import time
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image
import cv2
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from chinese_ocr_db_rcnn.character import CharacterOps
from chinese_ocr_db_rcnn.utils import draw_ocr, get_image_ext, sorted_boxes
@moduleinfo(
name="chinese_ocr_db_rcnn",
version="1.0.0",
summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseOCRDBRCNN(hub.Module):
def _initialize(self, text_detector_module=None):
"""
initialize with the necessary elements
"""
self.character_dict_path = os.path.join(self.directory, 'assets',
'ppocr_keys_v1.txt')
char_ops_params = {
'character_type': 'ch',
'character_dict_path': self.character_dict_path,
'loss_type': 'ctc'
}
self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320]
self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.pretrained_model_path = os.path.join(self.directory,
'inference_model')
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
config.enable_use_gpu(8000, 0)
else:
config.disable_gpu()
config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_tensor(input_names[0])
output_names = self.predictor.get_output_names()
self.output_tensors = []
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
self.output_tensors.append(output_tensor)
@property
def text_detector_module(self):
"""
text detect module
"""
if not self._text_detector_module:
self._text_detector_module = hub.Module(
name='chinese_text_detection_db')
return self._text_detector_module
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
images.append(img)
return images
def get_rotate_crop_image(self, img, points):
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0],\
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
@serving
def recognize_text(self,
images=[],
paths=[],
use_gpu=False,
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5):
"""
Get the chinese texts in the predicted images.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
use_gpu (bool): Whether to use gpu.
batch_size(int): the program deals once with one
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the recognize chinese texts' confidence
Returns:
res (list): The result of chinese texts and save path of images.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
self.use_gpu = use_gpu
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
detection_results = self.text_detector_module.detect_text(
images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh)
boxes = [item['data'] for item in detection_results]
all_results = []
for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy()
result = {'save_path': ''}
if img_boxes is None:
result['data'] = []
else:
img_crop_list = []
boxes = sorted_boxes(img_boxes)
for num_box in range(len(boxes)):
tmp_box = copy.deepcopy(boxes[num_box])
img_crop = self.get_rotate_crop_image(
original_image, tmp_box)
img_crop_list.append(img_crop)
rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it
rec_res_final = []
for index, res in enumerate(rec_results):
text, score = res
if score >= text_thresh:
rec_res_final.append({
'text': text,
'confidence': score,
'text_box_position': boxes[index]
})
result['data'] = rec_res_final
if visualization and result['data']:
result['save_path'] = self.save_result_image(
original_image, boxes, rec_results, output_dir,
text_thresh)
all_results.append(result)
return all_results
def save_result_image(self,
original_image,
detection_boxes,
rec_results,
output_dir='ocr_result',
text_thresh=0.5):
image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
txts = [item[0] for item in rec_results]
scores = [item[1] for item in rec_results]
draw_img = draw_ocr(
image,
detection_boxes,
txts,
scores,
font_file=self.font_file,
draw_txt=True,
drop_score=text_thresh)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
ext = get_image_ext(original_image)
saved_name = 'ndarray_{}{}'.format(time.time(), ext)
save_file_path = os.path.join(output_dir, saved_name)
cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path
def _recognize_text(self, image_list):
img_num = len(image_list)
batch_num = 30
rec_res = []
predict_time = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = image_list[ino].shape[0:2]
wh_ratio = w / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(image_list[ino], max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = predict_batch[beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
return rec_res
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
detector_dir = os.path.join(dirname, 'text_detector')
recognizer_dir = os.path.join(dirname, 'text_recognizer')
self._save_detector_model(detector_dir, model_filename, params_filename,
combined)
self._save_recognizer_model(recognizer_dir, model_filename,
params_filename, combined)
logger.info("The inference model has been saved in the path {}".format(
os.path.realpath(dirname)))
def _save_detector_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
self.text_detector_module.save_inference_model(
dirname, model_filename, params_filename, combined)
def _save_recognizer_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description="Run the chinese_ocr_db_rcnn module.",
prog='hub run chinese_ocr_db_rcnn',
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.recognize_texts(
paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization)
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument(
'--output_dir',
type=str,
default='ocr_result',
help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument(
'--input_path', type=str, default=None, help="diretory to image")
if __name__ == '__main__':
ocr = ChineseOCRDBRCNN()
image_path = [
'../doc/imgs/11.jpg', '../doc/imgs/12.jpg', '../test_image.jpg'
]
res = ocr.recognize_text(paths=image_path, visualization=True)
ocr.save_inference_model('save')
print(res)
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
def draw_ocr(image,
boxes,
txts,
scores,
font_file,
draw_txt=True,
drop_score=0.5):
img = image.copy()
draw = ImageDraw.Draw(img)
if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score:
continue
draw.line([(box[0][0], box[0][1]), (box[1][0], box[1][1])], fill='red')
draw.line([(box[1][0], box[1][1]), (box[2][0], box[2][1])], fill='red')
draw.line([(box[2][0], box[2][1]), (box[3][0], box[3][1])], fill='red')
draw.line([(box[3][0], box[3][1]), (box[0][0], box[0][1])], fill='red')
draw.line([(box[0][0] - 1, box[0][1] + 1),
(box[1][0] - 1, box[1][1] + 1)],
fill='red')
draw.line([(box[1][0] - 1, box[1][1] + 1),
(box[2][0] - 1, box[2][1] + 1)],
fill='red')
draw.line([(box[2][0] - 1, box[2][1] + 1),
(box[3][0] - 1, box[3][1] + 1)],
fill='red')
draw.line([(box[3][0] - 1, box[3][1] + 1),
(box[0][0] - 1, box[0][1] + 1)],
fill='red')
if draw_txt:
txt_color = (0, 0, 0)
img = np.array(resize_img(img))
_h = img.shape[0]
blank_img = np.ones(shape=[_h, 600], dtype=np.int8) * 255
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
font_size = 20
gap = 20
title = "index text score"
font = ImageFont.truetype(font_file, font_size, encoding="utf-8")
draw_txt.text((20, 0), title, txt_color, font=font)
count = 0
for idx, txt in enumerate(txts):
if scores[idx] < drop_score:
continue
font = ImageFont.truetype(font_file, font_size, encoding="utf-8")
new_txt = str(count) + ': ' + txt + ' ' + str(scores[count])
draw_txt.text((20, gap * (count + 1)),
new_txt,
txt_color,
font=font)
count += 1
img = np.concatenate([np.array(img), np.array(blank_img)], axis=1)
return img
def resize_img(img, input_size=600):
img = np.array(img)
im_shape = img.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
return im
def get_image_ext(image):
if image.shape[2] == 4:
return ".png"
return ".jpg"
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1])
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
## 概述
Differentiable Binarization(简称DB)是一种基于分割的文本检测算法。在各种文本检测算法中,基于分割的检测算法可以更好地处理弯曲等不规则形状文本,因此往往能取得更好的检测效果。但分割法后处理步骤中将分割结果转化为检测框的流程复杂,耗时严重。DB将二值化阈值加入训练中学习,可以获得更准确的检测边界,从而简化后处理流程。该Module支持直接预测。
<p align="center">
<img src="https://bj.bcebos.com/paddlehub/model/image/ocr/db_algo.png" hspace='10'/> <br />
</p>
更多详情参考[Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/pdf/1911.08947.pdf)
## 命令行预测
```shell
$ hub run chinese_text_detection_db --input_path "/PATH/TO/IMAGE"
```
**该Module依赖于第三方库shapely和pyclipper,使用该Module之前,请先安装shapely和pyclipper。**
## API
```python
def detect_text(paths=[],
images=[],
use_gpu=False,
output_dir='detection_result',
box_thresh=0.5,
visualization=False)
```
预测API,检测输入图片中的所有中文文本的位置。
**参数**
* paths (list\[str\]): 图片的路径;
* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
* box\_thresh (float): 检测文本框置信度的阈值;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 detection\_result;
**返回**
* res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
* data (list): 检测文本框结果,numpy.ndarray,文本框在原图中的像素坐标,4*2的矩阵,依次表示文本框左下、右下、右上、左上顶点的坐标
* save_path (str): 识别结果的保存路径, 如不保存图片则save_path为''
### 代码示例
```python
import paddlehub as hub
import cv2
text_detector = hub.Module(name="chinese_text_detection_db")
result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result =text_detector.detect_text(paths=['/PATH/TO/IMAGE'])
```
## 服务部署
PaddleHub Serving 可以部署一个目标检测的在线服务。
### 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m chinese_text_detection_db
```
这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
### 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/chinese_text_detection_db"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 查看代码
https://github.com/PaddlePaddle/PaddleOCR
## 依赖
paddlepaddle >= 1.7.2
paddlehub >= 1.6.0
shapely
pyclipper
## 更新历史
* 1.0.0
初始发布
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import math
import os
import time
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image
import cv2
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from chinese_text_detection_db.processor import DBPreProcess, DBPostProcess, draw_boxes, get_image_ext
@moduleinfo(
name="chinese_text_detection_db",
version="1.0.0",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseTextDetectionDB(hub.Module):
def _initialize(self):
"""
initialize with the necessary elements
"""
self.check_requirements()
self.pretrained_model_path = os.path.join(self.directory,
'inference_model')
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
config.enable_use_gpu(8000, 0)
else:
config.disable_gpu()
config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_tensor(input_names[0])
output_names = self.predictor.get_output_names()
self.output_tensors = []
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
self.output_tensors.append(output_tensor)
def check_requirements(self):
try:
import shapely, pyclipper
except:
logger.error(
'This module requires the shapely, pyclipper tools. The running enviroment does not meet the requirments. Please install the two packages.'
)
exit()
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
images.append(img)
return images
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
left = int(np.min(box[:, 0]))
right = int(np.max(box[:, 0]))
top = int(np.min(box[:, 1]))
bottom = int(np.max(box[:, 1]))
bbox_height = bottom - top
bbox_width = right - left
diffh = math.fabs(box[0, 1] - box[1, 1])
diffw = math.fabs(box[0, 0] - box[3, 0])
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def order_points_clockwise(self, pts):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
# sort the points based on their x-coordinates
"""
xSorted = pts[np.argsort(pts[:, 0]), :]
# grab the left-most and right-most points from the sorted
# x-roodinate points
leftMost = xSorted[:2, :]
rightMost = xSorted[2:, :]
# now, sort the left-most coordinates according to their
# y-coordinates so we can grab the top-left and bottom-left
# points, respectively
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
(tl, bl) = leftMost
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
(tr, br) = rightMost
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect
@serving
def detect_text(self,
images=[],
paths=[],
use_gpu=False,
output_dir='detection_result',
visualization=False,
box_thresh=0.5):
"""
Get the text box in the predicted images.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
use_gpu (bool): Whether to use gpu. Default false.
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
Returns:
res (list): The result of text detection box and save path of images.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
preprocessor = DBPreProcess()
postprocessor = DBPostProcess(box_thresh)
all_imgs = []
all_ratios = []
all_results = []
for original_image in predicted_data:
im, ratio_list = preprocessor(original_image)
res = {'save_path': ''}
if im is None:
res['data'] = []
else:
im = im.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run()
data_out = self.output_tensors[0].copy_to_cpu()
dt_boxes_list = postprocessor(data_out, [ratio_list])
boxes = self.filter_tag_det_res(dt_boxes_list[0],
original_image.shape)
res['data'] = boxes
all_imgs.append(im)
all_ratios.append(ratio_list)
if visualization:
img = Image.fromarray(
cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
draw_img = draw_boxes(img, boxes)
draw_img = np.array(draw_img)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
ext = get_image_ext(original_image)
saved_name = 'ndarray_{}{}'.format(time.time(), ext)
cv2.imwrite(
os.path.join(output_dir, saved_name),
draw_img[:, :, ::-1])
res['save_path'] = os.path.join(output_dir, saved_name)
all_results.append(res)
return all_results
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description="Run the chinese_text_detection_db module.",
prog='hub run chinese_text_detection_db',
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.detect_text(
paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization)
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument(
'--output_dir',
type=str,
default='detection_result',
help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument(
'--input_path', type=str, default=None, help="diretory to image")
if __name__ == '__main__':
db = ChineseTextDetectionDB()
image_path = ['../doc/imgs/11.jpg', '../doc/imgs/12.jpg']
res = db.detect_text(paths=image_path, visualization=True)
db.save_inference_model('save')
print(res)
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from PIL import Image, ImageDraw, ImageFont
from shapely.geometry import Polygon
import cv2
import numpy as np
import pyclipper
class DBPreProcess(object):
def __init__(self, max_side_len=960):
self.max_side_len = max_side_len
def resize_image_type(self, im):
"""
resize image to a size multiple of 32 which is required by the network
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# limit the max side
if max(resize_h, resize_w) > self.max_side_len:
if resize_h > resize_w:
ratio = float(self.max_side_len) / resize_h
else:
ratio = float(self.max_side_len) / resize_w
else:
ratio = 1.
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
if resize_h % 32 == 0:
resize_h = resize_h
elif resize_h // 32 <= 1:
resize_h = 32
else:
resize_h = (resize_h // 32 - 1) * 32
if resize_w % 32 == 0:
resize_w = resize_w
elif resize_w // 32 <= 1:
resize_w = 32
else:
resize_w = (resize_w // 32 - 1) * 32
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
except:
print(im.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def normalize(self, im):
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
im = im / 255
im -= img_mean
im /= img_std
channel_swap = (2, 0, 1)
im = im.transpose(channel_swap)
return im
def __call__(self, im):
im, (ratio_h, ratio_w) = self.resize_image_type(im)
im = self.normalize(im)
im = im[np.newaxis, :]
return [im, (ratio_h, ratio_w)]
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self, thresh=0.3, box_thresh=0.5, max_candidates=1000):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.min_size = 3
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
scores = np.zeros((num_contours, ), dtype=np.float32)
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes[index, :, :] = box.astype(np.int16)
scores[index] = score
return boxes, scores
def unclip(self, box, unclip_ratio=2.0):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, predictions, ratio_list):
pred = predictions[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(
pred[batch_index], segmentation[batch_index], width, height)
boxes = []
for k in range(len(tmp_boxes)):
if tmp_scores[k] > self.box_thresh:
boxes.append(tmp_boxes[k])
if len(boxes) > 0:
boxes = np.array(boxes)
ratio_h, ratio_w = ratio_list[batch_index]
boxes[:, :, 0] = boxes[:, :, 0] / ratio_w
boxes[:, :, 1] = boxes[:, :, 1] / ratio_h
boxes_batch.append(boxes)
return boxes_batch
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
img = image.copy()
draw = ImageDraw.Draw(img)
if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score:
continue
draw.line([(box[0][0], box[0][1]), (box[1][0], box[1][1])], fill='red')
draw.line([(box[1][0], box[1][1]), (box[2][0], box[2][1])], fill='red')
draw.line([(box[2][0], box[2][1]), (box[3][0], box[3][1])], fill='red')
draw.line([(box[3][0], box[3][1]), (box[0][0], box[0][1])], fill='red')
draw.line([(box[0][0] - 1, box[0][1] + 1),
(box[1][0] - 1, box[1][1] + 1)],
fill='red')
draw.line([(box[1][0] - 1, box[1][1] + 1),
(box[2][0] - 1, box[2][1] + 1)],
fill='red')
draw.line([(box[2][0] - 1, box[2][1] + 1),
(box[3][0] - 1, box[3][1] + 1)],
fill='red')
draw.line([(box[3][0] - 1, box[3][1] + 1),
(box[0][0] - 1, box[0][1] + 1)],
fill='red')
return img
def get_image_ext(image):
if image.shape[2] == 4:
return ".png"
return ".jpg"
name: chinese_ocr_db_rcnn
dir: "modules/image/text_recognition/chinese_ocr_db_rcnn"
exclude:
- README.md
resources:
-
url: https://bj.bcebos.com/paddlehub/model/image/ocr/chinese_ocr_db_rcnn_infer_model.tar.gz
dest: .
uncompress: True
name: chinese_text_detection_db
dir: "modules/image/text_recognition/chinese_text_detection_db"
exclude:
- README.md
resources:
-
url: https://bj.bcebos.com/paddlehub/model/image/ocr/chinese_text_detection_db_infer_model.tar.gz
dest: .
uncompress: True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import TestCase, main
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import paddlehub as hub
class ChineseOCRDBRCNNTestCase(TestCase):
def setUp(self):
self.module = hub.Module(name='chinese_ocr_db_rcnn')
self.test_images = [
"../image_dataset/text_recognition/11.jpg",
"../image_dataset/text_recognition/test_image.jpg"
]
def test_detect_text(self):
results_1 = self.module.recognize_text(
paths=self.test_images, use_gpu=True)
results_2 = self.module.recognize_text(
paths=self.test_images, use_gpu=False)
test_images = [cv2.imread(img) for img in self.test_images]
results_3 = self.module.recognize_text(
images=test_images, use_gpu=False)
for i, res in enumerate(results_1):
self.assertEqual(res['save_path'], '')
for j, item in enumerate(res['data']):
self.assertEqual(item['confidence'],
results_2[i]['data'][j]['confidence'])
self.assertEqual(item['confidence'],
results_3[i]['data'][j]['confidence'])
self.assertEqual(item['text'], results_2[i]['data'][j]['text'])
self.assertEqual(item['text'], results_3[i]['data'][j]['text'])
self.assertEqual(
(item['text_box_position'].all() == results_2[i]['data'][j]
['text_box_position'].all()), True)
self.assertEqual(
(item['text_box_position'].all() == results_3[i]['data'][j]
['text_box_position'].all()), True)
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import TestCase, main
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import paddlehub as hub
class ChineseTextDetectionDBTestCase(TestCase):
def setUp(self):
self.module = hub.Module(name='chinese_text_detection_db')
self.test_images = [
"../image_dataset/text_recognition/11.jpg",
"../image_dataset/text_recognition/test_image.jpg"
]
def test_detect_text(self):
results_1 = self.module.detect_text(
paths=self.test_images, use_gpu=True)
results_2 = self.module.detect_text(
paths=self.test_images, use_gpu=False)
test_images = [cv2.imread(img) for img in self.test_images]
results_3 = self.module.detect_text(images=test_images, use_gpu=False)
for index, res in enumerate(results_1):
self.assertEqual(res['save_path'], '')
self.assertEqual(
(res['data'].all() == results_2[index]['data'].all()), True)
self.assertEqual(
(res['data'].all() == results_3[index]['data'].all()), True)
if __name__ == '__main__':
main()
...@@ -4,7 +4,7 @@ yapf == 0.26.0 ...@@ -4,7 +4,7 @@ yapf == 0.26.0
six >= 1.10.0 six >= 1.10.0
flask >= 1.1.0 flask >= 1.1.0
flake8 flake8
visualdl == 2.0.0a0 visualdl >= 2.0.0b
cma >= 2.7.0 cma >= 2.7.0
sentencepiece sentencepiece
colorlog colorlog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册