From fdca422092dc9cda9f1fbbfa4deb2094ff8eed5a Mon Sep 17 00:00:00 2001 From: wangjiawei04 Date: Wed, 26 Aug 2020 14:21:34 +0000 Subject: [PATCH] support local on pipelien --- python/examples/pipeline/ocr/brpc_config.yml | 23 ++++ python/examples/pipeline/ocr/config.yml | 7 +- python/examples/pipeline/ocr/local_service.py | 115 ++++++++++++++++++ python/examples/pipeline/ocr/web_service.py | 4 +- python/paddle_serving_app/local_predict.py | 6 +- python/pipeline/dag.py | 13 +- python/pipeline/local_rpc_service_handler.py | 27 +++- python/pipeline/operator.py | 99 +++++++++------ 8 files changed, 243 insertions(+), 51 deletions(-) create mode 100644 python/examples/pipeline/ocr/brpc_config.yml create mode 100644 python/examples/pipeline/ocr/local_service.py diff --git a/python/examples/pipeline/ocr/brpc_config.yml b/python/examples/pipeline/ocr/brpc_config.yml new file mode 100644 index 00000000..6e8de736 --- /dev/null +++ b/python/examples/pipeline/ocr/brpc_config.yml @@ -0,0 +1,23 @@ +rpc_port: 18080 +worker_num: 4 +build_dag_each_worker: false +http_port: 9999 +dag: + is_thread_op: false + retry: 1 + use_profile: false +op: + det: + concurrency: 2 + local_service_conf: + client_type: brpc + model_config: ocr_det_model + devices: "" + rec: + concurrency: 1 + timeout: -1 + retry: 1 + local_service_conf: + client_type: brpc + model_config: ocr_rec_model + devices: "" diff --git a/python/examples/pipeline/ocr/config.yml b/python/examples/pipeline/ocr/config.yml index 48addccf..3b1fb357 100644 --- a/python/examples/pipeline/ocr/config.yml +++ b/python/examples/pipeline/ocr/config.yml @@ -4,19 +4,20 @@ build_dag_each_worker: false http_port: 9999 dag: is_thread_op: false - client_type: brpc retry: 1 use_profile: false op: det: concurrency: 2 local_service_conf: + client_type: local_predictor model_config: ocr_det_model - devices: "0" + devices: "" rec: concurrency: 1 timeout: -1 retry: 1 local_service_conf: + client_type: local_predictor model_config: ocr_rec_model - devices: "0" + devices: "" diff --git a/python/examples/pipeline/ocr/local_service.py b/python/examples/pipeline/ocr/local_service.py new file mode 100644 index 00000000..48a1667d --- /dev/null +++ b/python/examples/pipeline/ocr/local_service.py @@ -0,0 +1,115 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from paddle_serving_server.web_service import WebService, Op +except ImportError: + from paddle_serving_server.web_service import WebService, Op +import logging +import numpy as np +import cv2 +import base64 +from paddle_serving_app.reader import OCRReader +from paddle_serving_app.reader import Sequential, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes + +_LOGGER = logging.getLogger() + + +class DetOp(Op): + def init_op(self): + self.det_preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.filter_func = FilterBoxes(10, 10) + self.post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + + def preprocess(self, input_dicts): + (_, input_dict), = input_dicts.items() + data = base64.b64decode(input_dict["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + # Note: class variables(self.var) can only be used in process op mode + self.im = cv2.imdecode(data, cv2.IMREAD_COLOR) + self.ori_h, self.ori_w, _ = self.im.shape + det_img = self.det_preprocess(self.im) + _, self.new_h, self.new_w = det_img.shape + return {"image": det_img[np.newaxis,:].copy()} + + def postprocess(self, input_dicts, fetch_dict): + det_out = fetch_dict["concat_1.tmp_0"] + ratio_list = [ + float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w + ] + dt_boxes_list = self.post_func(det_out, [ratio_list]) + dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w]) + out_dict = {"dt_boxes": dt_boxes, "image": self.im} + return out_dict + + +class RecOp(Op): + def init_op(self): + self.ocr_reader = OCRReader() + self.get_rotate_crop_image = GetRotateCropImage() + self.sorted_boxes = SortedBoxes() + + def preprocess(self, input_dicts): + (_, input_dict), = input_dicts.items() + im = input_dict["image"] + dt_boxes = input_dict["dt_boxes"] + dt_boxes = self.sorted_boxes(dt_boxes) + feed_list = [] + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = self.get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + _, w, h = self.ocr_reader.resize_norm_img(img_list[0], + max_wh_ratio).shape + imgs = np.zeros((len(img_list), 3, w, h)).astype('float32') + for id, img in enumerate(img_list): + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) + imgs[id] = norm_img + feed = {"image": imgs.copy()} + return feed + + def postprocess(self, input_dicts, fetch_dict): + rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + res = {"res": str(res_lst)} + return res + + +class OcrService(WebService): + def get_pipeline_response(self, read_op): + det_op = DetOp(name="det", input_ops=[read_op]) + rec_op = RecOp(name="rec", input_ops=[det_op]) + return rec_op + + +uci_service = OcrService(name="ocr") +uci_service.prepare_pipeline_config("config.yml") +uci_service.run_service() diff --git a/python/examples/pipeline/ocr/web_service.py b/python/examples/pipeline/ocr/web_service.py index d1e6ec80..479b00e7 100644 --- a/python/examples/pipeline/ocr/web_service.py +++ b/python/examples/pipeline/ocr/web_service.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. try: - from paddle_serving_server_gpu.web_service import WebService, Op + from paddle_serving_server.web_service import WebService, Op except ImportError: from paddle_serving_server.web_service import WebService, Op import logging @@ -108,5 +108,5 @@ class OcrService(WebService): uci_service = OcrService(name="ocr") -uci_service.prepare_pipeline_config("config.yml") +uci_service.prepare_pipeline_config("brpc_config.yml") uci_service.run_service() diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py index afe6d474..ce206273 100644 --- a/python/paddle_serving_app/local_predict.py +++ b/python/paddle_serving_app/local_predict.py @@ -76,7 +76,9 @@ class Debugger(object): config.switch_use_feed_fetch_ops(False) self.predictor = create_paddle_predictor(config) - def predict(self, feed=None, fetch=None): + def predict(self, feed=None, fetch=None, log_id=0): + print("feed", feed) + print("fetch", fetch) if feed is None or fetch is None: raise ValueError("You should specify feed and fetch for prediction") fetch_list = [] @@ -139,5 +141,5 @@ class Debugger(object): for i, name in enumerate(fetch): fetch_map[name] = outputs[i] if len(output_tensors[i].lod()) > 0: - fetch_map[name + ".lod"] = output_tensors[i].lod()[0] + fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[0]).astype('int32') return fetch_map diff --git a/python/pipeline/dag.py b/python/pipeline/dag.py index 272071f3..639cec82 100644 --- a/python/pipeline/dag.py +++ b/python/pipeline/dag.py @@ -43,7 +43,6 @@ class DAGExecutor(object): dag_conf = server_conf["dag"] self._retry = dag_conf["retry"] - client_type = dag_conf["client_type"] self._server_use_profile = dag_conf["use_profile"] channel_size = dag_conf["channel_size"] self._is_thread_op = dag_conf["is_thread_op"] @@ -61,7 +60,7 @@ class DAGExecutor(object): self._is_thread_op, tracer_interval_s, server_worker_num) self._dag = DAG(self.name, response_op, self._server_use_profile, - self._is_thread_op, client_type, channel_size, + self._is_thread_op, channel_size, build_dag_each_worker, self._tracer) (in_channel, out_channel, pack_rpc_func, unpack_rpc_func) = self._dag.build() @@ -324,13 +323,12 @@ class DAGExecutor(object): class DAG(object): def __init__(self, request_name, response_op, use_profile, is_thread_op, - client_type, channel_size, build_dag_each_worker, tracer): + channel_size, build_dag_each_worker, tracer): self._request_name = request_name self._response_op = response_op self._use_profile = use_profile self._is_thread_op = is_thread_op self._channel_size = channel_size - self._client_type = client_type self._build_dag_each_worker = build_dag_each_worker self._tracer = tracer if not self._is_thread_op: @@ -571,10 +569,10 @@ class DAG(object): op.set_tracer(self._tracer) if self._is_thread_op: self._threads_or_proces.extend( - op.start_with_thread(self._client_type)) + op.start_with_thread()) else: self._threads_or_proces.extend( - op.start_with_process(self._client_type)) + op.start_with_process()) _LOGGER.info("[DAG] start") # not join yet @@ -582,7 +580,8 @@ class DAG(object): def join(self): for x in self._threads_or_proces: - x.join() + if x is not None: + x.join() def stop(self): for chl in self._channels: diff --git a/python/pipeline/local_rpc_service_handler.py b/python/pipeline/local_rpc_service_handler.py index 376fcaf1..9d1946ca 100644 --- a/python/pipeline/local_rpc_service_handler.py +++ b/python/pipeline/local_rpc_service_handler.py @@ -21,7 +21,7 @@ try: except ImportError: from paddle_serving_server import OpMaker, OpSeqMaker, Server PACKAGE_VERSION = "CPU" -from . import util +import util _LOGGER = logging.getLogger(__name__) _workdir_name_gen = util.NameGenerator("workdir_") @@ -132,3 +132,28 @@ class LocalRpcServiceHandler(object): self._server_pros.append(p) for p in self._server_pros: p.start() + + +class LocalPredictorServiceHandler(LocalRpcServiceHandler): + def prepare_server(self): + from paddle_serving_app.local_predict import Debugger + gpuid = self._devices + if gpuid == -1: + gpu = False + else: + gpu = True + self.predictor = Debugger() + self.predictor.load_model_config(model_path=self._model_config, gpu=gpu, profile=False, cpu_num=1) + + def get_client(self): + if self.predictor is None: + raise ValueError("local predictor not yet created.") + return self.predictor + + def get_fetch_list(self): + if self.predictor is None: + raise ValueError("local predictor not yet created.") + return self.predictor.fetch_names_ + + def start_server(self): + pass diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 3b928b9c..f503184b 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -128,29 +128,44 @@ class Op(object): _LOGGER.info("local_service_conf: {}".format( local_service_conf)) model_config = local_service_conf.get("model_config") + self.client_type = local_service_conf.get("client_type") _LOGGER.info("model_config: {}".format(model_config)) if model_config is None: self.with_serving = False else: # local rpc service self.with_serving = True - service_handler = local_rpc_service_handler.LocalRpcServiceHandler( - model_config=model_config, - workdir=local_service_conf["workdir"], - thread_num=local_service_conf["thread_num"], - devices=local_service_conf["devices"], - mem_optim=local_service_conf["mem_optim"], - ir_optim=local_service_conf["ir_optim"]) - service_handler.prepare_server() # get fetch_list - serivce_ports = service_handler.get_port_list() - self._server_endpoints = [ - "127.0.0.1:{}".format(p) for p in serivce_ports - ] - if self._client_config is None: - self._client_config = service_handler.get_client_config( - ) - if self._fetch_names is None: - self._fetch_names = service_handler.get_fetch_list() + if self.client_type == "brpc" or self.client_type == "grpc": + service_handler = local_rpc_service_handler.LocalRpcServiceHandler( + model_config=model_config, + workdir=local_service_conf["workdir"], + thread_num=local_service_conf["thread_num"], + devices=local_service_conf["devices"], + mem_optim=local_service_conf["mem_optim"], + ir_optim=local_service_conf["ir_optim"]) + service_handler.prepare_server() # get fetch_list + serivce_ports = service_handler.get_port_list() + self._server_endpoints = [ + "127.0.0.1:{}".format(p) for p in serivce_ports + ] + if self._client_config is None: + self._client_config = service_handler.get_client_config( + ) + if self._fetch_names is None: + self._fetch_names = service_handler.get_fetch_list() + elif self.client_type == "local_predictor": + service_handler = local_rpc_service_handler.LocalPredictorServiceHandler( + model_config=model_config, + workdir=local_service_conf["workdir"], + thread_num=local_service_conf["thread_num"], + devices=local_service_conf["devices"]) + service_handler.prepare_server() # get fetch_list + self.local_predictor = service_handler.get_client() + if self._client_config is None: + self._client_config = service_handler.get_client_config( + ) + if self._fetch_names is None: + self._fetch_names = service_handler.get_fetch_list() self._local_rpc_service_handler = service_handler else: self.with_serving = True @@ -215,21 +230,27 @@ class Op(object): def set_tracer(self, tracer): self._tracer = tracer - def init_client(self, client_type, client_config, server_endpoints, + def init_client(self, client_config, server_endpoints, fetch_names): + print("init client", fetch_names) if self.with_serving == False: _LOGGER.info("Op({}) has no client (and it also do not " "run the process function)".format(self.name)) return None - if client_type == 'brpc': + if self.client_type == 'brpc': client = Client() client.load_client_config(client_config) - elif client_type == 'grpc': + elif self.client_type == 'grpc': client = MultiLangClient() + elif self.client_type == 'local_predictor': + if self.local_predictor is None: + raise ValueError("local predictor not yet created") + client = self.local_predictor else: raise ValueError("Failed to init client: unknow client " - "type {}".format(client_type)) - client.connect(server_endpoints) + "type {}".format(self.client_type)) + if self.client_type != "local_predictor": + client.connect(server_endpoints) self._fetch_names = fetch_names return client @@ -292,14 +313,19 @@ class Op(object): return input_dict def process(self, feed_batch, typical_logid): + print("now we start process") err, err_info = ChannelData.check_batch_npdata(feed_batch) if err != 0: _LOGGER.critical( self._log("Failed to run process: {}. Please override " "preprocess func.".format(err_info))) os._exit(-1) - call_result = self.client.predict( - feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid) + if self.client_type == "local_predictor": + call_result = self.client.predict(feed=feed_batch[0], fetch=self._fetch_names, log_id=typical_logid) + else: + call_result = self.client.predict( + feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid) + print("now we end predict") if isinstance(self.client, MultiLangClient): if call_result is None or call_result["serving_status_code"] != 0: return None @@ -347,23 +373,23 @@ class Op(object): for channel in channels: channel.push(data, name) - def start_with_process(self, client_type): + def start_with_process(self): trace_buffer = None if self._tracer is not None: trace_buffer = self._tracer.data_buffer() - proces = [] + process= [] for concurrency_idx in range(self.concurrency): p = multiprocessing.Process( target=self._run, args=(concurrency_idx, self._get_input_channel(), - self._get_output_channels(), client_type, False, + self._get_output_channels(), False, trace_buffer)) p.daemon = True p.start() - proces.append(p) - return proces + process.append(p) + return process - def start_with_thread(self, client_type): + def start_with_thread(self): trace_buffer = None if self._tracer is not None: trace_buffer = self._tracer.data_buffer() @@ -372,7 +398,7 @@ class Op(object): t = threading.Thread( target=self._run, args=(concurrency_idx, self._get_input_channel(), - self._get_output_channels(), client_type, True, + self._get_output_channels(), True, trace_buffer)) # When a process exits, it attempts to terminate # all of its daemonic child processes. @@ -537,6 +563,7 @@ class Op(object): lod_offset_right = lod_offset[data_offset_right] midped_data_dict[data_id][name] = value[ lod_offset_left:lod_offset_right] + print(lod_offset[data_offset_left:data_offset_right + 1], lod_offset[data_offset_left]) midped_data_dict[data_id][lod_offset_name] = \ lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left] else: @@ -652,7 +679,7 @@ class Op(object): return parsed_data_dict, need_profile_dict, profile_dict - def _run(self, concurrency_idx, input_channel, output_channels, client_type, + def _run(self, concurrency_idx, input_channel, output_channels, is_thread_op, trace_buffer): op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) tid = threading.current_thread().ident @@ -660,7 +687,7 @@ class Op(object): # init op profiler = None try: - profiler = self._initialize(is_thread_op, client_type, + profiler = self._initialize(is_thread_op, concurrency_idx) except Exception as e: _LOGGER.critical( @@ -801,7 +828,7 @@ class Op(object): except Queue.Full: break - def _initialize(self, is_thread_op, client_type, concurrency_idx): + def _initialize(self, is_thread_op, concurrency_idx): if is_thread_op: with self._for_init_op_lock: if not self._succ_init_op: @@ -809,7 +836,7 @@ class Op(object): self.concurrency_idx = None # init client self.client = self.init_client( - client_type, self._client_config, + self._client_config, self._server_endpoints, self._fetch_names) # user defined self.init_op() @@ -818,7 +845,7 @@ class Op(object): else: self.concurrency_idx = concurrency_idx # init client - self.client = self.init_client(client_type, self._client_config, + self.client = self.init_client(self._client_config, self._server_endpoints, self._fetch_names) # user defined -- GitLab