diff --git a/python/examples/blazeface/README.md b/python/examples/blazeface/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f569841ce4a3ae69b1ff16041f7fb7d4617177f7 --- /dev/null +++ b/python/examples/blazeface/README.md @@ -0,0 +1,23 @@ +# Blazeface + +## Get Model +``` +python -m paddle_serving_app.package --get_model blazeface +tar -xzvf blazeface.tar.gz +``` + +## RPC Service + +### Start Service + +``` +python -m paddle_serving_server.serve --model serving_server --port 9494 +``` + +### Client Prediction + +``` +python test_client.py serving_client/serving_client_conf.prototxt test.jpg +``` + +the result is in `output` folder, including a json file and image file with bounding boxes. diff --git a/python/examples/ocr/test_ocr_rec_client.py b/python/examples/blazeface/test_client.py similarity index 53% rename from python/examples/ocr/test_ocr_rec_client.py rename to python/examples/blazeface/test_client.py index b61256d03202374ada5b0d50a075fef156eca2ea..27eb185ea90ce72641cef44d9066c46945ad2629 100644 --- a/python/examples/ocr/test_ocr_rec_client.py +++ b/python/examples/blazeface/test_client.py @@ -13,19 +13,26 @@ # limitations under the License. from paddle_serving_client import Client -from paddle_serving_app.reader import OCRReader -import cv2 +from paddle_serving_app.reader import * +import sys +import numpy as np +preprocess = Sequential([ + File2Image(), + Normalize([104, 117, 123], [127.502231, 127.502231, 127.502231], False) +]) + +postprocess = BlazeFacePostprocess("label_list.txt", "output") client = Client() -client.load_client_config("ocr_rec_client/serving_client_conf.prototxt") -client.connect(["127.0.0.1:9292"]) -image_file_list = ["./test_rec.jpg"] -img = cv2.imread(image_file_list[0]) -ocr_reader = OCRReader() -feed = {"image": ocr_reader.preprocess([img])} -fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] -fetch_map = client.predict(feed=feed, fetch=fetch) -rec_res = ocr_reader.postprocess(fetch_map) -print(image_file_list[0]) -print(rec_res[0][0]) +client.load_client_config(sys.argv[1]) +client.connect(['127.0.0.1:9494']) + +im_0 = preprocess(sys.argv[2]) +tmp = Transpose((2, 0, 1)) +im = tmp(im_0) +fetch_map = client.predict( + feed={"image": im}, fetch=["detection_output_0.tmp_0"]) +fetch_map["image"] = sys.argv[2] +fetch_map["im_shape"] = im_0.shape +postprocess(fetch_map) diff --git a/python/examples/criteo_ctr_with_cube/benchmark.py b/python/examples/criteo_ctr_with_cube/benchmark.py index e5bde9f996fccc41027fa6d255ca227cba212e22..a850d244b0a5a1a01e98a6207fa9674b6ea0af1a 100755 --- a/python/examples/criteo_ctr_with_cube/benchmark.py +++ b/python/examples/criteo_ctr_with_cube/benchmark.py @@ -29,6 +29,7 @@ args = benchmark_args() def single_func(idx, resource): client = Client() + print([resource["endpoint"][idx % len(resource["endpoint"])]]) client.load_client_config('ctr_client_conf/serving_client_conf.prototxt') client.connect(['127.0.0.1:9292']) batch = 1 @@ -40,27 +41,29 @@ def single_func(idx, resource): ] reader = dataset.infer_reader(test_filelists[len(test_filelists) - 40:], batch, buf_size) - args.batch_size = 1 if args.request == "rpc": fetch = ["prob"] - print("Start Time") start = time.time() itr = 1000 for ei in range(itr): - if args.batch_size == 1: - data = reader().next() - feed_dict = {} - feed_dict['dense_input'] = data[0][0] - for i in range(1, 27): - feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][i] - result = client.predict(feed=feed_dict, fetch=fetch) + if args.batch_size > 0: + feed_batch = [] + for bi in range(args.batch_size): + data = reader().next() + feed_dict = {} + feed_dict['dense_input'] = data[0][0] + for i in range(1, 27): + feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][ + i] + feed_batch.append(feed_dict) + result = client.predict(feed=feed_batch, fetch=fetch) else: print("unsupport batch size {}".format(args.batch_size)) elif args.request == "http": raise ("Not support http service.") end = time.time() - qps = itr / (end - start) + qps = itr * args.batch_size / (end - start) return [[end - start, qps]] @@ -70,6 +73,7 @@ if __name__ == '__main__': #result = single_func(0, {"endpoint": endpoint_list}) result = multi_thread_runner.run(single_func, args.thread, {"endpoint": endpoint_list}) + print(result) avg_cost = 0 qps = 0 for i in range(args.thread): diff --git a/python/examples/criteo_ctr_with_cube/benchmark.sh b/python/examples/criteo_ctr_with_cube/benchmark.sh index 4bea258a5cfa4e12ed6848c61270fe44bbc7ba44..35b19b637d9e8dec10fd3b59224c5c17e3ba5f53 100755 --- a/python/examples/criteo_ctr_with_cube/benchmark.sh +++ b/python/examples/criteo_ctr_with_cube/benchmark.sh @@ -1,10 +1,16 @@ rm profile_log -batch_size=1 +export FLAGS_profile_client=1 +export FLAGS_profile_server=1 for thread_num in 1 2 4 8 16 do - $PYTHONROOT/bin/python benchmark.py --thread $thread_num --model ctr_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1 +for batch_size in 1 4 16 64 256 +do + $PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1 + echo "batch size : $batch_size" + echo "thread num : $thread_num" echo "========================================" echo "batch size : $batch_size" >> profile_log $PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log tail -n 2 profile >> profile_log done +done diff --git a/python/examples/criteo_ctr_with_cube/benchmark_batch.py b/python/examples/criteo_ctr_with_cube/benchmark_batch.py deleted file mode 100755 index df5c6b90badb36fd7e349555973ccbd7ea0a8b70..0000000000000000000000000000000000000000 --- a/python/examples/criteo_ctr_with_cube/benchmark_batch.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pylint: disable=doc-string-missing - -from paddle_serving_client import Client -import sys -import os -import criteo as criteo -import time -from paddle_serving_client.utils import MultiThreadRunner -from paddle_serving_client.utils import benchmark_args -from paddle_serving_client.metric import auc - -args = benchmark_args() - - -def single_func(idx, resource): - client = Client() - print([resource["endpoint"][idx % len(resource["endpoint"])]]) - client.load_client_config('ctr_client_conf/serving_client_conf.prototxt') - client.connect(['127.0.0.1:9292']) - batch = 1 - buf_size = 100 - dataset = criteo.CriteoDataset() - dataset.setup(1000001) - test_filelists = [ - "./raw_data/part-%d" % x for x in range(len(os.listdir("./raw_data"))) - ] - reader = dataset.infer_reader(test_filelists[len(test_filelists) - 40:], - batch, buf_size) - if args.request == "rpc": - fetch = ["prob"] - start = time.time() - itr = 1000 - for ei in range(itr): - if args.batch_size > 1: - feed_batch = [] - for bi in range(args.batch_size): - data = reader().next() - feed_dict = {} - feed_dict['dense_input'] = data[0][0] - for i in range(1, 27): - feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][ - i] - feed_batch.append(feed_dict) - result = client.predict(feed=feed_batch, fetch=fetch) - else: - print("unsupport batch size {}".format(args.batch_size)) - - elif args.request == "http": - raise ("Not support http service.") - end = time.time() - qps = itr * args.batch_size / (end - start) - return [[end - start, qps]] - - -if __name__ == '__main__': - multi_thread_runner = MultiThreadRunner() - endpoint_list = ["127.0.0.1:9292"] - #result = single_func(0, {"endpoint": endpoint_list}) - result = multi_thread_runner.run(single_func, args.thread, - {"endpoint": endpoint_list}) - print(result) - avg_cost = 0 - qps = 0 - for i in range(args.thread): - avg_cost += result[0][i * 2 + 0] - qps += result[0][i * 2 + 1] - avg_cost = avg_cost / args.thread - print("average total cost {} s.".format(avg_cost)) - print("qps {} ins/s".format(qps)) diff --git a/python/examples/criteo_ctr_with_cube/benchmark_batch.sh b/python/examples/criteo_ctr_with_cube/benchmark_batch.sh deleted file mode 100755 index 3a51c0de68bf47fb798c165d2fb34868056ddab6..0000000000000000000000000000000000000000 --- a/python/examples/criteo_ctr_with_cube/benchmark_batch.sh +++ /dev/null @@ -1,12 +0,0 @@ -rm profile_log -for thread_num in 1 2 4 8 16 -do -for batch_size in 1 2 4 8 16 32 64 128 256 512 -do - $PYTHONROOT/bin/python benchmark_batch.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1 - echo "========================================" - echo "batch size : $batch_size" >> profile_log - $PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log - tail -n 2 profile >> profile_log -done -done diff --git a/python/examples/criteo_ctr_with_cube/cube_prepare.sh b/python/examples/criteo_ctr_with_cube/cube_prepare.sh index 1417254a54e2194ab3a0194f2ec970f480787acd..773baba4d91b02b244e766cd8ebf899cc740dbbc 100755 --- a/python/examples/criteo_ctr_with_cube/cube_prepare.sh +++ b/python/examples/criteo_ctr_with_cube/cube_prepare.sh @@ -16,7 +16,5 @@ mkdir -p cube_model mkdir -p cube/data -./seq_generator ctr_serving_model/SparseFeatFactors ./cube_model/feature ./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=${PWD}/cube/data -shard_num=1 -only_build=false -mv ./cube/data/0_0/test_dict_part0/* ./cube/data/ -cd cube && ./cube +cd cube && ./cube diff --git a/python/examples/ocr/README.md b/python/examples/ocr/README.md index 04c4fd3eaa304e55d980a2cf4fc34dda50f5009c..3535ed80eb27291aa4da4bb2683923c9e4082acf 100644 --- a/python/examples/ocr/README.md +++ b/python/examples/ocr/README.md @@ -4,18 +4,42 @@ ``` python -m paddle_serving_app.package --get_model ocr_rec tar -xzvf ocr_rec.tar.gz +python -m paddle_serving_app.package --get_model ocr_det +tar -xzvf ocr_det.tar.gz ``` ## RPC Service ### Start Service +For the following two code block, please check your devices and pick one +for GPU device +``` +python -m paddle_serving_server_gpu.serve --model ocr_rec_model --port 9292 --gpu_id 0 +python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +``` +for CPU device ``` python -m paddle_serving_server.serve --model ocr_rec_model --port 9292 +python -m paddle_serving_server.serve --model ocr_det_model --port 9293 ``` ### Client Prediction ``` -python test_ocr_rec_client.py +python ocr_rpc_client.py +``` + +## Web Service + +### Start Service + +``` +python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +python ocr_web_server.py +``` + +### Client Prediction +``` +sh ocr_web_client.sh ``` diff --git a/python/examples/ocr/ocr_rpc_client.py b/python/examples/ocr/ocr_rpc_client.py new file mode 100644 index 0000000000000000000000000000000000000000..212d46c2b226f91bcb0582e76e31ca2acdc8b948 --- /dev/null +++ b/python/examples/ocr/ocr_rpc_client.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes +import time +import re + + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + +def get_rotate_crop_image(img, points): + #img = cv2.imread(img) + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + img_crop_width = int(np.linalg.norm(points[0] - points[1])) + img_crop_height = int(np.linalg.norm(points[0] - points[3])) + pts_std = np.float32([[0, 0], [img_crop_width, 0], \ + [img_crop_width, img_crop_height], [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img_crop, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def read_det_box_file(filename): + with open(filename, 'r') as f: + line = f.readline() + a, b, c = int(line.split(' ')[0]), int(line.split(' ')[1]), int( + line.split(' ')[2]) + dt_boxes = np.zeros((a, b, c)).astype(np.float32) + line = f.readline() + for i in range(a): + for j in range(b): + line = f.readline() + dt_boxes[i, j, 0], dt_boxes[i, j, 1] = float( + line.split(' ')[0]), float(line.split(' ')[1]) + line = f.readline() + + +def resize_norm_img(img, max_wh_ratio): + import math + imgC, imgH, imgW = 3, 32, 320 + imgW = int(32 * max_wh_ratio) + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + +def main(): + client1 = Client() + client1.load_client_config("ocr_det_client/serving_client_conf.prototxt") + client1.connect(["127.0.0.1:9293"]) + + client2 = Client() + client2.load_client_config("ocr_rec_client/serving_client_conf.prototxt") + client2.connect(["127.0.0.1:9292"]) + + read_image_file = File2Image() + preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + + filter_func = FilterBoxes(10, 10) + ocr_reader = OCRReader() + files = [ + "./imgs/{}".format(f) for f in os.listdir('./imgs') + if re.match(r'[0-9]+.*\.jpg|[0-9]+.*\.png', f) + ] + #files = ["2.jpg"]*30 + #files = ["rctw/rctw/train/images/image_{}.jpg".format(i) for i in range(500)] + time_all = 0 + time_det_all = 0 + time_rec_all = 0 + for name in files: + #print(name) + im = read_image_file(name) + ori_h, ori_w, _ = im.shape + time1 = time.time() + img = preprocess(im) + _, new_h, new_w = img.shape + ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] + #print(new_h, new_w, ori_h, ori_w) + time_before_det = time.time() + outputs = client1.predict(feed={"image": img}, fetch=["concat_1.tmp_0"]) + time_after_det = time.time() + time_det_all += (time_after_det - time_before_det) + #print(outputs) + dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list]) + dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) + dt_boxes = sorted_boxes(dt_boxes) + feed_list = [] + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for img in img_list: + norm_img = resize_norm_img(img, max_wh_ratio) + #norm_img = norm_img[np.newaxis, :] + feed = {"image": norm_img} + feed_list.append(feed) + #fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + fetch = ["ctc_greedy_decoder_0.tmp_0"] + time_before_rec = time.time() + if len(feed_list) == 0: + continue + fetch_map = client2.predict(feed=feed_list, fetch=fetch) + time_after_rec = time.time() + time_rec_all += (time_after_rec - time_before_rec) + rec_res = ocr_reader.postprocess(fetch_map) + #for res in rec_res: + # print(res[0].encode("utf-8")) + time2 = time.time() + time_all += (time2 - time1) + print("rpc+det time: {}".format(time_all / len(files))) + + +if __name__ == '__main__': + main() diff --git a/python/examples/ocr/ocr_web_client.sh b/python/examples/ocr/ocr_web_client.sh new file mode 100644 index 0000000000000000000000000000000000000000..5f4f1d7d1fb00dc63b3235533850f56f998a647f --- /dev/null +++ b/python/examples/ocr/ocr_web_client.sh @@ -0,0 +1 @@ + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/others/1.jpg"}], "fetch": ["res"]}' http://127.0.0.1:9292/ocr/prediction diff --git a/python/examples/ocr/ocr_web_server.py b/python/examples/ocr/ocr_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b55027d84252f8590f1e62839ad8cbd25e56c8fe --- /dev/null +++ b/python/examples/ocr/ocr_web_server.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes +from paddle_serving_server_gpu.web_service import WebService +import time +import re + + +class OCRService(WebService): + def init_det_client(self, det_port, det_client_config): + self.det_preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.det_client = Client() + self.det_client.load_client_config(det_client_config) + self.det_client.connect(["127.0.0.1:{}".format(det_port)]) + + def preprocess(self, feed=[], fetch=[]): + img_url = feed[0]["image"] + #print(feed, img_url) + read_from_url = URL2Image() + im = read_from_url(img_url) + ori_h, ori_w, _ = im.shape + det_img = self.det_preprocess(im) + #print("det_img", det_img, det_img.shape) + det_out = self.det_client.predict( + feed={"image": det_img}, fetch=["concat_1.tmp_0"]) + + #print("det_out", det_out) + def sorted_boxes(dt_boxes): + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + def get_rotate_crop_image(img, points): + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + img_crop_width = int(np.linalg.norm(points[0] - points[1])) + img_crop_height = int(np.linalg.norm(points[0] - points[3])) + pts_std = np.float32([[0, 0], [img_crop_width, 0], \ + [img_crop_width, img_crop_height], [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img_crop, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + def resize_norm_img(img, max_wh_ratio): + import math + imgC, imgH, imgW = 3, 32, 320 + imgW = int(32 * max_wh_ratio) + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + _, new_h, new_w = det_img.shape + filter_func = FilterBoxes(10, 10) + post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] + dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) + dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) + dt_boxes = sorted_boxes(dt_boxes) + feed_list = [] + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for img in img_list: + norm_img = resize_norm_img(img, max_wh_ratio) + feed = {"image": norm_img} + feed_list.append(feed) + fetch = ["ctc_greedy_decoder_0.tmp_0"] + #print("feed_list", feed_list) + return feed_list, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + #print(fetch_map) + ocr_reader = OCRReader() + rec_res = ocr_reader.postprocess(fetch_map) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + fetch_map["res"] = res_lst + del fetch_map["ctc_greedy_decoder_0.tmp_0"] + del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"] + return fetch_map + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_rec_model") +ocr_service.prepare_server(workdir="workdir", port=9292) +ocr_service.init_det_client( + det_port=9293, + det_client_config="ocr_det_client/serving_client_conf.prototxt") +ocr_service.run_rpc_service() +ocr_service.run_web_service() diff --git a/python/examples/ocr/test_rec.jpg b/python/examples/ocr/test_rec.jpg deleted file mode 100644 index 2c34cd33eac5766a072fde041fa6c9b1d612f1db..0000000000000000000000000000000000000000 Binary files a/python/examples/ocr/test_rec.jpg and /dev/null differ diff --git a/python/paddle_serving_app/models/model_list.py b/python/paddle_serving_app/models/model_list.py index 79b3f91bd6584d17ddbc4124584cf40bd586b965..3b0c3cb9c4927df7ba55830657318073b1a3a7cc 100644 --- a/python/paddle_serving_app/models/model_list.py +++ b/python/paddle_serving_app/models/model_list.py @@ -24,14 +24,15 @@ class ServingModels(object): "SentimentAnalysis"] = ["senta_bilstm", "senta_bow", "senta_cnn"] self.model_dict["SemanticRepresentation"] = ["ernie"] self.model_dict["ChineseWordSegmentation"] = ["lac"] - self.model_dict["ObjectDetection"] = ["faster_rcnn", "yolov4"] + self.model_dict[ + "ObjectDetection"] = ["faster_rcnn", "yolov4", "blazeface"] self.model_dict["ImageSegmentation"] = [ "unet", "deeplabv3", "deeplabv3+cityscapes" ] self.model_dict["ImageClassification"] = [ "resnet_v2_50_imagenet", "mobilenet_v2_imagenet" ] - self.model_dict["TextDetection"] = ["ocr_detection"] + self.model_dict["TextDetection"] = ["ocr_det"] self.model_dict["OCR"] = ["ocr_rec"] image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/" diff --git a/python/paddle_serving_app/reader/functional.py b/python/paddle_serving_app/reader/functional.py index 4240641dd99fceb278ff60a5ba1dbb5275e534aa..7bab279c7f1aa71a2d55a8cb7b12bcb38607eb70 100644 --- a/python/paddle_serving_app/reader/functional.py +++ b/python/paddle_serving_app/reader/functional.py @@ -29,6 +29,7 @@ def normalize(img, mean, std, channel_first): else: img_mean = np.array(mean).reshape((1, 1, 3)) img_std = np.array(std).reshape((1, 1, 3)) + img = np.array(img).astype("float32") img -= img_mean img /= img_std return img diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py index a44ca5de84da2bafce9b4cea37fb88095debabc6..096f46549af137cb04a87e26a3b28c8d42e33daa 100644 --- a/python/paddle_serving_app/reader/image_reader.py +++ b/python/paddle_serving_app/reader/image_reader.py @@ -440,6 +440,30 @@ class RCNNPostprocess(object): self.label_file, self.output_dir) +class BlazeFacePostprocess(RCNNPostprocess): + def clip_bbox(self, bbox, im_size=None): + h = 1. if im_size is None else im_size[0] + w = 1. if im_size is None else im_size[1] + xmin = max(min(bbox[0], w), 0.) + ymin = max(min(bbox[1], h), 0.) + xmax = max(min(bbox[2], w), 0.) + ymax = max(min(bbox[3], h), 0.) + return xmin, ymin, xmax, ymax + + def _get_bbox_result(self, fetch_map, fetch_name, clsid2catid): + result = {} + is_bbox_normalized = True #for blaze face, set true here + output = fetch_map[fetch_name] + lod = [fetch_map[fetch_name + '.lod']] + lengths = self._offset_to_lengths(lod) + np_data = np.array(output) + result['bbox'] = (np_data, lengths) + result['im_id'] = np.array([[0]]) + result["im_shape"] = np.array(fetch_map["im_shape"]).astype(np.int32) + bbox_results = self._bbox2out([result], clsid2catid, is_bbox_normalized) + return bbox_results + + class Sequential(object): """ Args: diff --git a/python/paddle_serving_app/reader/ocr_reader.py b/python/paddle_serving_app/reader/ocr_reader.py index e5dc88482bd5e0a7a26873fd5cb60c43dc5104c9..72a2918f89a8ccc913894f3f46fab08f51cf9460 100644 --- a/python/paddle_serving_app/reader/ocr_reader.py +++ b/python/paddle_serving_app/reader/ocr_reader.py @@ -182,22 +182,26 @@ class OCRReader(object): return norm_img_batch[0] - def postprocess(self, outputs): + def postprocess(self, outputs, with_score=False): rec_res = [] rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"] - predict_lod = outputs["softmax_0.tmp_0.lod"] rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"] + if with_score: + predict_lod = outputs["softmax_0.tmp_0.lod"] for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] end = rec_idx_lod[rno + 1] rec_idx_tmp = rec_idx_batch[beg:end, 0] preds_text = self.char_ops.decode(rec_idx_tmp) - beg = predict_lod[rno] - end = predict_lod[rno + 1] - probs = outputs["softmax_0.tmp_0"][beg:end, :] - ind = np.argmax(probs, axis=1) - blank = probs.shape[1] - valid_ind = np.where(ind != (blank - 1))[0] - score = np.mean(probs[valid_ind, ind[valid_ind]]) - rec_res.append([preds_text, score]) + if with_score: + beg = predict_lod[rno] + end = predict_lod[rno + 1] + probs = outputs["softmax_0.tmp_0"][beg:end, :] + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + score = np.mean(probs[valid_ind, ind[valid_ind]]) + rec_res.append([preds_text, score]) + else: + rec_res.append([preds_text]) return rec_res