diff --git a/python/examples/ocr_detection/7.jpg b/python/examples/ocr_detection/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a9483bb74f66d88699b09545366c32a4fe108e54 Binary files /dev/null and b/python/examples/ocr_detection/7.jpg differ diff --git a/python/examples/ocr_detection/text_det_client.py b/python/examples/ocr_detection/text_det_client.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa1c5b1179fcbf1d010bb9f6335ef2886435a83 --- /dev/null +++ b/python/examples/ocr_detection/text_det_client.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes + +client = Client() +client.load_client_config("ocr_det_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9494"]) + +read_image_file = File2Image() +preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) +]) +post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 +}) +filter_func = FilterBoxes(10, 10) + +img = read_image_file(name) +ori_h, ori_w, _ = img.shape +img = preprocess(img) +new_h, new_w, _ = img.shape +ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] +outputs = client.predict(feed={"image": img}, fetch=["concat_1.tmp_0"]) +dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list]) +dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) diff --git a/python/paddle_serving_app/models/model_list.py b/python/paddle_serving_app/models/model_list.py index d5f42ab78acdbe837a719908d27cda513da02c3f..0c26a59f6f0537b9c910f21062938d4720d4f9f4 100644 --- a/python/paddle_serving_app/models/model_list.py +++ b/python/paddle_serving_app/models/model_list.py @@ -31,6 +31,7 @@ class ServingModels(object): self.model_dict["ImageClassification"] = [ "resnet_v2_50_imagenet", "mobilenet_v2_imagenet" ] + self.model_dict["TextDetection"] = ["ocr_detection"] self.model_dict["OCR"] = ["ocr_rec"] image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/" @@ -40,6 +41,7 @@ class ServingModels(object): senta_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/" semantic_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/" wordseg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/LexicalAnalysis/" + ocr_det_url = "https://paddle-serving.bj.bcebos.com/ocr/" self.url_dict = {} @@ -55,6 +57,7 @@ class ServingModels(object): pack_url(self.model_dict, "ImageSegmentation", image_seg_url) pack_url(self.model_dict, "ImageClassification", image_class_url) pack_url(self.model_dict, "OCR", ocr_url) + pack_url(self.model_dict, "TextDetection", ocr_det_url) def get_model_list(self): return self.model_dict diff --git a/python/paddle_serving_app/reader/__init__.py b/python/paddle_serving_app/reader/__init__.py index b2b5e75ac430ecf897e34ec7afc994c9ccf8ee66..e15a93084cbd437531129b48b51fe852ce17d19b 100644 --- a/python/paddle_serving_app/reader/__init__.py +++ b/python/paddle_serving_app/reader/__init__.py @@ -13,8 +13,9 @@ # limitations under the License. from .chinese_bert_reader import ChineseBertReader from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize -from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB +from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import RCNNPostprocess, SegPostprocess, PadStride +from .image_reader import DBPostProcess, FilterBoxes from .lac_reader import LACReader from .senta_reader import SentaReader from .imdb_reader import IMDBDataset diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py index 7f4a795513447d74e7f02d7741344ccae81c7c9d..59b9ee41442dd5e8a7c11ba5fb25e8ffed601ad7 100644 --- a/python/paddle_serving_app/reader/image_reader.py +++ b/python/paddle_serving_app/reader/image_reader.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import cv2 import os import numpy as np @@ -18,6 +21,8 @@ import base64 import sys from . import functional as F from PIL import Image, ImageDraw +from shapely.geometry import Polygon +import pyclipper import json _cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"} @@ -43,6 +48,196 @@ def generate_colormap(num_classes): return color_map +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, params): + self.thresh = params['thresh'] + self.box_thresh = params['box_thresh'] + self.max_candidates = params['max_candidates'] + self.unclip_ratio = params['unclip_ratio'] + self.min_size = 3 + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) + scores = np.zeros((num_contours, ), dtype=np.float32) + + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes[index, :, :] = box.astype(np.int16) + scores[index] = score + return boxes, scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def __call__(self, pred, ratio_list): + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + height, width = pred.shape[-2:] + tmp_boxes, tmp_scores = self.boxes_from_bitmap( + pred[batch_index], segmentation[batch_index], width, height) + + boxes = [] + for k in range(len(tmp_boxes)): + if tmp_scores[k] > self.box_thresh: + boxes.append(tmp_boxes[k]) + if len(boxes) > 0: + boxes = np.array(boxes) + + ratio_h, ratio_w = ratio_list[batch_index] + boxes[:, :, 0] = boxes[:, :, 0] / ratio_w + boxes[:, :, 1] = boxes[:, :, 1] / ratio_h + + boxes_batch.append(boxes) + return boxes_batch + + def __repr__(self): + return self.__class__.__name__ + \ + " thresh: {1}, box_thresh: {2}, max_candidates: {3}, unclip_ratio: {4}, min_size: {5}".format( + self.thresh, self.box_thresh, self.max_candidates, self.unclip_ratio, self.min_size) + + +class FilterBoxes(object): + def __init__(self, width, height): + self.filter_width = width + self.filter_height = height + + def order_points_clockwise(self, pts): + """ + reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py + # sort the points based on their x-coordinates + """ + xSorted = pts[np.argsort(pts[:, 0]), :] + + # grab the left-most and right-most points from the sorted + # x-roodinate points + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + # now, sort the left-most coordinates according to their + # y-coordinates so we can grab the top-left and bottom-left + # points, respectively + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + rightMost = rightMost[np.argsort(rightMost[:, 1]), :] + (tr, br) = rightMost + + rect = np.array([tl, tr, br, bl], dtype="float32") + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(4): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def __call__(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= self.filter_width or \ + rect_height <= self.filter_height: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __repr__(self): + return self.__class__.__name__ + " filter_width: {1}, filter_height: {2}".format( + self.filter_width, self.filter_height) + + class SegPostprocess(object): def __init__(self, class_num): self.class_num = class_num @@ -473,6 +668,57 @@ class Resize(object): _cv2_interpolation_to_str[self.interpolation]) +class ResizeByFactor(object): + """Resize the input numpy array Image to a size multiple of factor which is usually required by a network + + Args: + factor (int): Resize factor. make width and height multiple factor of the value of factor. Default is 32 + max_side_len (int): max size of width and height. if width or height is larger than max_side_len, just resize the width or the height. Default is 2400 + """ + + def __init__(self, factor=32, max_side_len=2400): + self.factor = factor + self.max_side_len = max_side_len + + def __call__(self, img): + h, w, _ = img.shape + resize_w = w + resize_h = h + if max(resize_h, resize_w) > self.max_side_len: + if resize_h > resize_w: + ratio = float(self.max_side_len) / resize_h + else: + ratio = float(self.max_side_len) / resize_w + else: + ratio = 1. + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + if resize_h % self.factor == 0: + resize_h = resize_h + elif resize_h // self.factor <= 1: + resize_h = self.factor + else: + resize_h = (resize_h // 32 - 1) * 32 + if resize_w % self.factor == 0: + resize_w = resize_w + elif resize_w // self.factor <= 1: + resize_w = self.factor + else: + resize_w = (resize_w // self.factor - 1) * self.factor + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + im = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(resize_w, resize_h) + sys.exit(0) + return im + + def __repr__(self): + return self.__class__.__name__ + '(factor={0}, max_side_len={1})'.format( + self.factor, self.max_side_len) + + class PadStride(object): def __init__(self, stride): self.coarsest_stride = stride