diff --git a/deploy/pdserving/clas_local_server.py b/deploy/pdserving/clas_local_server.py new file mode 100644 index 0000000000000000000000000000000000000000..abf5e2d8af37e05f05a9945ffe194fe1963b613e --- /dev/null +++ b/deploy/pdserving/clas_local_server.py @@ -0,0 +1,128 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +import time +import re +import base64 +from tools.infer.predict_cls import TextClassifier +import tools.infer.utility as utility + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextClassifierHelper(TextClassifier): + def __init__(self, args): + self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] + self.cls_batch_num = args.rec_batch_num + self.label_list = args.label_list + self.cls_thresh = args.cls_thresh + self.fetch = [ + "save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0" + ] + + def preprocess(self, img_list): + args = {} + img_num = len(img_list) + args["img_list"] = img_list + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + args["indices"] = indices + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + predict_time = 0 + beg_img_no, end_img_no = 0, img_num + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + feed = {"image": norm_img_batch.copy()} + return feed, self.fetch, args + + def postprocess(self, outputs, args): + prob_out = outputs[0] + label_out = outputs[1] + indices = args["indices"] + cls_res = [['', 0.0]] * len(label_out) + if len(label_out.shape) != 1: + prob_out, label_out = label_out, prob_out + for rno in range(len(label_out)): + label_idx = label_out[rno] + score = prob_out[rno][label_idx] + label = self.label_list[label_idx] + cls_res[indices[rno]] = [label, score] + if '180' in label and score > self.cls_thresh: + img_list[indices[rno]] = cv2.rotate(img_list[indices[rno]], 1) + return args["img_list"], cls_res + + +class OCRService(WebService): + def init_rec(self): + self.ocr_reader = OCRReader() + self.text_classifier = TextClassifierHelper(global_args) + + def preprocess(self, feed=[], fetch=[]): + img_list = [] + for feed_data in feed: + data = base64.b64decode(feed_data["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + img_list.append(im) + feed, fetch, self.tmp_args = self.text_classifier.preprocess(img_list) + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + outputs = [fetch_map[x] for x in self.text_classifier.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + _, rec_res = self.text_classifier.postprocess(outputs, self.tmp_args) + res = { + "pred_text": [x[0] for x in rec_res], + "score": [str(x[1]) for x in rec_res] + } + return res + + +if __name__ == "__main__": + ocr_service = OCRService(name="ocr") + ocr_service.load_model_config("cls_server") + ocr_service.init_rec() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") + ocr_service.run_debugger_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/clas_rpc_server.py b/deploy/pdserving/clas_rpc_server.py new file mode 100644 index 0000000000000000000000000000000000000000..cc96199b5f7a8cddb64cf9a3c5b4e4e82e0ca95c --- /dev/null +++ b/deploy/pdserving/clas_rpc_server.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +import time +import re +import base64 +from tools.infer.predict_cls import TextClassifier +import tools.infer.utility as utility + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextClassifierHelper(TextClassifier): + def __init__(self, args): + self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] + self.cls_batch_num = args.rec_batch_num + self.label_list = args.label_list + self.cls_thresh = args.cls_thresh + self.fetch = [ + "save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0" + ] + + def preprocess(self, img_list): + args = {} + img_num = len(img_list) + args["img_list"] = img_list + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + args["indices"] = indices + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + predict_time = 0 + beg_img_no, end_img_no = 0, img_num + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + if img_num > 1: + feed = [{ + "image": norm_img_batch[x] + } for x in range(norm_img_batch.shape[0])] + else: + feed = {"image": norm_img_batch[0]} + return feed, self.fetch, args + + def postprocess(self, outputs, args): + prob_out = outputs[0] + label_out = outputs[1] + indices = args["indices"] + cls_res = [['', 0.0]] * len(label_out) + if len(label_out.shape) != 1: + prob_out, label_out = label_out, prob_out + for rno in range(len(label_out)): + label_idx = label_out[rno] + score = prob_out[rno][label_idx] + label = self.label_list[label_idx] + cls_res[indices[rno]] = [label, score] + if '180' in label and score > self.cls_thresh: + img_list[indices[rno]] = cv2.rotate(img_list[indices[rno]], 1) + return args["img_list"], cls_res + + +class OCRService(WebService): + def init_rec(self): + self.ocr_reader = OCRReader() + self.text_classifier = TextClassifierHelper(global_args) + + def preprocess(self, feed=[], fetch=[]): + # TODO: to handle batch rec images + img_list = [] + for feed_data in feed: + data = base64.b64decode(feed_data["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + img_list.append(im) + feed, fetch, self.tmp_args = self.text_classifier.preprocess(img_list) + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + outputs = [fetch_map[x] for x in self.text_classifier.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + _, rec_res = self.text_classifier.postprocess(outputs, self.tmp_args) + res = { + "direction": [x[0] for x in rec_res], + "score": [str(x[1]) for x in rec_res] + } + return res + + +if __name__ == "__main__": + ocr_service = OCRService(name="ocr") + ocr_service.load_model_config(global_args.cls_model_dir) + ocr_service.init_rec() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") + ocr_service.run_rpc_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/clas_web_client.py b/deploy/pdserving/clas_web_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd929ef0538f70c77ed5f879afaf526468635c --- /dev/null +++ b/deploy/pdserving/clas_web_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- coding: utf-8 -*- + +import requests +import json +import cv2 +import base64 +import os, sys +import time + + +def cv2_to_base64(image): + #data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(image).decode( + 'utf8') #data.tostring()).decode('utf8') + + +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:9292/ocr/prediction" +test_img_dir = "../../doc/imgs_words/ch/" +for img_file in os.listdir(test_img_dir): + with open(os.path.join(test_img_dir, img_file), 'rb') as file: + image_data1 = file.read() + image = cv2_to_base64(image_data1) + data = {"feed": [{"image": image}], "fetch": ["res"]} + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + print(r.json()) + break diff --git a/deploy/pdserving/det_local_server.py b/deploy/pdserving/det_local_server.py index eb7948daadd018810997bba78367e86aa3398e31..f79b9994624e3030ca2cd7ee8b3400be1e947947 100644 --- a/deploy/pdserving/det_local_server.py +++ b/deploy/pdserving/det_local_server.py @@ -17,63 +17,91 @@ import cv2 import sys import numpy as np import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes -if sys.argv[1] == 'gpu': - from paddle_serving_server_gpu.web_service import WebService -elif sys.argv[1] == 'cpu': - from paddle_serving_server.web_service import WebService import time import re import base64 +from tools.infer.predict_det import TextDetector +import tools.infer.utility as utility + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextDetectorHelper(TextDetector): + def __init__(self, args): + super(TextDetectorHelper, self).__init__(args) + if self.det_algorithm == "SAST": + self.fetch = [ + "bn_f_border4.output.tmp_2", "bn_f_tco4.output.tmp_2", + "bn_f_tvo4.output.tmp_2", "sigmoid_0.tmp_0" + ] + elif self.det_algorithm == "EAST": + self.fetch = ["sigmoid_0.tmp_0", "tmp_2"] + elif self.det_algorithm == "DB": + self.fetch = ["sigmoid_0.tmp_0"] + + def preprocess(self, img): + img = img.copy() + im, ratio_list = self.preprocess_op(img) + if im is None: + return None, 0 + return { + "image": im.copy() + }, self.fetch, { + "ratio_list": [ratio_list], + "ori_im": img + } + + def postprocess(self, outputs, args): + outs_dict = {} + if self.det_algorithm == "EAST": + outs_dict['f_geo'] = outputs[0] + outs_dict['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + outs_dict['f_border'] = outputs[0] + outs_dict['f_score'] = outputs[1] + outs_dict['f_tco'] = outputs[2] + outs_dict['f_tvo'] = outputs[3] + else: + outs_dict['maps'] = outputs[0] + dt_boxes_list = self.postprocess_op(outs_dict, args["ratio_list"]) + dt_boxes = dt_boxes_list[0] + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, + args["ori_im"].shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, args["ori_im"].shape) + return dt_boxes -class OCRService(WebService): +class DetService(WebService): def init_det(self): - self.det_preprocess = Sequential([ - ResizeByFactor(32, 960), Div(255), - Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( - (2, 0, 1)) - ]) - self.filter_func = FilterBoxes(10, 10) - self.post_func = DBPostProcess({ - "thresh": 0.3, - "box_thresh": 0.5, - "max_candidates": 1000, - "unclip_ratio": 1.5, - "min_size": 3 - }) + self.text_detector = TextDetectorHelper(global_args) def preprocess(self, feed=[], fetch=[]): data = base64.b64decode(feed[0]["image"].encode('utf8')) data = np.fromstring(data, np.uint8) im = cv2.imdecode(data, cv2.IMREAD_COLOR) - self.ori_h, self.ori_w, _ = im.shape - det_img = self.det_preprocess(im) - _, self.new_h, self.new_w = det_img.shape - return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"] + feed, fetch, self.tmp_args = self.text_detector.preprocess(im) + return feed, fetch def postprocess(self, feed={}, fetch=[], fetch_map=None): - det_out = fetch_map["concat_1.tmp_0"] - ratio_list = [ - float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w - ] - dt_boxes_list = self.post_func(det_out, [ratio_list]) - dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w]) - return {"dt_boxes": dt_boxes.tolist()} + outputs = [fetch_map[x] for x in fetch] + res = self.text_detector.postprocess(outputs, self.tmp_args) + return {"boxes": res.tolist()} -ocr_service = OCRService(name="ocr") -ocr_service.load_model_config("ocr_det_model") -ocr_service.init_det() -if sys.argv[1] == 'gpu': - ocr_service.set_gpus("0") - ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) - ocr_service.run_debugger_service(gpu=True) -elif sys.argv[1] == 'cpu': - ocr_service.prepare_server(workdir="workdir", port=9292) +if __name__ == "__main__": + ocr_service = DetService(name="ocr") + ocr_service.load_model_config("serving_server_dir") + ocr_service.init_det() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") ocr_service.run_debugger_service() -ocr_service.init_det() -ocr_service.run_web_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/det_rpc_server.py b/deploy/pdserving/det_rpc_server.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6d135b278b7152fa67235b823148f7e10c8f19 --- /dev/null +++ b/deploy/pdserving/det_rpc_server.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import cv2 +import sys +import numpy as np +import os +import time +import re +import base64 +from tools.infer.predict_det import TextDetector +import tools.infer.utility as utility + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextDetectorHelper(TextDetector): + def __init__(self, args): + super(TextDetectorHelper, self).__init__(args) + if self.det_algorithm == "SAST": + self.fetch = [ + "bn_f_border4.output.tmp_2", "bn_f_tco4.output.tmp_2", + "bn_f_tvo4.output.tmp_2", "sigmoid_0.tmp_0" + ] + elif self.det_algorithm == "EAST": + self.fetch = ["sigmoid_0.tmp_0", "tmp_2"] + elif self.det_algorithm == "DB": + self.fetch = ["sigmoid_0.tmp_0"] + + def preprocess(self, img): + im, ratio_list = self.preprocess_op(img) + if im is None: + return None, 0 + return { + "image": im[0] + }, self.fetch, { + "ratio_list": [ratio_list], + "ori_im": img + } + + def postprocess(self, outputs, args): + outs_dict = {} + if self.det_algorithm == "EAST": + outs_dict['f_geo'] = outputs[0] + outs_dict['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + outs_dict['f_border'] = outputs[0] + outs_dict['f_score'] = outputs[1] + outs_dict['f_tco'] = outputs[2] + outs_dict['f_tvo'] = outputs[3] + else: + outs_dict['maps'] = outputs[0] + dt_boxes_list = self.postprocess_op(outs_dict, args["ratio_list"]) + dt_boxes = dt_boxes_list[0] + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, + args["ori_im"].shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, args["ori_im"].shape) + return dt_boxes + + +class DetService(WebService): + def init_det(self): + self.text_detector = TextDetectorHelper(global_args) + print("init finish") + + def preprocess(self, feed=[], fetch=[]): + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + feed, fetch, self.tmp_args = self.text_detector.preprocess(im) + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + outputs = [fetch_map[x] for x in fetch] + res = self.text_detector.postprocess(outputs, self.tmp_args) + return {"boxes": res.tolist()} + + +if __name__ == "__main__": + ocr_service = DetService(name="ocr") + ocr_service.load_model_config("serving_server_dir") + ocr_service.init_det() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") + ocr_service.run_rpc_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/det_web_server.py b/deploy/pdserving/det_web_server.py deleted file mode 100644 index 14be74130dcb413c31a3e76c150d74f65575f451..0000000000000000000000000000000000000000 --- a/deploy/pdserving/det_web_server.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle_serving_client import Client -import cv2 -import sys -import numpy as np -import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes -if sys.argv[1] == 'gpu': - from paddle_serving_server_gpu.web_service import WebService -elif sys.argv[1] == 'cpu': - from paddle_serving_server.web_service import WebService -import time -import re -import base64 - - -class OCRService(WebService): - def init_det(self): - self.det_preprocess = Sequential([ - ResizeByFactor(32, 960), Div(255), - Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( - (2, 0, 1)) - ]) - self.filter_func = FilterBoxes(10, 10) - self.post_func = DBPostProcess({ - "thresh": 0.3, - "box_thresh": 0.5, - "max_candidates": 1000, - "unclip_ratio": 1.5, - "min_size": 3 - }) - - def preprocess(self, feed=[], fetch=[]): - data = base64.b64decode(feed[0]["image"].encode('utf8')) - data = np.fromstring(data, np.uint8) - im = cv2.imdecode(data, cv2.IMREAD_COLOR) - self.ori_h, self.ori_w, _ = im.shape - det_img = self.det_preprocess(im) - _, self.new_h, self.new_w = det_img.shape - print(det_img) - return {"image": det_img}, ["concat_1.tmp_0"] - - def postprocess(self, feed={}, fetch=[], fetch_map=None): - det_out = fetch_map["concat_1.tmp_0"] - ratio_list = [ - float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w - ] - dt_boxes_list = self.post_func(det_out, [ratio_list]) - dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w]) - return {"dt_boxes": dt_boxes.tolist()} - - -ocr_service = OCRService(name="ocr") -ocr_service.load_model_config("ocr_det_model") -if sys.argv[1] == 'gpu': - ocr_service.set_gpus("0") - ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) -elif sys.argv[1] == 'cpu': - ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") -ocr_service.init_det() -ocr_service.run_rpc_service() -ocr_service.run_web_service() diff --git a/deploy/pdserving/ocr_local_server.py b/deploy/pdserving/ocr_local_server.py index de5b3d13f12afd4a84c5d46625682c42f418d6bb..dae7137437b14bcc15a6858d366c01bb61d00440 100644 --- a/deploy/pdserving/ocr_local_server.py +++ b/deploy/pdserving/ocr_local_server.py @@ -18,97 +18,107 @@ import cv2 import sys import numpy as np import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes -if sys.argv[1] == 'gpu': - from paddle_serving_server_gpu.web_service import WebService -elif sys.argv[1] == 'cpu': - from paddle_serving_server.web_service import WebService -from paddle_serving_app.local_predict import Debugger import time import re import base64 +from clas_local_server import TextClassifierHelper +from det_local_server import TextDetectorHelper +from rec_local_server import TextRecognizerHelper +import tools.infer.utility as utility +from tools.infer.predict_system import TextSystem, sorted_boxes +from paddle_serving_app.local_predict import Debugger +import copy +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService -class OCRService(WebService): - def init_det_debugger(self, det_model_config): - self.det_preprocess = Sequential([ - ResizeByFactor(32, 960), Div(255), - Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( - (2, 0, 1)) - ]) + +class TextSystemHelper(TextSystem): + def __init__(self, args): + self.text_detector = TextDetectorHelper(args) + self.text_recognizer = TextRecognizerHelper(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.clas_client = Debugger() + self.clas_client.load_model_config( + "ocr_clas_server", gpu=True, profile=False) + self.text_classifier = TextClassifierHelper(args) self.det_client = Debugger() - if sys.argv[1] == 'gpu': - self.det_client.load_model_config( - det_model_config, gpu=True, profile=False) - elif sys.argv[1] == 'cpu': - self.det_client.load_model_config( - det_model_config, gpu=False, profile=False) - self.ocr_reader = OCRReader() + self.det_client.load_model_config( + "serving_server_dir", gpu=True, profile=False) + self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + + def preprocess(self, img): + feed, fetch, self.tmp_args = self.text_detector.preprocess(img) + fetch_map = self.det_client.predict(feed, fetch) + print("det fetch_map", fetch_map) + outputs = [fetch_map[x] for x in fetch] + dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args) + if dt_boxes is None: + return None, None + img_crop_list = [] + dt_boxes = sorted_boxes(dt_boxes) + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + img_crop = self.get_rotate_crop_image(img, tmp_box) + img_crop_list.append(img_crop) + if self.use_angle_cls: + feed, fetch, self.tmp_args = self.text_classifier.preprocess( + img_crop_list) + fetch_map = self.clas_client.predict(feed, fetch) + outputs = [fetch_map[x] for x in self.text_classifier.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + img_crop_list, _ = self.text_classifier.postprocess(outputs, + self.tmp_args) + feed, fetch, self.tmp_args = self.text_recognizer.preprocess( + img_crop_list) + return feed, self.fetch, self.tmp_args + + def postprocess(self, outputs, args): + return self.text_recognizer.postprocess(outputs, args) + + +class OCRService(WebService): + def init_rec(self): + args = utility.parse_args() + self.text_system = TextSystemHelper(args) def preprocess(self, feed=[], fetch=[]): + # TODO: to handle batch rec images + print("start preprocess") data = base64.b64decode(feed[0]["image"].encode('utf8')) data = np.fromstring(data, np.uint8) im = cv2.imdecode(data, cv2.IMREAD_COLOR) - ori_h, ori_w, _ = im.shape - det_img = self.det_preprocess(im) - _, new_h, new_w = det_img.shape - det_img = det_img[np.newaxis, :] - det_img = det_img.copy() - det_out = self.det_client.predict( - feed={"image": det_img}, fetch=["concat_1.tmp_0"]) - filter_func = FilterBoxes(10, 10) - post_func = DBPostProcess({ - "thresh": 0.3, - "box_thresh": 0.5, - "max_candidates": 1000, - "unclip_ratio": 1.5, - "min_size": 3 - }) - sorted_boxes = SortedBoxes() - ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] - dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) - dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) - dt_boxes = sorted_boxes(dt_boxes) - get_rotate_crop_image = GetRotateCropImage() - img_list = [] - max_wh_ratio = 0 - for i, dtbox in enumerate(dt_boxes): - boximg = get_rotate_crop_image(im, dt_boxes[i]) - img_list.append(boximg) - h, w = boximg.shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - if len(img_list) == 0: - return [], [] - _, w, h = self.ocr_reader.resize_norm_img(img_list[0], - max_wh_ratio).shape - imgs = np.zeros((len(img_list), 3, w, h)).astype('float32') - for id, img in enumerate(img_list): - norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) - imgs[id] = norm_img - feed = {"image": imgs.copy()} - fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + feed, fetch, self.tmp_args = self.text_system.preprocess(im) + print("ocr preprocess done") return feed, fetch def postprocess(self, feed={}, fetch=[], fetch_map=None): - rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) - res_lst = [] - for res in rec_res: - res_lst.append(res[0]) - res = {"res": res_lst} + outputs = [fetch_map[x] for x in self.text_system.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + rec_res = self.text_system.postprocess(outputs, self.tmp_args) + res = { + "pred_text": [x[0] for x in rec_res], + "score": [str(x[1]) for x in rec_res] + } return res -ocr_service = OCRService(name="ocr") -ocr_service.load_model_config("ocr_rec_model") -ocr_service.init_det_debugger(det_model_config="ocr_det_model") -if sys.argv[1] == 'gpu': - ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) - ocr_service.run_debugger_service(gpu=True) -elif sys.argv[1] == 'cpu': - ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") +if __name__ == "__main__": + ocr_service = OCRService(name="ocr") + ocr_service.load_model_config("ocr_rec_model") + ocr_service.init_rec() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") ocr_service.run_debugger_service() -ocr_service.run_web_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/ocr_rpc_server.py b/deploy/pdserving/ocr_rpc_server.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed8810eb8c13dee7d62dd37447ab7eed5d72cbc --- /dev/null +++ b/deploy/pdserving/ocr_rpc_server.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +import time +import re +import base64 +from clas_rpc_server import TextClassifierHelper +from det_rpc_server import TextDetectorHelper +from rec_rpc_server import TextRecognizerHelper +import tools.infer.utility as utility +from tools.infer.predict_system import TextSystem +import copy + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextSystemHelper(TextSystem): + def __init__(self, args): + self.text_detector = TextDetectorHelper(args) + self.text_recognizer = TextRecognizerHelper(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.clas_client = Client() + self.clas_client.load_client_config( + "ocr_clas_client/serving_client_conf.prototxt") + self.clas_client.connect(["127.0.0.1:9294"]) + self.text_classifier = TextClassifierHelper(args) + self.det_client = Client() + self.det_client.load_client_config( + "ocr_det_server/serving_client_conf.prototxt") + self.det_client.connect(["127.0.0.1:9293"]) + self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + + def preprocess(self, img): + feed, fetch, self.tmp_args = self.text_detector.preprocess(img) + fetch_map = self.det_client.predict(feed, fetch) + outputs = [fetch_map[x] for x in fetch] + dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args) + if dt_boxes is None: + return None, None + img_crop_list = [] + sorted_boxes = SortedBoxes() + dt_boxes = sorted_boxes(dt_boxes) + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + img_crop = self.get_rotate_crop_image(img, tmp_box) + img_crop_list.append(img_crop) + if self.use_angle_cls: + feed, fetch, self.tmp_args = self.text_classifier.preprocess( + img_crop_list) + fetch_map = self.clas_client.predict(feed, fetch) + outputs = [fetch_map[x] for x in self.text_classifier.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + img_crop_list, _ = self.text_classifier.postprocess(outputs, + self.tmp_args) + feed, fetch, self.tmp_args = self.text_recognizer.preprocess( + img_crop_list) + return feed, self.fetch, self.tmp_args + + def postprocess(self, outputs, args): + return self.text_recognizer.postprocess(outputs, args) + + +class OCRService(WebService): + def init_rec(self): + args = utility.parse_args() + self.text_system = TextSystemHelper(args) + + def preprocess(self, feed=[], fetch=[]): + # TODO: to handle batch rec images + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + feed, fetch, self.tmp_args = self.text_system.preprocess(im) + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + outputs = [fetch_map[x] for x in self.text_system.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + rec_res = self.text_system.postprocess(outputs, self.tmp_args) + res = { + "pred_text": [x[0] for x in rec_res], + "score": [str(x[1]) for x in rec_res] + } + return res + + +if __name__ == "__main__": + ocr_service = OCRService(name="ocr") + ocr_service.load_model_config(global_args.rec_model_dir) + ocr_service.init_rec() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") + ocr_service.run_rpc_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/ocr_web_client.py b/deploy/pdserving/ocr_web_client.py index e2a92eb8ee4aa62059be184dd7e67237ed460f13..036f730edfa764ba15f0a9cf8a217932a3af6da5 100644 --- a/deploy/pdserving/ocr_web_client.py +++ b/deploy/pdserving/ocr_web_client.py @@ -20,11 +20,13 @@ import base64 import os, sys import time + def cv2_to_base64(image): #data = cv2.imencode('.jpg', image)[1] return base64.b64encode(image).decode( 'utf8') #data.tostring()).decode('utf8') + headers = {"Content-type": "application/json"} url = "http://127.0.0.1:9292/ocr/prediction" test_img_dir = "../../doc/imgs/" @@ -34,4 +36,8 @@ for img_file in os.listdir(test_img_dir): image = cv2_to_base64(image_data1) data = {"feed": [{"image": image}], "fetch": ["res"]} r = requests.post(url=url, headers=headers, data=json.dumps(data)) - print(r.json()) + print(r) + rjson = r.json() + print(rjson) + #for x in rjson["result"]["pred_text"]: + # print(x) diff --git a/deploy/pdserving/ocr_web_server.py b/deploy/pdserving/ocr_web_server.py deleted file mode 100644 index 6c0de44661958a6425f57039261969551ff552c5..0000000000000000000000000000000000000000 --- a/deploy/pdserving/ocr_web_server.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle_serving_client import Client -from paddle_serving_app.reader import OCRReader -import cv2 -import sys -import numpy as np -import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes -if sys.argv[1] == 'gpu': - from paddle_serving_server_gpu.web_service import WebService -elif sys.argv[1] == 'cpu': - from paddle_serving_server.web_service import WebService -import time -import re -import base64 - - -class OCRService(WebService): - def init_det_client(self, det_port, det_client_config): - self.det_preprocess = Sequential([ - ResizeByFactor(32, 960), Div(255), - Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( - (2, 0, 1)) - ]) - self.det_client = Client() - self.det_client.load_client_config(det_client_config) - self.det_client.connect(["127.0.0.1:{}".format(det_port)]) - self.ocr_reader = OCRReader() - - def preprocess(self, feed=[], fetch=[]): - data = base64.b64decode(feed[0]["image"].encode('utf8')) - data = np.fromstring(data, np.uint8) - im = cv2.imdecode(data, cv2.IMREAD_COLOR) - ori_h, ori_w, _ = im.shape - det_img = self.det_preprocess(im) - det_out = self.det_client.predict( - feed={"image": det_img}, fetch=["concat_1.tmp_0"]) - _, new_h, new_w = det_img.shape - filter_func = FilterBoxes(10, 10) - post_func = DBPostProcess({ - "thresh": 0.3, - "box_thresh": 0.5, - "max_candidates": 1000, - "unclip_ratio": 1.5, - "min_size": 3 - }) - sorted_boxes = SortedBoxes() - ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] - dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) - dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) - dt_boxes = sorted_boxes(dt_boxes) - get_rotate_crop_image = GetRotateCropImage() - feed_list = [] - img_list = [] - max_wh_ratio = 0 - for i, dtbox in enumerate(dt_boxes): - boximg = get_rotate_crop_image(im, dt_boxes[i]) - img_list.append(boximg) - h, w = boximg.shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - for img in img_list: - norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) - feed = {"image": norm_img} - feed_list.append(feed) - fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] - return feed_list, fetch - - def postprocess(self, feed={}, fetch=[], fetch_map=None): - rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) - res_lst = [] - for res in rec_res: - res_lst.append(res[0]) - res = {"res": res_lst} - return res - - -ocr_service = OCRService(name="ocr") -ocr_service.load_model_config("ocr_rec_model") -if sys.argv[1] == 'gpu': - ocr_service.set_gpus("0") - ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) -elif sys.argv[1] == 'cpu': - ocr_service.prepare_server(workdir="workdir", port=9292) -ocr_service.init_det_client( - det_port=9293, - det_client_config="ocr_det_client/serving_client_conf.prototxt") -ocr_service.run_rpc_service() -ocr_service.run_web_service() diff --git a/deploy/pdserving/readme.md b/deploy/pdserving/readme.md index af12d508ba9c04e6032f2a392701e72b41462395..c3b7be454c4446bd9fda661b545b8ad167165459 100644 --- a/deploy/pdserving/readme.md +++ b/deploy/pdserving/readme.md @@ -1,7 +1,7 @@ [English](readme_en.md) | 简体中文 PaddleOCR提供2种服务部署方式: -- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../hubserving/readme.md)。 +- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../hubserving/readme.md)。 - 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",按照本教程使用。 # Paddle Serving 服务部署 @@ -11,7 +11,7 @@ PaddleOCR提供2种服务部署方式: ### 1. 准备环境 我们先安装Paddle Serving相关组件 -我们推荐用户使用GPU来做Paddle Serving的OCR服务部署 +我们推荐用户使用GPU来做Paddle Serving的OCR服务部署 **CUDA版本:9.0** @@ -39,7 +39,7 @@ python -m pip install paddle_serving_app paddle_serving_client python -m paddle_serving_app.package --get_model ocr_rec tar -xzvf ocr_rec.tar.gz python -m paddle_serving_app.package --get_model ocr_det -tar -xzvf ocr_det.tar.gz +tar -xzvf ocr_det.tar.gz ``` 执行上述命令会下载`db_crnn_mobile`的模型,如果想要下载规模更大的`db_crnn_server`模型,可以在下载预测模型并解压之后。参考[如何从Paddle保存的预测模型转为Paddle Serving格式可部署的模型](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md)。 @@ -72,7 +72,7 @@ feed_var_names, fetch_var_names = inference_model_to_serving( ``` # cpu,gpu启动二选一,以下是cpu启动 -python -m paddle_serving_server.serve --model ocr_det_model --port 9293 +python -m paddle_serving_server.serve --model ocr_det_model --port 9293 python ocr_web_server.py cpu # gpu启动 python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 diff --git a/deploy/pdserving/readme_en.md b/deploy/pdserving/readme_en.md index 9a0c684fb6fb4f0eeff2552af70f62053d3351fb..bfd88e63eb6a7be01145efd9aebe821bc441a901 100644 --- a/deploy/pdserving/readme_en.md +++ b/deploy/pdserving/readme_en.md @@ -1,6 +1,6 @@ English | [简体中文](readme.md) -PaddleOCR provides 2 service deployment methods: +PaddleOCR provides 2 service deployment methods: - Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please refer to the [tutorial](../hubserving/readme_en.md) for usage. - Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please follow this tutorial. @@ -37,7 +37,7 @@ You can directly use converted model provided by `paddle_serving_app` for conven python -m paddle_serving_app.package --get_model ocr_rec tar -xzvf ocr_rec.tar.gz python -m paddle_serving_app.package --get_model ocr_det -tar -xzvf ocr_det.tar.gz +tar -xzvf ocr_det.tar.gz ``` Executing the above command will download the `db_crnn_mobile` model, which is in different format with inference model. If you want to use other models for deployment, you can refer to the [tutorial](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md) to convert your inference model to a model which is deployable for Paddle Serving. @@ -71,7 +71,7 @@ Start the standard version or the fast version service according to your actual ``` # start with CPU -python -m paddle_serving_server.serve --model ocr_det_model --port 9293 +python -m paddle_serving_server.serve --model ocr_det_model --port 9293 python ocr_web_server.py cpu # or, with GPU diff --git a/deploy/pdserving/rec_local_server.py b/deploy/pdserving/rec_local_server.py index ba021c1cd5054071eb115b3e6e9c64cb572ff871..5021cdd97edcc5ef82f5c08228d09aaff95b83c4 100644 --- a/deploy/pdserving/rec_local_server.py +++ b/deploy/pdserving/rec_local_server.py @@ -18,62 +18,159 @@ import cv2 import sys import numpy as np import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes -if sys.argv[1] == 'gpu': - from paddle_serving_server_gpu.web_service import WebService -elif sys.argv[1] == 'cpu': - from paddle_serving_server.web_service import WebService import time import re import base64 +from tools.infer.predict_rec import TextRecognizer +import tools.infer.utility as utility + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextRecognizerHelper(TextRecognizer): + def __init__(self, args): + super(TextRecognizerHelper, self).__init__(args) + if self.loss_type == "ctc": + self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + + def preprocess(self, img_list): + img_num = len(img_list) + args = {} + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + indices = np.argsort(np.array(width_list)) + args["indices"] = indices + predict_time = 0 + beg_img_no = 0 + end_img_no = img_num + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + if self.loss_type != "srn": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn(img_list[indices[ino]], + self.rec_image_shape, 8, 25, + self.char_ops) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + norm_img_batch = np.concatenate(norm_img_batch, axis=0).copy() + feed = {"image": norm_img_batch.copy()} + return feed, self.fetch, args + + def postprocess(self, outputs, args): + if self.loss_type == "ctc": + rec_idx_batch = outputs[0] + predict_batch = outputs[1] + rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] + predict_lod = args["softmax_0.tmp_0.lod"] + indices = args["indices"] + print("indices", indices, rec_idx_lod) + rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) + for rno in range(len(rec_idx_lod) - 1): + beg = rec_idx_lod[rno] + end = rec_idx_lod[rno + 1] + rec_idx_tmp = rec_idx_batch[beg:end, 0] + preds_text = self.char_ops.decode(rec_idx_tmp) + beg = predict_lod[rno] + end = predict_lod[rno + 1] + probs = predict_batch[beg:end, :] + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + rec_res[indices[rno]] = [preds_text, score] + elif self.loss_type == 'srn': + char_num = self.char_ops.get_char_num() + preds = rec_idx_batch.reshape(-1) + elapse = time.time() - starttime + predict_time += elapse + total_preds = preds.copy() + for ino in range(int(len(rec_idx_batch) / self.text_len)): + preds = total_preds[ino * self.text_len:(ino + 1) * + self.text_len] + ind = np.argmax(probs, axis=1) + valid_ind = np.where(preds != int(char_num - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + preds = preds[:valid_ind[-1] + 1] + preds_text = self.char_ops.decode(preds) + rec_res[indices[ino]] = [preds_text, score] + else: + for rno in range(len(rec_idx_batch)): + end_pos = np.where(rec_idx_batch[rno, :] == 1)[0] + if len(end_pos) <= 1: + preds = rec_idx_batch[rno, 1:] + score = np.mean(predict_batch[rno, 1:]) + else: + preds = rec_idx_batch[rno, 1:end_pos[1]] + score = np.mean(predict_batch[rno, 1:end_pos[1]]) + preds_text = self.char_ops.decode(preds) + rec_res[indices[rno]] = [preds_text, score] + return rec_res class OCRService(WebService): def init_rec(self): self.ocr_reader = OCRReader() + self.text_recognizer = TextRecognizerHelper(global_args) def preprocess(self, feed=[], fetch=[]): + # TODO: to handle batch rec images img_list = [] for feed_data in feed: data = base64.b64decode(feed_data["image"].encode('utf8')) data = np.fromstring(data, np.uint8) im = cv2.imdecode(data, cv2.IMREAD_COLOR) img_list.append(im) - max_wh_ratio = 0 - for i, boximg in enumerate(img_list): - h, w = boximg.shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - _, w, h = self.ocr_reader.resize_norm_img(img_list[0], - max_wh_ratio).shape - imgs = np.zeros((len(img_list), 3, w, h)).astype('float32') - for i, img in enumerate(img_list): - norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) - imgs[i] = norm_img - feed = {"image": imgs.copy()} - fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + feed, fetch, self.tmp_args = self.text_recognizer.preprocess(img_list) return feed, fetch def postprocess(self, feed={}, fetch=[], fetch_map=None): - rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) - res_lst = [] - for res in rec_res: - res_lst.append(res[0]) - res = {"res": res_lst} + outputs = [fetch_map[x] for x in self.text_recognizer.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args) + print("rec_res", rec_res) + res = { + "pred_text": [x[0] for x in rec_res], + "score": [str(x[1]) for x in rec_res] + } return res -ocr_service = OCRService(name="ocr") -ocr_service.load_model_config("ocr_rec_model") -ocr_service.init_rec() -if sys.argv[1] == 'gpu': - ocr_service.set_gpus("0") - ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) - ocr_service.run_debugger_service(gpu=True) -elif sys.argv[1] == 'cpu': - ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") +if __name__ == "__main__": + ocr_service = OCRService(name="ocr") + ocr_service.load_model_config("ocr_rec_model") + ocr_service.init_rec() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") ocr_service.run_debugger_service() -ocr_service.run_web_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/rec_rpc_server.py b/deploy/pdserving/rec_rpc_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a9df9e8d383805b695448b736b67cd5cd45601 --- /dev/null +++ b/deploy/pdserving/rec_rpc_server.py @@ -0,0 +1,182 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +import time +import re +import base64 +from tools.infer.predict_rec import TextRecognizer +import tools.infer.utility as utility + +global_args = utility.parse_args() +if global_args.use_gpu: + from paddle_serving_server_gpu.web_service import WebService +else: + from paddle_serving_server.web_service import WebService + + +class TextRecognizerHelper(TextRecognizer): + def __init__(self, args): + super(TextRecognizerHelper, self).__init__(args) + if self.loss_type == "ctc": + self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + + def preprocess(self, img_list): + img_num = len(img_list) + args = {} + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + indices = np.argsort(np.array(width_list)) + args["indices"] = indices + predict_time = 0 + beg_img_no = 0 + end_img_no = img_num + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + if self.loss_type != "srn": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn(img_list[indices[ino]], + self.rec_image_shape, 8, 25, + self.char_ops) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + + norm_img_batch = np.concatenate(norm_img_batch, axis=0) + if img_num > 1: + feed = [{ + "image": norm_img_batch[x] + } for x in range(norm_img_batch.shape[0])] + else: + feed = {"image": norm_img_batch[0]} + return feed, self.fetch, args + + def postprocess(self, outputs, args): + if self.loss_type == "ctc": + rec_idx_batch = outputs[0] + predict_batch = outputs[1] + rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] + predict_lod = args["softmax_0.tmp_0.lod"] + indices = args["indices"] + print("indices", indices, rec_idx_lod) + rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) + for rno in range(len(rec_idx_lod) - 1): + beg = rec_idx_lod[rno] + end = rec_idx_lod[rno + 1] + rec_idx_tmp = rec_idx_batch[beg:end, 0] + preds_text = self.char_ops.decode(rec_idx_tmp) + beg = predict_lod[rno] + end = predict_lod[rno + 1] + probs = predict_batch[beg:end, :] + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + rec_res[indices[rno]] = [preds_text, score] + elif self.loss_type == 'srn': + char_num = self.char_ops.get_char_num() + preds = rec_idx_batch.reshape(-1) + elapse = time.time() - starttime + predict_time += elapse + total_preds = preds.copy() + for ino in range(int(len(rec_idx_batch) / self.text_len)): + preds = total_preds[ino * self.text_len:(ino + 1) * + self.text_len] + ind = np.argmax(probs, axis=1) + valid_ind = np.where(preds != int(char_num - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + preds = preds[:valid_ind[-1] + 1] + preds_text = self.char_ops.decode(preds) + rec_res[indices[ino]] = [preds_text, score] + else: + for rno in range(len(rec_idx_batch)): + end_pos = np.where(rec_idx_batch[rno, :] == 1)[0] + if len(end_pos) <= 1: + preds = rec_idx_batch[rno, 1:] + score = np.mean(predict_batch[rno, 1:]) + else: + preds = rec_idx_batch[rno, 1:end_pos[1]] + score = np.mean(predict_batch[rno, 1:end_pos[1]]) + preds_text = self.char_ops.decode(preds) + rec_res[indices[rno]] = [preds_text, score] + return rec_res + + +class OCRService(WebService): + def init_rec(self): + self.ocr_reader = OCRReader() + self.text_recognizer = TextRecognizerHelper(global_args) + + def preprocess(self, feed=[], fetch=[]): + # TODO: to handle batch rec images + img_list = [] + for feed_data in feed: + data = base64.b64decode(feed_data["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + img_list.append(im) + feed, fetch, self.tmp_args = self.text_recognizer.preprocess(img_list) + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + outputs = [fetch_map[x] for x in self.text_recognizer.fetch] + for x in fetch_map.keys(): + if ".lod" in x: + self.tmp_args[x] = fetch_map[x] + rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args) + print("rec_res", rec_res) + res = { + "pred_text": [x[0] for x in rec_res], + "score": [str(x[1]) for x in rec_res] + } + return res + + +if __name__ == "__main__": + ocr_service = OCRService(name="ocr") + ocr_service.load_model_config(global_args.rec_model_dir) + ocr_service.init_rec() + if global_args.use_gpu: + ocr_service.prepare_server( + workdir="workdir", port=9292, device="gpu", gpuid=0) + else: + ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") + ocr_service.run_rpc_service() + ocr_service.run_web_service() diff --git a/deploy/pdserving/rec_web_client.py b/deploy/pdserving/rec_web_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd929ef0538f70c77ed5f879afaf526468635c --- /dev/null +++ b/deploy/pdserving/rec_web_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- coding: utf-8 -*- + +import requests +import json +import cv2 +import base64 +import os, sys +import time + + +def cv2_to_base64(image): + #data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(image).decode( + 'utf8') #data.tostring()).decode('utf8') + + +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:9292/ocr/prediction" +test_img_dir = "../../doc/imgs_words/ch/" +for img_file in os.listdir(test_img_dir): + with open(os.path.join(test_img_dir, img_file), 'rb') as file: + image_data1 = file.read() + image = cv2_to_base64(image_data1) + data = {"feed": [{"image": image}], "fetch": ["res"]} + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + print(r.json()) + break diff --git a/deploy/pdserving/rec_web_server.py b/deploy/pdserving/rec_web_server.py deleted file mode 100644 index 0f4e9f6d264ed602f387bfaf0303cd59af7823fa..0000000000000000000000000000000000000000 --- a/deploy/pdserving/rec_web_server.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle_serving_client import Client -from paddle_serving_app.reader import OCRReader -import cv2 -import sys -import numpy as np -import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes -if sys.argv[1] == 'gpu': - from paddle_serving_server_gpu.web_service import WebService -elif sys.argv[1] == 'cpu': - from paddle_serving_server.web_service import WebService -import time -import re -import base64 - - -class OCRService(WebService): - def init_rec(self): - self.ocr_reader = OCRReader() - - def preprocess(self, feed=[], fetch=[]): - # TODO: to handle batch rec images - img_list = [] - for feed_data in feed: - data = base64.b64decode(feed_data["image"].encode('utf8')) - data = np.fromstring(data, np.uint8) - im = cv2.imdecode(data, cv2.IMREAD_COLOR) - img_list.append(im) - feed_list = [] - max_wh_ratio = 0 - for i, boximg in enumerate(img_list): - h, w = boximg.shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - for img in img_list: - norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) - feed = {"image": norm_img} - feed_list.append(feed) - fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] - return feed_list, fetch - - def postprocess(self, feed={}, fetch=[], fetch_map=None): - rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) - res_lst = [] - for res in rec_res: - res_lst.append(res[0]) - res = {"res": res_lst} - return res - - -ocr_service = OCRService(name="ocr") -ocr_service.load_model_config("ocr_rec_model") -ocr_service.init_rec() -if sys.argv[1] == 'gpu': - ocr_service.set_gpus("0") - ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) -elif sys.argv[1] == 'cpu': - ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") -ocr_service.run_rpc_service() -ocr_service.run_web_service() diff --git a/doc/doc_ch/serving_inference.md b/doc/doc_ch/serving_inference.md new file mode 100644 index 0000000000000000000000000000000000000000..4f0a998056bcb37e364451126b10458d27e30d83 --- /dev/null +++ b/doc/doc_ch/serving_inference.md @@ -0,0 +1,236 @@ +# 使用Paddle Serving预测推理 + +阅读本文档之前,请先阅读文档 [基于Python预测引擎推理](./inference.md) + +同本地执行预测一样,我们需要保存一份可以用于Paddle Serving的模型。 + +接下来首先介绍如何将训练的模型转换成Paddle Serving模型,然后将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。 + + + +## 一、训练模型转Serving模型 + +### 检测模型转Serving模型 + +下载超轻量级中文检测模型: + +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/ +``` + +上述模型是以MobileNetV3为backbone训练的DB算法,将训练好的模型转换成Serving模型只需要运行如下命令: + +``` +# -c后面设置训练算法的yml配置文件 +# -o配置可选参数 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python tools/export_serving_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/ +``` + +转Serving模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的`Global.checkpoints`、`Global.save_inference_dir`参数。 其中`Global.checkpoints`指向训练中保存的模型参数文件,`Global.save_inference_dir`是生成的inference模型要保存的目录。 转换成功后,在`save_inference_dir`目录下有两个文件: + +``` +inference/det_db/ +├── serving_client_dir # 客户端配置文件夹 +└── serving_server_dir # 服务端配置文件夹 + +``` + +### 识别模型转Serving模型 + +下载超轻量中文识别模型: + +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/ +``` + +识别模型转inference模型与检测的方式相同,如下: + +``` +# -c后面设置训练算法的yml配置文件 +# -o配置可选参数 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_serving_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \ + Global.save_inference_dir=./inference/rec_crnn/ +``` + +**注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。 + +转换成功后,在目录下有两个文件: + +``` +/inference/rec_crnn/ +├── serving_client_dir # 客户端配置文件夹 +└── serving_server_dir # 服务端配置文件夹 +``` + +### 方向分类模型转Serving模型 + +下载方向分类模型: + +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/ +``` + +方向分类模型转inference模型与检测的方式相同,如下: + +``` +# -c后面设置训练算法的yml配置文件 +# -o配置可选参数 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_serving_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \ + Global.save_inference_dir=./inference/cls/ +``` + +转换成功后,在目录下有两个文件: + +``` +/inference/cls/ +├── serving_client_dir # 客户端配置文件夹 +└── serving_server_dir # 服务端配置文件夹 +``` + +在接下来的教程中,我们将给出推理的demo模型下载链接。 + +``` +wget --no-check-certificate ocr_serving_model_zoo.tar.gz +tar zxf ocr_serving_model_zoo.tar.gz +``` + + + +## 二、文本检测模型Serving推理 + +文本检测模型推理,默认使用DB模型的配置参数。当不使用DB模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。 + +与本地预测不同的是,Serving预测需要一个客户端和一个服务端,因此接下来的教程都是两行代码。所有的 + +### 1. 超轻量中文检测模型推理 + +超轻量中文检测模型推理,可以执行如下命令启动服务端: + +``` +#根据环境只需要启动其中一个就可以 +python det_rpc_server.py --use_serving True #标准版,Linux用户 +python det_local_server.py --use_serving True #快速版,Windows/Linux用户 +``` + +客户端 + +``` +python det_web_client.py +``` + + + +Serving的推测和本地预测不同点在于,客户端发送请求到服务端,服务端需要检测到文字框之后返回框的坐标,此处没有后处理的图片,只能看到坐标值。 + +### 2. DB文本检测模型推理 + +首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换: + +``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_serving_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db" +``` + +经过转换之后,会在`./inference/det_db` 目录下出现`serving_server_dir`和`serving_client_dir`,然后指定`det_model_dir` 。 + +## 三、文本识别模型Serving推理 + +下面将介绍超轻量中文识别模型推理、基于CTC损失的识别模型推理和基于Attention损失的识别模型推理。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。此外,如果训练时修改了文本的字典,请参考下面的自定义文本识别字典的推理。 + +### 1. 超轻量中文识别模型推理 + +超轻量中文识别模型推理,可以执行如下命令启动服务端: + +``` +#根据环境只需要启动其中一个就可以 +python rec_rpc_server.py --use_serving True #标准版,Linux用户 +python rec_local_server.py --use_serving True #快速版,Windows/Linux用户 +``` + +客户端 + +``` +python rec_web_client.py +``` + + + +执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: + +``` +{u'result': {u'score': [u'0.89547354'], u'pred_text': ['实力活力']}} +``` + + + +## 四、方向分类模型推理 + +下面将介绍方向分类模型推理。 + + + +### 1. 方向分类模型推理 + +方向分类模型推理, 可以执行如下命令启动服务端: + +``` +#根据环境只需要启动其中一个就可以 +python clas_rpc_server.py --use_serving True #标准版,Linux用户 +python clas_local_server.py --use_serving True #快速版,Windows/Linux用户 +``` + +客户端 + +``` +python rec_web_client.py +``` + +![](../imgs_words/ch/word_4.jpg) + +执行命令后,上面图像的预测结果(分类的方向和得分)会打印到屏幕上,示例如下: + +``` +{u'result': {u'direction': [u'0'], u'score': [u'0.9999963']}} +``` + + +## 五、文本检测、方向分类和文字识别串联Serving推理 + +### 1. 超轻量中文OCR模型推理 + +在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir`和`rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。与本地预测不同的是,为了减少网络传输耗时,可视化识别结果目前不做处理,用户收到的是推理得到的文字字段。 + +执行如下命令启动服务端: + +``` +#标准版,Linux用户 +#GPU用户 +python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +python -m paddle_serving_server_gpu.serve --model ocr_cls_model --port 9294 --gpu_id 0 +python ocr_rpc_server.py --use_serving True --use_gpu True +#CPU用户 +python -m paddle_serving_server.serve --model ocr_det_model --port 9293 +python -m paddle_serving_server.serve --model ocr_cls_model --port 9294 +python ocr_rpc_server.py --use_serving True --use_gpu False + +#快速版,Windows/Linux用户 +python ocr_local_server.py --use_serving True +``` + +客户端 + +``` +python rec_web_client.py +``` diff --git a/ppocr/data/det/__init__.py b/ppocr/data/det/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/export_serving_model.py b/tools/export_serving_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b43fac67999de1e39e6439300afa8867880b55c2 --- /dev/null +++ b/tools/export_serving_model.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import program +from paddle import fluid +from ppocr.utils.utility import initial_logger +logger = initial_logger() +from ppocr.utils.save_load import init_model +from paddle_serving_client.io import save_model + + +def main(): + startup_prog, eval_program, place, config, _ = program.preprocess() + + feeded_var_names, target_vars, fetches_var_name = program.build_export( + config, eval_program, startup_prog) + eval_program = eval_program.clone(for_test=True) + exe = fluid.Executor(place) + exe.run(startup_prog) + + init_model(config, eval_program, exe) + + save_inference_dir = config['Global']['save_inference_dir'] + if not os.path.exists(save_inference_dir): + os.makedirs(save_inference_dir) + serving_client_dir = "{}/serving_client_dir".format(save_inference_dir) + serving_server_dir = "{}/serving_server_dir".format(save_inference_dir) + + feed_dict = { + x: eval_program.global_block().var(x) + for x in feeded_var_names + } + fetch_dict = {x.name: x for x in target_vars} + save_model(serving_server_dir, serving_client_dir, feed_dict, fetch_dict, + eval_program) + print( + "paddle serving model saved in {}/serving_server_dir and {}/serving_client_dir". + format(save_inference_dir, save_inference_dir)) + print("save success, output_name_list:", fetches_var_name) + + +if __name__ == '__main__': + main() diff --git a/tools/infer/__init__.py b/tools/infer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 3c14011a24cf5afcecc5edd5a54e395a0f171f53..74029b2474b0cbd923988bbe7e8578893ee5d276 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -33,8 +33,9 @@ from paddle import fluid class TextClassifier(object): def __init__(self, args): - self.predictor, self.input_tensor, self.output_tensors = \ - utility.create_predictor(args, mode="cls") + if args.use_serving is False: + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, mode="cls") self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_batch_num = args.rec_batch_num self.label_list = args.label_list @@ -103,7 +104,6 @@ class TextClassifier(object): label_out = self.output_tensors[1].copy_to_cpu() if len(label_out.shape) != 1: prob_out, label_out = label_out, prob_out - elapse = time.time() - starttime predict_time += elapse for rno in range(len(label_out)): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 625f87abc39fc0e9d7683f72dafec1d53324873a..18ea4bffd96aa0f1ccf795de4d9c2fdc3ea3f4a9 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -75,9 +75,9 @@ class TextDetector(object): else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) - - self.predictor, self.input_tensor, self.output_tensors =\ - utility.create_predictor(args, mode="det") + if args.use_gpu is False: + self.predictor, self.input_tensor, self.output_tensors =\ + utility.create_predictor(args, mode="det") def order_points_clockwise(self, pts): """ diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 06273e9f9e5b42a9ecc829c435662e9aabcdd224..c09a14f90638f524efc660269f9495117f822257 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -34,8 +34,9 @@ from ppocr.utils.character import CharacterOps class TextRecognizer(object): def __init__(self, args): - self.predictor, self.input_tensor, self.output_tensors =\ - utility.create_predictor(args, mode="rec") + if args.use_serving is False: + self.predictor, self.input_tensor, self.output_tensors =\ + utility.create_predictor(args, mode="rec") self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num @@ -320,7 +321,7 @@ def main(args): print(e) logger.info( "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "Please read the FAQ: https://github.com/PaddlePaddle/PaddleOCR#faq \n" "If your model has tps module: " "TPS does not support variable shape.\n" "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 45d7b73707904d3bb2df665a1cf348a32c70f852..d85322d84b0cce7e30073fb28b13f2fef6426729 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -37,6 +37,7 @@ def parse_args(): parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) + parser.add_argument("--use_serving", type=str2bool, default=False) # params for text detector parser.add_argument("--image_dir", type=str)