diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index f2d1e450940eb1ae17641c8c86fbf738eb7da34e..3d90161c88e53f3edde9557e69885950782f4459 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -18,9 +18,15 @@ from pyserver import Channel from pyserver import PyServer import numpy as np import python_service_channel_pb2 +import logging +logging.basicConfig( + format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', + datefmt='%Y-%m-%d %H:%M', + level=logging.INFO) # channel data: {name(str): data(bytes)} +""" class ImdbOp(Op): def postprocess(self, output_data): data = python_service_channel_pb2.ChannelData() @@ -28,36 +34,44 @@ class ImdbOp(Op): pred = np.array(output_data["prediction"][0][0], dtype='float') inst.data = np.ndarray.tobytes(pred) inst.name = "prediction" - inst.id = 0 #TODO + inst.id = 0 #TODO data.insts.append(inst) return data +""" class CombineOp(Op): def preprocess(self, input_data): + data_id = None cnt = 0 for input in input_data: data = input[0] # batchsize=1 cnt += np.frombuffer(data.insts[0].data, dtype='float') + if data_id is None: + data_id = data.id + if data_id != data.id: + raise Exception("id not match: {} vs {}".format(data_id, + data.id)) data = python_service_channel_pb2.ChannelData() inst = python_service_channel_pb2.Inst() inst.data = np.ndarray.tobytes(cnt) inst.name = "resp" - inst.id = 0 #TODO data.insts.append(inst) + data.id = data_id print(data) return data class UciOp(Op): def postprocess(self, output_data): + data_ids = self.get_data_ids() data = python_service_channel_pb2.ChannelData() inst = python_service_channel_pb2.Inst() pred = np.array(output_data["price"][0][0], dtype='float') inst.data = np.ndarray.tobytes(pred) inst.name = "prediction" - inst.id = 0 #TODO data.insts.append(inst) + data.id = data_ids[0] return data @@ -121,5 +135,5 @@ pyserver.add_channel(combine_out_channel) pyserver.add_op(cnn_op) pyserver.add_op(bow_op) pyserver.add_op(combine_op) -pyserver.prepare_server(port=8080, worker_num=1) +pyserver.prepare_server(port=8080, worker_num=2) pyserver.run_server() diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index b64ac968ae027c60ed3fdc660ff62c8ec2d56fe3..4849f4e0fc7333fca486a2a16d5343131f5adf8f 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -24,6 +24,8 @@ import grpc import general_python_service_pb2 import general_python_service_pb2_grpc import python_service_channel_pb2 +import logging +import time class Channel(Queue.Queue): @@ -39,10 +41,12 @@ class Channel(Queue.Queue): self._pushbatch = [] self._frontbatch = None self._count = 0 + self._order = 0 def push(self, item): with self._pushlock: self._pushbatch.append(item) + self._order += 1 if len(self._pushbatch) == self._batchsize: self.put(self._pushbatch, timeout=self._timeout) # self.put(self._pushbatch) @@ -87,6 +91,7 @@ class Op(object): self._server_model = server_model self._server_port = server_port self._device = device + self._data_ids = [] def set_client(self, client_config, server_name, fetch_names): self._client = Client() @@ -113,12 +118,22 @@ class Op(object): raise TypeError('channels must be list type') self._outputs = channels + def get_data_ids(self): + return self._data_ids + + def clear_data_ids(self): + self._data_ids = [] + + def append_id_to_data_ids(self, data_id): + self._data_ids.append(data_id) + def preprocess(self, input_data): if len(input_data) != 1: raise Exception( 'this Op has multiple previous channels. Please override this method' ) feed_batch = [] + self.clear_data_ids() for data in input_data: if len(data.insts) != self._batch_size: raise Exception('len(data_insts) != self._batch_size') @@ -126,12 +141,13 @@ class Op(object): for inst in data.insts: feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype) feed_batch.append(feed) + self.append_id_to_data_ids(data.id) return feed_batch def midprocess(self, data): # data = preprocess(input), which must be a dict - print('data: {}'.format(data)) - print('fetch: {}'.format(self._fetch_names)) + logging.debug('data: {}'.format(data)) + logging.debug('fetch: {}'.format(self._fetch_names)) fetch_map = self._client.predict(feed=data, fetch=self._fetch_names) return fetch_map @@ -168,36 +184,80 @@ class GeneralPythonService( super(GeneralPythonService, self).__init__() self._in_channel = in_channel self._out_channel = out_channel - print('succ init') - - def inference(self, request, context): - print('start inferce') + self._lock = threading.Lock() + self._globel_resp_dict = {} + self._id_counter = 0 + self._recive_func = threading.Thread( + target=GeneralPythonService._recive_out_channel_func, args=(self, )) + self._recive_func.start() + logging.debug('succ init') + + def _recive_out_channel_func(self): + while True: + data = self._out_channel.front() + data_id = None + for d in data: + if data_id is None: + data_id = d.id + if data_id != d.id: + raise Exception("id not match: {} vs {}".format(data_id, + d.id)) + with self._lock: + self._globel_resp_dict[data_id] = data + #TODO wake up inference + + def _get_next_id(self): + with self._lock: + self._id_counter += 1 + return self._id_counter - 1 + + def _get_data_in_globel_resp_dict(self, data_id): + if data_id in self._globel_resp_dict: + with self._lock: + return self._globel_resp_dict.pop(data_id) + return None + + def _pack_data_for_infer(self, request): + logging.debug('start inferce') data = python_service_channel_pb2.ChannelData() - print('gen data: {}'.format(data)) + data_id = self._get_next_id() + data.id = data_id for idx, name in enumerate(request.feed_var_names): - print('name: {}'.format(request.feed_var_names[idx])) - print('data: {}'.format(request.feed_insts[idx])) + logging.debug('name: {}'.format(request.feed_var_names[idx])) + logging.debug('data: {}'.format(request.feed_insts[idx])) inst = python_service_channel_pb2.Inst() inst.data = request.feed_insts[idx] inst.name = name - inst.id = 0 #TODO data.insts.append(inst) - print('push data') - self._in_channel.push(data) - print('wait for infer') - data = self._out_channel.front() + return data, data_id + + def _pack_data_for_resp(self, data): data = data[0] #TODO batchsize = 1 - print('get data') + logging.debug('get data') resp = general_python_service_pb2.Response() - print('gen resp') - print(data) + logging.debug('gen resp') + logging.debug(data) for inst in data.insts: - print('append data') + logging.debug('append data') resp.fetch_insts.append(inst.data) - print('append name') + logging.debug('append name') resp.fetch_var_names.append(inst.name) return resp + def inference(self, request, context): + data, data_id = self._pack_data_for_infer(request) + logging.debug('push data') + self._in_channel.push(data) + logging.debug('wait for infer') + resp_data = None + while True: + resp_data = self._get_data_in_globel_resp_dict(data_id) + if resp_data is not None: + break + time.sleep(0.05) #TODO: wake up by _recive_out_channel_func + resp = self._pack_data_for_resp(resp_data) + return resp + class PyServer(object): def __init__(self): @@ -216,7 +276,7 @@ class PyServer(object): self._ops.append(op) def gen_desc(self): - print('here will generate desc for paas') + logging.info('here will generate desc for paas') pass def prepare_server(self, port, worker_num): @@ -273,6 +333,6 @@ class PyServer(object): else: cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format( model_path, port) - print(cmd) + logging.info(cmd) return os.system(cmd) diff --git a/python/paddle_serving_server/python_service_channel.proto b/python/paddle_serving_server/python_service_channel.proto index 2314368d7f50494f017890ff793460947f04e44f..dc97d3eb5bedf2d954e00840da458d39674de498 100644 --- a/python/paddle_serving_server/python_service_channel.proto +++ b/python/paddle_serving_server/python_service_channel.proto @@ -14,11 +14,13 @@ syntax = "proto2"; -message ChannelData { repeated Inst insts = 1; } +message ChannelData { + repeated Inst insts = 1; + required int32 id = 2; + optional string type = 3 [ default = "channel" ]; +} message Inst { required bytes data = 1; required string name = 2; - required int32 id = 3; - optional string type = 4 [ default = "channel" ]; }