diff --git a/python/examples/ocr/ocr_web_server.py b/python/examples/ocr/ocr_web_server.py index b55027d84252f8590f1e62839ad8cbd25e56c8fe..a9af2696ebb7cfdb402a81b836b466262162ae04 100644 --- a/python/examples/ocr/ocr_web_server.py +++ b/python/examples/ocr/ocr_web_server.py @@ -21,11 +21,11 @@ import os from paddle_serving_client import Client from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes from paddle_serving_server_gpu.web_service import WebService import time import re - +import base64 class OCRService(WebService): def init_det_client(self, det_port, det_client_config): @@ -37,74 +37,16 @@ class OCRService(WebService): self.det_client = Client() self.det_client.load_client_config(det_client_config) self.det_client.connect(["127.0.0.1:{}".format(det_port)]) + self.ocr_reader = OCRReader() def preprocess(self, feed=[], fetch=[]): - img_url = feed[0]["image"] - #print(feed, img_url) - read_from_url = URL2Image() - im = read_from_url(img_url) + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) ori_h, ori_w, _ = im.shape det_img = self.det_preprocess(im) - #print("det_img", det_img, det_img.shape) det_out = self.det_client.predict( - feed={"image": det_img}, fetch=["concat_1.tmp_0"]) - - #print("det_out", det_out) - def sorted_boxes(dt_boxes): - num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) - _boxes = list(sorted_boxes) - for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): - tmp = _boxes[i] - _boxes[i] = _boxes[i + 1] - _boxes[i + 1] = tmp - return _boxes - - def get_rotate_crop_image(img, points): - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - img_crop_width = int(np.linalg.norm(points[0] - points[1])) - img_crop_height = int(np.linalg.norm(points[0] - points[3])) - pts_std = np.float32([[0, 0], [img_crop_width, 0], \ - [img_crop_width, img_crop_height], [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img_crop, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - - def resize_norm_img(img, max_wh_ratio): - import math - imgC, imgH, imgW = 3, 32, 320 - imgW = int(32 * max_wh_ratio) - h = img.shape[0] - w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image - return padding_im - + feed={"image": det_img}, fetch=["concat_1.tmp_0"]) _, new_h, new_w = det_img.shape filter_func = FilterBoxes(10, 10) post_func = DBPostProcess({ @@ -114,10 +56,12 @@ class OCRService(WebService): "unclip_ratio": 1.5, "min_size": 3 }) + sorted_boxes = SortedBoxes() ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) dt_boxes = sorted_boxes(dt_boxes) + get_rotate_crop_image = GetRotateCropImage() feed_list = [] img_list = [] max_wh_ratio = 0 @@ -128,24 +72,20 @@ class OCRService(WebService): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for img in img_list: - norm_img = resize_norm_img(img, max_wh_ratio) + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) feed = {"image": norm_img} feed_list.append(feed) - fetch = ["ctc_greedy_decoder_0.tmp_0"] - #print("feed_list", feed_list) + fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + print(feed_list) return feed_list, fetch def postprocess(self, feed={}, fetch=[], fetch_map=None): - #print(fetch_map) - ocr_reader = OCRReader() - rec_res = ocr_reader.postprocess(fetch_map) + rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) res_lst = [] for res in rec_res: res_lst.append(res[0]) - fetch_map["res"] = res_lst - del fetch_map["ctc_greedy_decoder_0.tmp_0"] - del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"] - return fetch_map + res = {"res": res_lst} + return res ocr_service = OCRService(name="ocr") diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py index 93039c6fdd467357b589bbb2889f3c2d3208b538..18acc8228122de145b4e970d7eb5a90b95be8d44 100644 --- a/python/paddle_serving_app/local_predict.py +++ b/python/paddle_serving_app/local_predict.py @@ -122,11 +122,13 @@ class Debugger(object): feed[name] = feed[name].astype("int64") else: feed[name] = feed[name].astype("float32") - inputs.append(PaddleTensor(feed[name][np.newaxis, :])) + inputs.append(PaddleTensor(feed[name])) outputs = self.predictor.run(inputs) fetch_map = {} for name in fetch: fetch_map[name] = outputs[self.fetch_names_to_idx_[ name]].as_ndarray() + if len(outputs[self.fetch_names_to_idx_[name]].lod) > 0: + fetch_map[name+".lod"] = outputs[self.fetch_names_to_idx_[name]].lod[0] return fetch_map diff --git a/python/paddle_serving_app/reader/__init__.py b/python/paddle_serving_app/reader/__init__.py index e15a93084cbd437531129b48b51fe852ce17d19b..93e2cd76102d93f52955060055afda34f9576ed8 100644 --- a/python/paddle_serving_app/reader/__init__.py +++ b/python/paddle_serving_app/reader/__init__.py @@ -15,7 +15,7 @@ from .chinese_bert_reader import ChineseBertReader from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import RCNNPostprocess, SegPostprocess, PadStride -from .image_reader import DBPostProcess, FilterBoxes +from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes from .lac_reader import LACReader from .senta_reader import SentaReader from .imdb_reader import IMDBDataset diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py index 096f46549af137cb04a87e26a3b28c8d42e33daa..830e7703dc290fa25ade5bb70b92b7e125e0beb9 100644 --- a/python/paddle_serving_app/reader/image_reader.py +++ b/python/paddle_serving_app/reader/image_reader.py @@ -781,6 +781,55 @@ class Transpose(object): "({})".format(self.transpose_target) return format_string +class SortedBoxes(object): + """ + Sorted bounding boxes from Detection + """ + def __init__(self): + pass + + def __call__(self, dt_boxes): + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + +class GetRotateCropImage(object): + """ + Rotate and Crop image from OCR Det output + """ + def __init__(self): + pass + + def __call__(self, img, points): + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + img_crop_width = int(np.linalg.norm(points[0] - points[1])) + img_crop_height = int(np.linalg.norm(points[0] - points[3])) + pts_std = np.float32([[0, 0], [img_crop_width, 0], \ + [img_crop_width, img_crop_height], [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img_crop, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + class ImageReader(): def __init__(self, diff --git a/python/paddle_serving_app/reader/ocr_reader.py b/python/paddle_serving_app/reader/ocr_reader.py index 72a2918f89a8ccc913894f3f46fab08f51cf9460..e10ff72562791e1cc3dfc8c9088f3d0e60f34cc3 100644 --- a/python/paddle_serving_app/reader/ocr_reader.py +++ b/python/paddle_serving_app/reader/ocr_reader.py @@ -120,29 +120,16 @@ class CharacterOps(object): class OCRReader(object): - def __init__(self): - args = self.parse_args() - image_shape = [int(v) for v in args.rec_image_shape.split(",")] + def __init__(self, algorithm="CRNN", image_shape=[3,32,320], char_type="ch", batch_num=1, char_dict_path="./ppocr_keys_v1.txt"): self.rec_image_shape = image_shape - self.character_type = args.rec_char_type - self.rec_batch_num = args.rec_batch_num + self.character_type = char_type + self.rec_batch_num = batch_num char_ops_params = {} - char_ops_params["character_type"] = args.rec_char_type - char_ops_params["character_dict_path"] = args.rec_char_dict_path + char_ops_params["character_type"] = char_type + char_ops_params["character_dict_path"] = char_dict_path char_ops_params['loss_type'] = 'ctc' self.char_ops = CharacterOps(char_ops_params) - - def parse_args(self): - parser = argparse.ArgumentParser() - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=1) - parser.add_argument( - "--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt") - return parser.parse_args() - + def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape if self.character_type == "ch": @@ -154,17 +141,17 @@ class OCRReader(object): resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) - - seq = Sequential([ - Resize(imgH, resized_w), Transpose((2, 0, 1)), Div(255), - Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], True) - ]) - resized_image = seq(img) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image + padding_im[:, :, 0:resized_w] = resized_image return padding_im + def preprocess(self, img_list): img_num = len(img_list) norm_img_batch = [] @@ -191,11 +178,16 @@ class OCRReader(object): for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] end = rec_idx_lod[rno + 1] - rec_idx_tmp = rec_idx_batch[beg:end, 0] + if isinstance(rec_idx_batch, list): + rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]] + else: #nd array + rec_idx_tmp = rec_idx_batch[beg:end, 0] preds_text = self.char_ops.decode(rec_idx_tmp) if with_score: beg = predict_lod[rno] end = predict_lod[rno + 1] + if isinstance(outputs["softmax_0.tmp_0"], list): + outputs["softmax_0.tmp_0"] = np.array(outputs["softmax_0.tmp_0"]).astype(np.float32) probs = outputs["softmax_0.tmp_0"][beg:end, :] ind = np.argmax(probs, axis=1) blank = probs.shape[1] diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 0eff9c72df84b30ded7dbc7c2e82a96bbd591162..4de310152436f3de499368a67e4336d173031f98 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -129,7 +129,8 @@ class WebService(object): del feed["fetch"] fetch_map = self.client.predict(feed=feed, fetch=fetch) for key in fetch_map: - fetch_map[key] = fetch_map[key].tolist() + if isinstance(fetch_map[key], np.ndarray): + fetch_map[key] = fetch_map[key].tolist() result = self.postprocess( feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) result = {"result": result} @@ -164,6 +165,32 @@ class WebService(object): self.app_instance = app_instance + # TODO: maybe change another API name: maybe run_local_predictor? + def run_debugger_service(self, gpu=False): + import socket + localIP = socket.gethostbyname(socket.gethostname()) + print("web service address:") + print("http://{}:{}/{}/prediction".format(localIP, self.port, + self.name)) + app_instance = Flask(__name__) + + @app_instance.before_first_request + def init(): + self._launch_local_predictor(gpu) + + service_name = "/" + self.name + "/prediction" + + @app_instance.route(service_name, methods=["POST"]) + def run(): + return self.get_prediction(request) + + self.app_instance = app_instance + + def _launch_local_predictor(self, gpu): + from paddle_serving_app.local_predict import Debugger + self.client = Debugger() + self.client.load_model_config("{}".format(self.model_config), gpu=gpu, profile=False) + def run_web_service(self): self.app_instance.run(host="0.0.0.0", port=self.port,