diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index e3ea39154650e237120bd73045448f9c7f46677a..31d34e19febdae14b9a9c38c0d087e75f8e6843b 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -57,4 +57,5 @@ pyserver.add_cnannel(combine_out_channel) pyserver.add_op(cnn_op) pyserver.add_op(bow_op) pyserver.add_op(combine_op) +pyserver.prepare_server(port=8080, worker_num=4) pyserver.run_server() diff --git a/python/paddle_serving_server/pserving.py b/python/paddle_serving_server/pserving.py index 6d8112deb27efa5f8f2cf7cb96a9d122adf2236d..9102b0a09d32282c5bcf78a86e50db9b8a8e8dd1 100644 --- a/python/paddle_serving_server/pserving.py +++ b/python/paddle_serving_server/pserving.py @@ -18,6 +18,7 @@ import queue import os import paddle_serving_server from paddle_serving_client import Client +from concurrent import futures import grpc import general_python_service_pb2 import general_python_service_pb2_grpc @@ -82,11 +83,20 @@ class Op(object): self._client.connect([server_name]) self._fetch_names = fetch_names + def with_serving(self): + return self._client is not None + + def get_inputs(self): + return self._inputs + def set_inputs(self, channels): if not isinstance(channels, list): raise TypeError('channels must be list type') self._inputs = channels + def get_outputs(self): + return self._outputs + def set_outputs(self, channels): if not isinstance(channels, list): raise TypeError('channels must be list type') @@ -114,7 +124,7 @@ class Op(object): input_data.append(channel.front()) data = self.preprocess(input_data) - if self._client is not None: + if self.with_serving(): fetch_map = self.midprocess(data) output_data = self.postprocess(fetch_map) else: @@ -124,11 +134,25 @@ class Op(object): channel.push(output_data) +class GeneralPythonService( + general_python_service_pb2_grpc.GeneralPythonService): + def __init__(self, channel): + self._channel = channel + + def Request(self, request, context): + pass + + def Response(self, request, context): + pass + + class PyServer(object): def __init__(self): self._channels = [] self._ops = [] self._op_threads = [] + self._port = None + self._worker_num = None def add_channel(self, channel): self._channels.append(channel) @@ -139,16 +163,42 @@ class PyServer(object): def gen_desc(self): pass + def prepare_server(self, port, worker_num): + self._port = port + self._worker_num = worker_num + self.gen_desc() + def run_server(self): + inputs = [] + outputs = [] for op in self._ops: - self.prepare_server(op) + inputs += op.get_inputs() + outputs += op.get_outputs() + if op.with_serving(): + self.prepare_serving(op) th = multiprocessing.Process(target=op.start, args=(op, )) th.start() self._op_threads.append(th) - for th in self._op_threads: - th.join() - def prepare_server(self, op): + input_channel = [] + for channel in inputs: + if channel not in outputs: + input_channel.append(channel) + if len(input_channel) != 1: + raise Exception("input_channel more than 1 or no input_channel") + + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self._worker_num)) + general_python_service_pb2_grpc.add_GeneralPythonService_to_server( + GeneralPythonService(input_channel[0]), server) + server.start() + try: + for th in self._op_threads: + th.join() + except KeyboardInterrupt: + server.stop(0) + + def prepare_serving(self, op): model_path = op._server_model port = op._server_port device = op._device