提交 9b3119df 编写于 作者: C chenjian

add module

上级 7f9274d9
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import base64
import math
import os
import time
import cv2
import numpy as np
import paddle.fluid as fluid
import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from PIL import Image
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
@moduleinfo(
name="ppocrv3_det_ch",
version="1.0.0",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseTextDetectionDB(hub.Module):
def _initialize(self, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory, 'inference_model', 'ppocrv3_det')
self.enable_mkldnn = enable_mkldnn
self._set_config()
def check_requirements(self):
try:
import shapely, pyclipper
except:
raise ImportError(
'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.'
)
def _set_config(self):
"""
predictor config setting
"""
model_file_path = self.pretrained_model_path + '.pdmodel'
params_file_path = self.pretrained_model_path + '.pdiparams'
config = paddle_infer.Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
config.enable_use_gpu(8000, 0)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
self.predictor = paddle_infer.create_predictor(config)
input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.predictor.get_output_names()
self.output_tensors = []
for output_name in output_names:
output_tensor = self.predictor.get_output_handle(output_name)
self.output_tensors.append(output_tensor)
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
images.append(img)
return images
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def detect_text(self,
images=[],
paths=[],
use_gpu=False,
output_dir='detection_result',
visualization=False,
box_thresh=0.5):
"""
Get the text box in the predicted images.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
use_gpu (bool): Whether to use gpu. Default false.
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
Returns:
res (list): The result of text detection box and save path of images.
"""
self.check_requirements()
from .processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
preprocessor = DBProcessTest(params={'max_side_len': 960})
postprocessor = DBPostProcess(params={
'thresh': 0.3,
'box_thresh': 0.6,
'max_candidates': 1000,
'unclip_ratio': 1.5
})
all_imgs = []
all_ratios = []
all_results = []
for original_image in predicted_data:
ori_im = original_image.copy()
im, ratio_list = preprocessor(original_image)
print('after preprocess int det, shape{}'.format(im.shape))
res = {'save_path': ''}
if im is None:
res['data'] = []
else:
im = im.copy()
self.input_tensor.copy_from_cpu(im)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
outs_dict = {}
outs_dict['maps'] = outputs[0]
# data_out = self.output_tensors[0].copy_to_cpu()
print('Outputs[0] in det, shape: {}'.format(outputs[0].shape))
dt_boxes_list = postprocessor(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0]
print('after postprocess int det, shape{}'.format(dt_boxes.shape))
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
print('after fitler tag int det, shape{}'.format(boxes.shape))
res['data'] = boxes.astype(np.int).tolist()
print('boxes: {}'.format(boxes))
all_imgs.append(im)
all_ratios.append(ratio_list)
if visualization:
img = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
draw_img = draw_boxes(img, boxes)
draw_img = np.array(draw_img)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
ext = get_image_ext(original_image)
saved_name = 'ndarray_{}{}'.format(time.time(), ext)
cv2.imwrite(os.path.join(output_dir, saved_name), draw_img[:, :, ::-1])
res['save_path'] = os.path.join(output_dir, saved_name)
all_results.append(res)
return all_results
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.detect_text(images=images_decode, **kwargs)
return results
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(description="Run the %s module." % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.detect_text(paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization)
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='detection_result',
help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import cv2
import numpy as np
import paddle
import pyclipper
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from shapely.geometry import Polygon
class DBProcessTest(object):
"""
DB pre-process for Test mode
"""
def __init__(self, params):
super(DBProcessTest, self).__init__()
self.resize_type = 0
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
else:
self.max_side_len = 2400
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.max_side_len
h, w, _ = img.shape
# limit the max side
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w]
def resize_image_type1(self, im):
resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
return im, (ratio_h, ratio_w)
def normalize(self, im):
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
im = im / 255
im[:, :, 0] -= img_mean[0]
im[:, :, 1] -= img_mean[1]
im[:, :, 2] -= img_mean[2]
im[:, :, 0] /= img_std[0]
im[:, :, 1] /= img_std[1]
im[:, :, 2] /= img_std[2]
channel_swap = (2, 0, 1)
im = im.transpose(channel_swap)
return im
def __call__(self, im):
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
im = self.normalize(im)
im = im[np.newaxis, :]
return [im, (ratio_h, ratio_w)]
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self, params):
self.thresh = params['thresh']
self.box_thresh = params['box_thresh']
self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio']
self.min_size = 3
self.dilation_kernel = None
self.score_mode = 'fast'
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, ratio_list):
pred = outs_dict['maps']
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
mask = segmentation[batch_index]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
boxes_batch.append(tmp_boxes)
return boxes_batch
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
img = image.copy()
draw = ImageDraw.Draw(img)
if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score:
continue
draw.line([(box[0][0], box[0][1]), (box[1][0], box[1][1])], fill='red')
draw.line([(box[1][0], box[1][1]), (box[2][0], box[2][1])], fill='red')
draw.line([(box[2][0], box[2][1]), (box[3][0], box[3][1])], fill='red')
draw.line([(box[3][0], box[3][1]), (box[0][0], box[0][1])], fill='red')
draw.line([(box[0][0] - 1, box[0][1] + 1), (box[1][0] - 1, box[1][1] + 1)], fill='red')
draw.line([(box[1][0] - 1, box[1][1] + 1), (box[2][0] - 1, box[2][1] + 1)], fill='red')
draw.line([(box[2][0] - 1, box[2][1] + 1), (box[3][0] - 1, box[3][1] + 1)], fill='red')
draw.line([(box[3][0] - 1, box[3][1] + 1), (box[0][0] - 1, box[0][1] + 1)], fill='red')
return img
def get_image_ext(image):
if image.shape[2] == 4:
return ".png"
return ".jpg"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string
import numpy as np
class CharacterOps(object):
""" Convert between text-label and text-index
Args:
config: config from yaml file
"""
def __init__(self, config):
self.character_type = config['character_type']
self.max_text_len = config['max_text_length']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if add_space:
self.character_str.append(" ")
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
else:
self.character_str = None
self.beg_str = "sos"
self.end_str = "eos"
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if self.character_type == "en":
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
text = np.array(text_list)
return text
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
# print(text_index)
# print(batch_idx)
# print(selection)
# for text_id in text_index[batch_idx][selection]:
# print(text_id)
# print(self.character[text_id])
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_char_num(self):
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
if self.loss_type == "attention":
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx"\
% beg_or_end
return idx
else:
err = "error in get_beg_end_flag_idx when using the loss %s"\
% (self.loss_type)
assert False, err
def get_ignored_tokens(self):
return [0] # for ctc blank
def cal_predicts_accuracy(char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate=False):
"""
Calculate prediction accuracy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod: lod tensor of preds
labels: label of input image, text index
labels_lod: lod tensor of label
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
"""
acc_num = 0
img_num = 0
for ino in range(len(labels_lod) - 1):
beg_no = preds_lod[ino]
end_no = preds_lod[ino + 1]
preds_text = preds[beg_no:end_no].reshape(-1)
preds_text = char_ops.decode(preds_text, is_remove_duplicate)
beg_no = labels_lod[ino]
end_no = labels_lod[ino + 1]
labels_text = labels[beg_no:end_no].reshape(-1)
labels_text = char_ops.decode(labels_text, is_remove_duplicate)
img_num += 1
if preds_text == labels_text:
acc_num += 1
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops, preds, labels, max_text_len, is_debug=False):
acc_num = 0
img_num = 0
char_num = char_ops.get_char_num()
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != int(char_num - 1): #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num - 1):
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
target_lod = [0]
convert_ids = []
for ino in range(img_num):
end_pos = np.where(preds[ino, :] == 1)[0]
if len(end_pos) <= 1:
text_list = preds[ino, 1:]
else:
text_list = preds[ino, 1:end_pos[1]]
target_lod.append(target_lod[ino] + len(text_list))
convert_ids = convert_ids + list(text_list)
convert_ids = np.array(convert_ids)
convert_ids = convert_ids.reshape((-1, 1))
return convert_ids, target_lod
def convert_rec_label_to_lod(ori_labels):
img_num = len(ori_labels)
target_lod = [0]
convert_ids = []
for ino in range(img_num):
target_lod.append(target_lod[ino] + len(ori_labels[ino]))
convert_ids = convert_ids + list(ori_labels[ino])
convert_ids = np.array(convert_ids)
convert_ids = convert_ids.reshape((-1, 1))
return convert_ids, target_lod
# -*- coding:utf-8 -*-
import argparse
import ast
import copy
import math
import os
import time
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.core import PaddleTensor
from PIL import Image
import paddlehub as hub
from .character import CharacterOps
from .utils import base64_to_cv2
from .utils import draw_ocr
from .utils import get_image_ext
from .utils import sorted_boxes
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(
name="ppocrv3_rec_ch",
version="1.0.0",
summary="The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
class ChineseOCRDBCRNN(hub.Module):
def _initialize(self, text_detector_module=None, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
self.character_dict_path = os.path.join(self.directory, 'assets', 'ppocr_keys_v1.txt')
char_ops_params = {
'character_type': 'ch',
'character_dict_path': self.character_dict_path,
'loss_type': 'ctc',
'max_text_length': 25,
'use_space_char': True
}
self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320]
self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.enable_mkldnn = enable_mkldnn
self.rec_pretrained_model_path = os.path.join(self.directory, 'inference_model', 'ppocrv3_rec')
self.cls_pretrained_model_path = os.path.join(self.directory, 'inference_model', 'ppocr_cls')
self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config(
self.rec_pretrained_model_path)
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
self.cls_pretrained_model_path)
def _set_config(self, pretrained_model_path):
"""
predictor config path
"""
model_file_path = pretrained_model_path + '.pdmodel'
params_file_path = pretrained_model_path + '.pdiparams'
config = paddle_infer.Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
config.enable_use_gpu(8000, 0)
else:
config.disable_gpu()
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
predictor = paddle_infer.create_predictor(config)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
output_names = predictor.get_output_names()
output_handles = []
for output_name in output_names:
output_handle = predictor.get_output_handle(output_name)
output_handles.append(output_handle)
return predictor, input_handle, output_handles
@property
def text_detector_module(self):
"""
text detect module
"""
if not self._text_detector_module:
self._text_detector_module = hub.Module(name='ppocrv3_det_ch',
enable_mkldnn=self.enable_mkldnn,
version='1.0.0')
return self._text_detector_module
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
images.append(img)
return images
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])))
img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img_rec(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_cls(self, img):
cls_image_shape = [3, 48, 192]
imgC, imgH, imgW = cls_image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def recognize_text(self,
images=[],
paths=[],
use_gpu=False,
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5,
angle_classification_thresh=0.9):
"""
Get the chinese texts in the predicted images.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
use_gpu (bool): Whether to use gpu.
batch_size(int): the program deals once with one
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
Returns:
res (list): The result of chinese texts and save path of images.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
self.use_gpu = use_gpu
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
detection_results = self.text_detector_module.detect_text(images=predicted_data,
use_gpu=self.use_gpu,
box_thresh=box_thresh)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
print("dt_boxes num : {}".format(len(boxes[0])))
all_results = []
for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy()
result = {'save_path': ''}
if img_boxes.size == 0:
result['data'] = []
else:
img_crop_list = []
boxes = sorted_boxes(img_boxes)
for num_box in range(len(boxes)):
tmp_box = copy.deepcopy(boxes[num_box])
img_crop = self.get_rotate_crop_image(original_image, tmp_box)
img_crop_list.append(img_crop)
print('img_crop shape {}'.format(img_crop.shape))
img_crop_list, angle_list = self._classify_text(img_crop_list,
angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it
rec_res_final = []
for index, res in enumerate(rec_results):
text, score = res
if score >= text_thresh:
rec_res_final.append({
'text': text,
'confidence': float(score),
'text_box_position': boxes[index].astype(np.int).tolist()
})
result['data'] = rec_res_final
if visualization and result['data']:
result['save_path'] = self.save_result_image(original_image, boxes, rec_results, output_dir,
text_thresh)
all_results.append(result)
return all_results
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.recognize_text(images_decode, **kwargs)
return results
def save_result_image(
self,
original_image,
detection_boxes,
rec_results,
output_dir='ocr_result',
text_thresh=0.5,
):
image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
txts = [item[0] for item in rec_results]
scores = [item[1] for item in rec_results]
draw_img = draw_ocr(image,
detection_boxes,
txts,
scores,
font_file=self.font_file,
draw_txt=True,
drop_score=text_thresh)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
ext = get_image_ext(original_image)
saved_name = 'ndarray_{}{}'.format(time.time(), ext)
save_file_path = os.path.join(output_dir, saved_name)
cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path
def _classify_text(self, image_list, angle_classification_thresh=0.9):
img_list = copy.deepcopy(image_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
cls_res = [['', 0.0]] * img_num
batch_num = 6
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img_cls(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
self.cls_input_tensor.copy_from_cpu(norm_img_batch)
self.cls_predictor.run()
prob_out = self.cls_output_tensors[0].copy_to_cpu()
## post process
label_list = ['0', '180']
pred_idxs = prob_out.argmax(axis=1)
cls_result = [(label_list[idx], prob_out[i, idx]) for i, idx in enumerate(pred_idxs)]
for rno in range(len(cls_result)):
label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > angle_classification_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res
def _recognize_text(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = 6
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img_rec(img_list[indices[ino]], max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
self.rec_input_tensor.copy_from_cpu(norm_img_batch)
self.rec_predictor.run()
##
outputs = []
for output_tensor in self.rec_output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if len(outputs) != 1:
preds = outputs
else:
preds = outputs[0]
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
print('preds.shape: {}', preds.shape)
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
# print('preds_idx: {} \n preds_prob: {}'.format(preds_idx, preds_prob) )
rec_result = self.char_ops.decode(preds_idx, preds_prob, is_remove_duplicate=True)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(description="Run the %s module." % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.recognize_text(paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization)
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='ocr_result',
help="The directory to save output images.")
self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="diretory to image")
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import math
import cv2
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
def draw_ocr(image, boxes, txts, scores, font_file, draw_txt=True, drop_score=0.5):
"""
Visualize the results of OCR detection and recognition
args:
image(Image|array): RGB image
boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts
scores(list): txxs corresponding scores
draw_txt(bool): whether draw text or not
drop_score(float): only scores greater than drop_threshold will be visualized
return(array):
the visualized img
"""
if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score or math.isnan(score):
continue
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
if draw_txt:
img = np.array(resize_img(image, input_size=600))
txt_img = text_visual(txts, scores, font_file, img_h=img.shape[0], img_w=600, threshold=drop_score)
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
return img
return image
def text_visual(texts, scores, font_file, img_h=400, img_w=600, threshold=0.):
"""
create new blank img and draw txt on it
args:
texts(list): the text will be draw
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_w(int): the width of blank img
return(array):
"""
if scores is not None:
assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
blank_img[:, img_w - 1:] = 0
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
return blank_img, draw_txt
blank_img, draw_txt = create_blank_img()
font_size = 20
txt_color = (0, 0, 0)
font = ImageFont.truetype(font_file, font_size, encoding="utf-8")
gap = font_size + 5
txt_img_list = []
count, index = 1, 0
for idx, txt in enumerate(texts):
index += 1
if scores[idx] < threshold or math.isnan(scores[idx]):
index -= 1
continue
first_line = True
while str_count(txt) >= img_w // font_size - 4:
tmp = txt
txt = tmp[:img_w // font_size - 4]
if first_line:
new_txt = str(index) + ': ' + txt
first_line = False
else:
new_txt = ' ' + txt
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
txt = tmp[img_w // font_size - 4:]
if count >= img_h // gap - 1:
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
if first_line:
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
else:
new_txt = " " + txt + " " + '%.3f' % (scores[idx])
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
# whether add new blank img or not
if count >= img_h // gap - 1 and idx + 1 < len(texts):
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
txt_img_list.append(np.array(blank_img))
if len(txt_img_list) == 1:
blank_img = np.array(txt_img_list[0])
else:
blank_img = np.concatenate(txt_img_list, axis=1)
return np.array(blank_img)
def str_count(s):
"""
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
the number of Chinese characters
"""
import string
count_zh = count_pu = 0
s_len = len(s)
en_dg_count = 0
for c in s:
if c in string.ascii_letters or c.isdigit() or c.isspace():
en_dg_count += 1
elif c.isalpha():
count_zh += 1
else:
count_pu += 1
return s_len - math.ceil(en_dg_count / 2)
def resize_img(img, input_size=600):
img = np.array(img)
im_shape = img.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
return im
def get_image_ext(image):
if image.shape[2] == 4:
return ".png"
return ".jpg"
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册