diff --git a/python/examples/imdb/test_py_client.py b/python/examples/imdb/test_py_client.py index 82cbfb92105aa49a6a2669ed5170d078a9bc7345..a4a2c3194089cc94edae8b95fd823d45a9776a01 100644 --- a/python/examples/imdb/test_py_client.py +++ b/python/examples/imdb/test_py_client.py @@ -43,8 +43,9 @@ data = np.ndarray.tobytes(x) req.feed_var_names.append("x") req.feed_insts.append(data) -resp = stub.inference(req) -for idx, name in enumerate(resp.fetch_var_names): - print('{}: {}'.format( - name, np.frombuffer( - resp.fetch_insts[idx], dtype='float'))) +for i in range(100): + resp = stub.inference(req) + for idx, name in enumerate(resp.fetch_var_names): + print('{}: {}'.format( + name, np.frombuffer( + resp.fetch_insts[idx], dtype='float'))) diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index dd8ea7bcba9bf95c2ce5eba0ab144660342b6824..fa1fa0d763497e05e4f21a11a5f4e29c0b3b1869 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -29,83 +29,80 @@ logging.basicConfig( class CombineOp(Op): - #TODO: different id of data def preprocess(self, input_data): - data_id = None cnt = 0 - for input in input_data: - data = input[0] # batchsize=1 + for op_name, data in input_data.items(): + logging.debug("CombineOp preprocess: {}".format(op_name)) cnt += np.frombuffer(data.insts[0].data, dtype='float') - if data_id is None: - data_id = data.id - if data_id != data.id: - raise Exception("id not match: {} vs {}".format(data_id, - data.id)) data = python_service_channel_pb2.ChannelData() inst = python_service_channel_pb2.Inst() inst.data = np.ndarray.tobytes(cnt) inst.name = "resp" data.insts.append(inst) - data.id = data_id return data + def postprocess(self, output_data): + return output_data + class UciOp(Op): def postprocess(self, output_data): - data_ids = self.get_data_ids() data = python_service_channel_pb2.ChannelData() inst = python_service_channel_pb2.Inst() pred = np.array(output_data["price"][0][0], dtype='float') inst.data = np.ndarray.tobytes(pred) inst.name = "prediction" data.insts.append(inst) - data.id = data_ids[0] return data -read_channel = Channel() -cnn_out_channel = Channel() -bow_out_channel = Channel() -combine_out_channel = Channel() +read_channel = Channel(name="read_channel") +combine_channel = Channel(name="combine_channel") +out_channel = Channel(name="out_channel") cnn_op = UciOp( name="cnn_op", - inputs=[read_channel], + input=read_channel, in_dtype='float', - outputs=[cnn_out_channel], + outputs=[combine_channel], out_dtype='float', server_model="./uci_housing_model", server_port="9393", device="cpu", client_config="uci_housing_client/serving_client_conf.prototxt", server_name="127.0.0.1:9393", - fetch_names=["price"]) + fetch_names=["price"], + concurrency=2) bow_op = UciOp( name="bow_op", - inputs=[read_channel], + input=read_channel, in_dtype='float', - outputs=[bow_out_channel], + outputs=[combine_channel], out_dtype='float', server_model="./uci_housing_model", server_port="9292", device="cpu", client_config="uci_housing_client/serving_client_conf.prototxt", server_name="127.0.0.1:9393", - fetch_names=["price"]) + fetch_names=["price"], + concurrency=2) combine_op = CombineOp( name="combine_op", - inputs=[cnn_out_channel, bow_out_channel], + input=combine_channel, in_dtype='float', - outputs=[combine_out_channel], - out_dtype='float') + outputs=[out_channel], + out_dtype='float', + concurrency=2) +logging.info(read_channel.debug()) +logging.info(combine_channel.debug()) +logging.info(out_channel.debug()) pyserver = PyServer() pyserver.add_channel(read_channel) -pyserver.add_channel(cnn_out_channel) -pyserver.add_channel(bow_out_channel) -pyserver.add_channel(combine_out_channel) +pyserver.add_channel(combine_channel) +pyserver.add_channel(out_channel) pyserver.add_op(cnn_op) pyserver.add_op(bow_op) pyserver.add_op(combine_op) diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index c33a5d7c295e51dde72967d874b08678efa357e3..a55e48d8e2010047877503691ce8617663ef70ae 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -25,81 +25,202 @@ import general_python_service_pb2 import general_python_service_pb2_grpc import python_service_channel_pb2 import logging +import random import time class Channel(Queue.Queue): - def __init__(self, maxsize=-1, timeout=None, batchsize=1): + """ + The channel used for communication between Ops. + + 1. Support multiple different Op feed data (multiple producer) + Different types of data will be packaged through the data ID + 2. Support multiple different Op fetch data (multiple consumer) + Only when all types of Ops get the data of the same ID, + the data will be poped; The Op of the same type will not + get the data of the same ID. + 3. (TODO) Timeout and BatchSize are not fully supported. + + Note: + 1. The ID of the data in the channel must be different. + 2. The function add_producer() and add_consumer() are not thread safe, + and can only be called during initialization. + """ + + def __init__(self, name=None, maxsize=-1, timeout=None, batchsize=1): Queue.Queue.__init__(self, maxsize=maxsize) self._maxsize = maxsize self._timeout = timeout - self._batchsize = batchsize - self._pushlock = threading.Lock() - self._frontlock = threading.Lock() - self._pushbatch = [] + self._name = name + #self._batchsize = batchsize + # self._pushbatch = [] - self._consumer = {} # {op_name: idx} + self._cv = threading.Condition() + + self._producers = [] + self._producer_res_count = {} # {data_id: count} + self._push_res = {} # {data_id: {op_name: data}} + + self._front_wait_interval = 0.1 # second + self._consumers = {} # {op_name: idx} + self._idx_consumer_num = {} # {idx: num} self._consumer_base_idx = 0 - self._frontbatch = [] - self._idx_consumer_num = {} + self._front_res = [] + + def get_producers(self): + return self._producers + + def get_consumers(self): + return self._consumers.keys() + + def _log(self, info_str): + return "[{}] {}".format(self._name, info_str) + + def debug(self): + return self._log("p: {}, c: {}".format(self.get_producers(), + self.get_consumers())) + + def add_producer(self, op_name): + """ not thread safe, and can only be called during initialization """ + if op_name in self._producers: + raise ValueError( + self._log("producer({}) is already in channel".format(op_name))) + self._producers.append(op_name) def add_consumer(self, op_name): - """ not thread safe """ - if op_name in self._consumer: - raise ValueError("op_name({}) is already in channel".format( - op_name)) - self._consumer_id[op_name] = 0 + """ not thread safe, and can only be called during initialization """ + if op_name in self._consumers: + raise ValueError( + self._log("consumer({}) is already in channel".format(op_name))) + self._consumers[op_name] = 0 if self._idx_consumer_num.get(0) is None: self._idx_consumer_num[0] = 0 self._idx_consumer_num[0] += 1 - def push(self, item): - with self._pushlock: - self._pushbatch.append(item) - if len(self._pushbatch) == self._batchsize: - self.put(self._pushbatch, timeout=self._timeout) - self._pushbatch = [] - - def front(self, op_name): - if len(self._consumer) == 0: + def push(self, data, op_name=None): + logging.debug( + self._log("{} try to push data: {}".format(op_name, data))) + if len(self._producers) == 0: raise Exception( - "expected number of consumers to be greater than 0, but the it is 0." - ) - elif len(self._consumer) == 1: - return self.get(timeout=self._timeout) + self._log( + "expected number of producers to be greater than 0, but the it is 0." + )) + elif len(self._producers) == 1: + self._cv.acquire() + while True: + try: + self.put(data, timeout=0) + break + except Queue.Empty: + self._cv.wait() + self._cv.notify_all() + self._cv.release() + logging.debug(self._log("{} push data succ!".format(op_name))) + return True + elif op_name is None: + raise Exception( + self._log( + "There are multiple producers, so op_name cannot be None.")) - with self._frontlock: - consumer_idx = self._consumer[op_name] - base_idx = self._consumer_base_idx - data_idx = consumer_idx - base_idx + producer_num = len(self._producers) + data_id = data.id + put_data = None + self._cv.acquire() + logging.debug(self._log("{} get lock ~".format(op_name))) + if data_id not in self._push_res: + self._push_res[data_id] = {name: None for name in self._producers} + self._producer_res_count[data_id] = 0 + self._push_res[data_id][op_name] = data + if self._producer_res_count[data_id] + 1 == producer_num: + put_data = self._push_res[data_id] + self._push_res.pop(data_id) + self._producer_res_count.pop(data_id) + else: + self._producer_res_count[data_id] += 1 - if data_idx >= len(self._frontbatch): - batch_data = self.get(timeout=self._timeout) - self._frontbatch.append(batch_data) + if put_data is None: + logging.debug( + self._log("{} push data succ, not not push to queue.".format( + op_name))) + else: + while True: + try: + self.put(put_data, timeout=0) + break + except Queue.Empty: + self._cv.wait() + + logging.debug( + self._log("multi | {} push data succ!".format(op_name))) + self._cv.notify_all() + self._cv.release() + return True - resp = self._frontbatch[data_idx] + def front(self, op_name=None): + logging.debug(self._log("{} try to get data".format(op_name))) + if len(self._consumers) == 0: + raise Exception( + self._log( + "expected number of consumers to be greater than 0, but the it is 0." + )) + elif len(self._consumers) == 1: + self._cv.acquire() + resp = None + while resp is None: + try: + resp = self.get(timeout=0) + break + except Queue.Empty: + self._cv.wait() + logging.debug(self._log("{} get data succ!".format(op_name))) + return resp + elif op_name is None: + raise Exception( + self._log( + "There are multiple consumers, so op_name cannot be None.")) - self._idx_consumer_num[consumer_idx] -= 1 - if consumer_idx == base_idx and self._idx_consumer_num[ - consumer_idx] == 0: - self._idx_consumer_num.pop(consumer_idx) - self._frontbatch.pop(0) - self._consumer_base_idx += 1 + self._cv.acquire() + # data_idx = consumer_idx - base_idx + while self._consumers[op_name] - self._consumer_base_idx >= len( + self._front_res): + try: + data = self.get(timeout=0) + self._front_res.append(data) + break + except Queue.Empty: + self._cv.wait() + + consumer_idx = self._consumers[op_name] + base_idx = self._consumer_base_idx + data_idx = consumer_idx - base_idx + resp = self._front_res[data_idx] + logging.debug(self._log("{} get data: {}".format(op_name, resp))) + + self._idx_consumer_num[consumer_idx] -= 1 + if consumer_idx == base_idx and self._idx_consumer_num[ + consumer_idx] == 0: + self._idx_consumer_num.pop(consumer_idx) + self._front_res.pop(0) + self._consumer_base_idx += 1 + + self._consumers[op_name] += 1 + new_consumer_idx = self._consumers[op_name] + if self._idx_consumer_num.get(new_consumer_idx) is None: + self._idx_consumer_num[new_consumer_idx] = 0 + self._idx_consumer_num[new_consumer_idx] += 1 - self._consumer[op_name] += 1 - new_consumer_idx = self._consumer[op_name] - if self._idx_consumer_num.get(new_consumer_idx) is None: - self._idx_consumer_num[new_consumer_idx] = 0 - self._idx_consumer_num[new_consumer_idx] += 1 + self._cv.notify_all() + self._cv.release() + logging.debug(self._log("multi | {} get data succ!".format(op_name))) return resp # reference, read only class Op(object): def __init__(self, name, - inputs, + input, in_dtype, outputs, out_dtype, @@ -115,11 +236,11 @@ class Op(object): # TODO: globally unique check self._name = name # to identify the type of OP, it must be globally unique self._concurrency = concurrency # amount of concurrency - self.set_inputs(inputs) + self.set_input(input) self._in_dtype = in_dtype self.set_outputs(outputs) self._out_dtype = out_dtype - self._batch_size = batchsize + # self._batch_size = batchsize self._client = None if client_config is not None and \ server_name is not None and \ @@ -128,7 +249,6 @@ class Op(object): self._server_model = server_model self._server_port = server_port self._device = device - self._data_ids = [] def set_client(self, client_config, server_name, fetch_names): self._client = Client() @@ -139,59 +259,57 @@ class Op(object): def with_serving(self): return self._client is not None - def get_inputs(self): - return self._inputs + def get_input(self): + return self._input - def set_inputs(self, channels): - if not isinstance(channels, list): - raise TypeError('channels must be list type') - for channel in channels: - channel.add_consumer(self._name) - self._inputs = channels + def set_input(self, channel): + if not isinstance(channel, Channel): + raise TypeError( + self._log('input channel must be Channel type, not {}'.format( + type(channel)))) + channel.add_consumer(self._name) + self._input = channel def get_outputs(self): return self._outputs def set_outputs(self, channels): if not isinstance(channels, list): - raise TypeError('channels must be list type') + raise TypeError( + self._log('output channels must be list type, not {}'.format( + type(channels)))) + for channel in channels: + channel.add_producer(self._name) self._outputs = channels - def get_data_ids(self): - return self._data_ids - - def clear_data_ids(self): - self._data_ids = [] - - def append_id_to_data_ids(self, data_id): - self._data_ids.append(data_id) - - def preprocess(self, input_data): - if len(input_data) != 1: + def preprocess(self, data): + if isinstance(data, dict): raise Exception( - 'this Op has multiple previous channels. Please override this method' - ) - feed_batch = [] - self.clear_data_ids() - for data in input_data: - if len(data.insts) != self._batch_size: - raise Exception('len(data_insts) != self._batch_size') - feed = {} - for inst in data.insts: - feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype) - feed_batch.append(feed) - self.append_id_to_data_ids(data.id) - return feed_batch + self._log( + 'this Op has multiple previous inputs. Please override this method' + )) + feed = {} + for inst in data.insts: + feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype) + return feed def midprocess(self, data): - # data = preprocess(input), which must be a dict - logging.debug('data: {}'.format(data)) - logging.debug('fetch: {}'.format(self._fetch_names)) + if not isinstance(data, dict): + raise Exception( + self._log( + 'data must be dict type(the output of preprocess()), but get {}'. + format(type(data)))) + logging.debug(self._log('data: {}'.format(data))) + logging.debug(self._log('fetch: {}'.format(self._fetch_names))) fetch_map = self._client.predict(feed=data, fetch=self._fetch_names) + logging.debug(self._log("finish predict")) return fetch_map def postprocess(self, output_data): - return output_data + raise Exception( + self._log( + 'Please override this method to convert data to the format in channel.' + )) def stop(self): self._run = False @@ -199,22 +317,33 @@ class Op(object): def start(self): self._run = True while self._run: - input_data = [] - for channel in self._inputs: - input_data.append(channel.front(self._name)) - if len(input_data) > 1: - data = self.preprocess(input_data) + input_data = self._input.front(self._name) + data_id = None + logging.debug(self._log("input_data: {}".format(input_data))) + if isinstance(input_data, dict): + key = input_data.keys()[0] + data_id = input_data[key].id else: - data = self.preprocess(input_data[0]) + data_id = input_data.id + data = self.preprocess(input_data) if self.with_serving(): - fetch_map = self.midprocess(data) - output_data = self.postprocess(fetch_map) - else: - output_data = self.postprocess(data) + data = self.midprocess(data) + output_data = self.postprocess(data) + + if not isinstance(output_data, + python_service_channel_pb2.ChannelData): + raise TypeError( + self._log( + 'output_data must be ChannelData type, but get {}'. + format(type(output_data)))) + output_data.id = data_id for channel in self._outputs: - channel.push(output_data) + channel.push(output_data, self._name) + + def _log(self, info_str): + return "[{}] {}".format(self._name, info_str) def get_concurrency(self): return self._concurrency @@ -224,8 +353,11 @@ class GeneralPythonService( general_python_service_pb2_grpc.GeneralPythonService): def __init__(self, in_channel, out_channel): super(GeneralPythonService, self).__init__() + self._name = "__GeneralPythonService__" self.set_in_channel(in_channel) self.set_out_channel(out_channel) + logging.debug(self._log(in_channel.debug())) + logging.debug(self._log(out_channel.debug())) #TODO: # multi-lock for different clients # diffenert lock for server and client @@ -236,29 +368,35 @@ class GeneralPythonService( self._recive_func = threading.Thread( target=GeneralPythonService._recive_out_channel_func, args=(self, )) self._recive_func.start() - logging.debug('succ init') + + def _log(self, info_str): + return "[{}] {}".format(self._name, info_str) def set_in_channel(self, in_channel): + if not isinstance(in_channel, Channel): + raise TypeError( + self._log('in_channel must be Channel type, but get {}'.format( + type(in_channel)))) + in_channel.add_producer(self._name) self._in_channel = in_channel def set_out_channel(self, out_channel): - if isinstance(out_channel, list): - raise TypeError('out_channel can not be list type') - out_channel.add_consumer("__GeneralPythonService__") + if not isinstance(out_channel, Channel): + raise TypeError( + self._log('out_channel must be Channel type, but get {}'.format( + type(out_channel)))) + out_channel.add_consumer(self._name) self._out_channel = out_channel def _recive_out_channel_func(self): while True: - data = self._out_channel.front() - data_id = None - for d in data: - if data_id is None: - data_id = d.id - if data_id != d.id: - raise Exception("id not match: {} vs {}".format(data_id, - d.id)) + data = self._out_channel.front(self._name) + if not isinstance(data, python_service_channel_pb2.ChannelData): + raise TypeError( + self._log('data must be ChannelData type, but get {}'. + format(type(data)))) self._cv.acquire() - self._globel_resp_dict[data_id] = data + self._globel_resp_dict[data.id] = data self._cv.notify_all() self._cv.release() @@ -277,13 +415,14 @@ class GeneralPythonService( return resp def _pack_data_for_infer(self, request): - logging.debug('start inferce') + logging.debug(self._log('start inferce')) data = python_service_channel_pb2.ChannelData() data_id = self._get_next_id() data.id = data_id for idx, name in enumerate(request.feed_var_names): - logging.debug('name: {}'.format(request.feed_var_names[idx])) - logging.debug('data: {}'.format(request.feed_insts[idx])) + logging.debug( + self._log('name: {}'.format(request.feed_var_names[idx]))) + logging.debug(self._log('data: {}'.format(request.feed_insts[idx]))) inst = python_service_channel_pb2.Inst() inst.data = request.feed_insts[idx] inst.name = name @@ -291,23 +430,22 @@ class GeneralPythonService( return data, data_id def _pack_data_for_resp(self, data): - data = data[0] #TODO batchsize = 1 - logging.debug('get data') + logging.debug(self._log('get data')) resp = general_python_service_pb2.Response() - logging.debug('gen resp') + logging.debug(self._log('gen resp')) logging.debug(data) for inst in data.insts: - logging.debug('append data') + logging.debug(self._log('append data')) resp.fetch_insts.append(inst.data) - logging.debug('append name') + logging.debug(self._log('append name')) resp.fetch_var_names.append(inst.name) return resp def inference(self, request, context): data, data_id = self._pack_data_for_infer(request) - logging.debug('push data') - self._in_channel.push(data) - logging.debug('wait for infer') + logging.debug(self._log('push data')) + self._in_channel.push(data, self._name) + logging.debug(self._log('wait for infer')) resp_data = None resp_data = self._get_data_in_globel_resp_dict(data_id) resp = self._pack_data_for_resp(resp_data) @@ -340,7 +478,7 @@ class PyServer(object): inputs = set() outputs = set() for op in self._ops: - inputs |= set(op.get_inputs()) + inputs |= set([op.get_input()]) outputs |= set(op.get_outputs()) if op.with_serving(): self.prepare_serving(op) @@ -360,6 +498,8 @@ class PyServer(object): def _run_ops(self): for op in self._ops: op_concurrency = op.get_concurrency() + logging.debug("run op: {}, op_concurrency: {}".format( + op._name, op_concurrency)) for c in range(op_concurrency): # th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, )) th = threading.Thread( @@ -387,13 +527,13 @@ class PyServer(object): port = op._server_port device = op._device - # run a server (not in PyServing) if device == "cpu": cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format( model_path, port) else: cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format( model_path, port) - logging.info(cmd) + # run a server (not in PyServing) + logging.info("run a server (not in PyServing): {}".format(cmd)) return # os.system(cmd)