diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index 3d90161c88e53f3edde9557e69885950782f4459..dd8ea7bcba9bf95c2ce5eba0ab144660342b6824 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -26,21 +26,10 @@ logging.basicConfig( level=logging.INFO) # channel data: {name(str): data(bytes)} -""" -class ImdbOp(Op): - def postprocess(self, output_data): - data = python_service_channel_pb2.ChannelData() - inst = python_service_channel_pb2.Inst() - pred = np.array(output_data["prediction"][0][0], dtype='float') - inst.data = np.ndarray.tobytes(pred) - inst.name = "prediction" - inst.id = 0 #TODO - data.insts.append(inst) - return data -""" class CombineOp(Op): + #TODO: different id of data def preprocess(self, input_data): data_id = None cnt = 0 @@ -58,7 +47,6 @@ class CombineOp(Op): inst.name = "resp" data.insts.append(inst) data.id = data_id - print(data) return data @@ -75,11 +63,13 @@ class UciOp(Op): return data -read_channel = Channel(consumer=2) +read_channel = Channel() cnn_out_channel = Channel() bow_out_channel = Channel() combine_out_channel = Channel() + cnn_op = UciOp( + name="cnn_op", inputs=[read_channel], in_dtype='float', outputs=[cnn_out_channel], @@ -90,7 +80,9 @@ cnn_op = UciOp( client_config="uci_housing_client/serving_client_conf.prototxt", server_name="127.0.0.1:9393", fetch_names=["price"]) + bow_op = UciOp( + name="bow_op", inputs=[read_channel], in_dtype='float', outputs=[bow_out_channel], @@ -101,27 +93,9 @@ bow_op = UciOp( client_config="uci_housing_client/serving_client_conf.prototxt", server_name="127.0.0.1:9393", fetch_names=["price"]) -''' -cnn_op = ImdbOp( - inputs=[read_channel], - outputs=[cnn_out_channel], - server_model="./imdb_cnn_model", - server_port="9393", - device="cpu", - client_config="imdb_cnn_client_conf/serving_client_conf.prototxt", - server_name="127.0.0.1:9393", - fetch_names=["acc", "cost", "prediction"]) -bow_op = ImdbOp( - inputs=[read_channel], - outputs=[bow_out_channel], - server_model="./imdb_bow_model", - server_port="9292", - device="cpu", - client_config="imdb_bow_client_conf/serving_client_conf.prototxt", - server_name="127.0.0.1:9292", - fetch_names=["acc", "cost", "prediction"]) -''' + combine_op = CombineOp( + name="combine_op", inputs=[cnn_out_channel, bow_out_channel], in_dtype='float', outputs=[combine_out_channel], diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index 3a8006a44f6427b24799462aee30288e063c10fe..c33a5d7c295e51dde72967d874b08678efa357e3 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -29,43 +29,76 @@ import time class Channel(Queue.Queue): - def __init__(self, consumer=1, maxsize=-1, timeout=None, batchsize=1): + def __init__(self, maxsize=-1, timeout=None, batchsize=1): Queue.Queue.__init__(self, maxsize=maxsize) - # super(Channel, self).__init__(maxsize=maxsize) self._maxsize = maxsize self._timeout = timeout self._batchsize = batchsize - self._consumer = consumer self._pushlock = threading.Lock() self._frontlock = threading.Lock() self._pushbatch = [] - self._frontbatch = None - self._count = 0 - self._order = 0 + + self._consumer = {} # {op_name: idx} + self._consumer_base_idx = 0 + self._frontbatch = [] + self._idx_consumer_num = {} + + def add_consumer(self, op_name): + """ not thread safe """ + if op_name in self._consumer: + raise ValueError("op_name({}) is already in channel".format( + op_name)) + self._consumer_id[op_name] = 0 + + if self._idx_consumer_num.get(0) is None: + self._idx_consumer_num[0] = 0 + self._idx_consumer_num[0] += 1 def push(self, item): with self._pushlock: self._pushbatch.append(item) - self._order += 1 if len(self._pushbatch) == self._batchsize: self.put(self._pushbatch, timeout=self._timeout) - # self.put(self._pushbatch) self._pushbatch = [] - def front(self): - if self._consumer == 1: + def front(self, op_name): + if len(self._consumer) == 0: + raise Exception( + "expected number of consumers to be greater than 0, but the it is 0." + ) + elif len(self._consumer) == 1: return self.get(timeout=self._timeout) + with self._frontlock: - if self._count == 0: - self._frontbatch = self.get(timeout=self._timeout) - self._count += 1 - if self._count == self._consumer: - self._count = 0 - return self._frontbatch + consumer_idx = self._consumer[op_name] + base_idx = self._consumer_base_idx + data_idx = consumer_idx - base_idx + + if data_idx >= len(self._frontbatch): + batch_data = self.get(timeout=self._timeout) + self._frontbatch.append(batch_data) + + resp = self._frontbatch[data_idx] + + self._idx_consumer_num[consumer_idx] -= 1 + if consumer_idx == base_idx and self._idx_consumer_num[ + consumer_idx] == 0: + self._idx_consumer_num.pop(consumer_idx) + self._frontbatch.pop(0) + self._consumer_base_idx += 1 + + self._consumer[op_name] += 1 + new_consumer_idx = self._consumer[op_name] + if self._idx_consumer_num.get(new_consumer_idx) is None: + self._idx_consumer_num[new_consumer_idx] = 0 + self._idx_consumer_num[new_consumer_idx] += 1 + + return resp # reference, read only class Op(object): def __init__(self, + name, inputs, in_dtype, outputs, @@ -76,8 +109,12 @@ class Op(object): device=None, client_config=None, server_name=None, - fetch_names=None): + fetch_names=None, + concurrency=1): self._run = False + # TODO: globally unique check + self._name = name # to identify the type of OP, it must be globally unique + self._concurrency = concurrency # amount of concurrency self.set_inputs(inputs) self._in_dtype = in_dtype self.set_outputs(outputs) @@ -108,6 +145,8 @@ class Op(object): def set_inputs(self, channels): if not isinstance(channels, list): raise TypeError('channels must be list type') + for channel in channels: + channel.add_consumer(self._name) self._inputs = channels def get_outputs(self): @@ -162,7 +201,7 @@ class Op(object): while self._run: input_data = [] for channel in self._inputs: - input_data.append(channel.front()) + input_data.append(channel.front(self._name)) if len(input_data) > 1: data = self.preprocess(input_data) else: @@ -177,13 +216,16 @@ class Op(object): for channel in self._outputs: channel.push(output_data) + def get_concurrency(self): + return self._concurrency + class GeneralPythonService( general_python_service_pb2_grpc.GeneralPythonService): def __init__(self, in_channel, out_channel): super(GeneralPythonService, self).__init__() - self._in_channel = in_channel - self._out_channel = out_channel + self.set_in_channel(in_channel) + self.set_out_channel(out_channel) #TODO: # multi-lock for different clients # diffenert lock for server and client @@ -196,6 +238,15 @@ class GeneralPythonService( self._recive_func.start() logging.debug('succ init') + def set_in_channel(self, in_channel): + self._in_channel = in_channel + + def set_out_channel(self, out_channel): + if isinstance(out_channel, list): + raise TypeError('out_channel can not be list type') + out_channel.add_consumer("__GeneralPythonService__") + self._out_channel = out_channel + def _recive_out_channel_func(self): while True: data = self._out_channel.front() @@ -303,15 +354,21 @@ class PyServer(object): self._out_channel = out_channel.pop() self.gen_desc() - def op_start_wrapper(self, op): + def _op_start_wrapper(self, op): return op.start() - def run_server(self): + def _run_ops(self): for op in self._ops: - # th = multiprocessing.Process(target=self.op_start_wrapper, args=(op, )) - th = threading.Thread(target=self.op_start_wrapper, args=(op, )) - th.start() - self._op_threads.append(th) + op_concurrency = op.get_concurrency() + for c in range(op_concurrency): + # th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, )) + th = threading.Thread( + target=self._op_start_wrapper, args=(op, )) + th.start() + self._op_threads.append(th) + + def run_server(self): + self._run_ops() server = grpc.server( futures.ThreadPoolExecutor(max_workers=self._worker_num)) general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( @@ -339,4 +396,4 @@ class PyServer(object): model_path, port) logging.info(cmd) return - os.system(cmd) + # os.system(cmd)