提交 02f9966c 编写于 作者: B barrierye

Support multiple different Op feed data or fetch data

上级 20985a8c
......@@ -43,8 +43,9 @@ data = np.ndarray.tobytes(x)
req.feed_var_names.append("x")
req.feed_insts.append(data)
resp = stub.inference(req)
for idx, name in enumerate(resp.fetch_var_names):
print('{}: {}'.format(
name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
for i in range(100):
resp = stub.inference(req)
for idx, name in enumerate(resp.fetch_var_names):
print('{}: {}'.format(
name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
......@@ -29,83 +29,80 @@ logging.basicConfig(
class CombineOp(Op):
#TODO: different id of data
def preprocess(self, input_data):
data_id = None
cnt = 0
for input in input_data:
data = input[0] # batchsize=1
for op_name, data in input_data.items():
logging.debug("CombineOp preprocess: {}".format(op_name))
cnt += np.frombuffer(data.insts[0].data, dtype='float')
if data_id is None:
data_id = data.id
if data_id != data.id:
raise Exception("id not match: {} vs {}".format(data_id,
data.id))
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp"
data.insts.append(inst)
data.id = data_id
return data
def postprocess(self, output_data):
return output_data
class UciOp(Op):
def postprocess(self, output_data):
data_ids = self.get_data_ids()
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
data.insts.append(inst)
data.id = data_ids[0]
return data
read_channel = Channel()
cnn_out_channel = Channel()
bow_out_channel = Channel()
combine_out_channel = Channel()
read_channel = Channel(name="read_channel")
combine_channel = Channel(name="combine_channel")
out_channel = Channel(name="out_channel")
cnn_op = UciOp(
name="cnn_op",
inputs=[read_channel],
input=read_channel,
in_dtype='float',
outputs=[cnn_out_channel],
outputs=[combine_channel],
out_dtype='float',
server_model="./uci_housing_model",
server_port="9393",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
fetch_names=["price"],
concurrency=2)
bow_op = UciOp(
name="bow_op",
inputs=[read_channel],
input=read_channel,
in_dtype='float',
outputs=[bow_out_channel],
outputs=[combine_channel],
out_dtype='float',
server_model="./uci_housing_model",
server_port="9292",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
fetch_names=["price"],
concurrency=2)
combine_op = CombineOp(
name="combine_op",
inputs=[cnn_out_channel, bow_out_channel],
input=combine_channel,
in_dtype='float',
outputs=[combine_out_channel],
out_dtype='float')
outputs=[out_channel],
out_dtype='float',
concurrency=2)
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
logging.info(out_channel.debug())
pyserver = PyServer()
pyserver.add_channel(read_channel)
pyserver.add_channel(cnn_out_channel)
pyserver.add_channel(bow_out_channel)
pyserver.add_channel(combine_out_channel)
pyserver.add_channel(combine_channel)
pyserver.add_channel(out_channel)
pyserver.add_op(cnn_op)
pyserver.add_op(bow_op)
pyserver.add_op(combine_op)
......
......@@ -25,81 +25,202 @@ import general_python_service_pb2
import general_python_service_pb2_grpc
import python_service_channel_pb2
import logging
import random
import time
class Channel(Queue.Queue):
def __init__(self, maxsize=-1, timeout=None, batchsize=1):
"""
The channel used for communication between Ops.
1. Support multiple different Op feed data (multiple producer)
Different types of data will be packaged through the data ID
2. Support multiple different Op fetch data (multiple consumer)
Only when all types of Ops get the data of the same ID,
the data will be poped; The Op of the same type will not
get the data of the same ID.
3. (TODO) Timeout and BatchSize are not fully supported.
Note:
1. The ID of the data in the channel must be different.
2. The function add_producer() and add_consumer() are not thread safe,
and can only be called during initialization.
"""
def __init__(self, name=None, maxsize=-1, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self._batchsize = batchsize
self._pushlock = threading.Lock()
self._frontlock = threading.Lock()
self._pushbatch = []
self._name = name
#self._batchsize = batchsize
# self._pushbatch = []
self._consumer = {} # {op_name: idx}
self._cv = threading.Condition()
self._producers = []
self._producer_res_count = {} # {data_id: count}
self._push_res = {} # {data_id: {op_name: data}}
self._front_wait_interval = 0.1 # second
self._consumers = {} # {op_name: idx}
self._idx_consumer_num = {} # {idx: num}
self._consumer_base_idx = 0
self._frontbatch = []
self._idx_consumer_num = {}
self._front_res = []
def get_producers(self):
return self._producers
def get_consumers(self):
return self._consumers.keys()
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(),
self.get_consumers()))
def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization """
if op_name in self._producers:
raise ValueError(
self._log("producer({}) is already in channel".format(op_name)))
self._producers.append(op_name)
def add_consumer(self, op_name):
""" not thread safe """
if op_name in self._consumer:
raise ValueError("op_name({}) is already in channel".format(
op_name))
self._consumer_id[op_name] = 0
""" not thread safe, and can only be called during initialization """
if op_name in self._consumers:
raise ValueError(
self._log("consumer({}) is already in channel".format(op_name)))
self._consumers[op_name] = 0
if self._idx_consumer_num.get(0) is None:
self._idx_consumer_num[0] = 0
self._idx_consumer_num[0] += 1
def push(self, item):
with self._pushlock:
self._pushbatch.append(item)
if len(self._pushbatch) == self._batchsize:
self.put(self._pushbatch, timeout=self._timeout)
self._pushbatch = []
def front(self, op_name):
if len(self._consumer) == 0:
def push(self, data, op_name=None):
logging.debug(
self._log("{} try to push data: {}".format(op_name, data)))
if len(self._producers) == 0:
raise Exception(
"expected number of consumers to be greater than 0, but the it is 0."
)
elif len(self._consumer) == 1:
return self.get(timeout=self._timeout)
self._log(
"expected number of producers to be greater than 0, but the it is 0."
))
elif len(self._producers) == 1:
self._cv.acquire()
while True:
try:
self.put(data, timeout=0)
break
except Queue.Empty:
self._cv.wait()
self._cv.notify_all()
self._cv.release()
logging.debug(self._log("{} push data succ!".format(op_name)))
return True
elif op_name is None:
raise Exception(
self._log(
"There are multiple producers, so op_name cannot be None."))
with self._frontlock:
consumer_idx = self._consumer[op_name]
base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx
producer_num = len(self._producers)
data_id = data.id
put_data = None
self._cv.acquire()
logging.debug(self._log("{} get lock ~".format(op_name)))
if data_id not in self._push_res:
self._push_res[data_id] = {name: None for name in self._producers}
self._producer_res_count[data_id] = 0
self._push_res[data_id][op_name] = data
if self._producer_res_count[data_id] + 1 == producer_num:
put_data = self._push_res[data_id]
self._push_res.pop(data_id)
self._producer_res_count.pop(data_id)
else:
self._producer_res_count[data_id] += 1
if data_idx >= len(self._frontbatch):
batch_data = self.get(timeout=self._timeout)
self._frontbatch.append(batch_data)
if put_data is None:
logging.debug(
self._log("{} push data succ, not not push to queue.".format(
op_name)))
else:
while True:
try:
self.put(put_data, timeout=0)
break
except Queue.Empty:
self._cv.wait()
logging.debug(
self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all()
self._cv.release()
return True
resp = self._frontbatch[data_idx]
def front(self, op_name=None):
logging.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumers) == 0:
raise Exception(
self._log(
"expected number of consumers to be greater than 0, but the it is 0."
))
elif len(self._consumers) == 1:
self._cv.acquire()
resp = None
while resp is None:
try:
resp = self.get(timeout=0)
break
except Queue.Empty:
self._cv.wait()
logging.debug(self._log("{} get data succ!".format(op_name)))
return resp
elif op_name is None:
raise Exception(
self._log(
"There are multiple consumers, so op_name cannot be None."))
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx)
self._frontbatch.pop(0)
self._consumer_base_idx += 1
self._cv.acquire()
# data_idx = consumer_idx - base_idx
while self._consumers[op_name] - self._consumer_base_idx >= len(
self._front_res):
try:
data = self.get(timeout=0)
self._front_res.append(data)
break
except Queue.Empty:
self._cv.wait()
consumer_idx = self._consumers[op_name]
base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp)))
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx)
self._front_res.pop(0)
self._consumer_base_idx += 1
self._consumers[op_name] += 1
new_consumer_idx = self._consumers[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
self._consumer[op_name] += 1
new_consumer_idx = self._consumer[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
self._cv.notify_all()
self._cv.release()
logging.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only
class Op(object):
def __init__(self,
name,
inputs,
input,
in_dtype,
outputs,
out_dtype,
......@@ -115,11 +236,11 @@ class Op(object):
# TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_inputs(inputs)
self.set_input(input)
self._in_dtype = in_dtype
self.set_outputs(outputs)
self._out_dtype = out_dtype
self._batch_size = batchsize
# self._batch_size = batchsize
self._client = None
if client_config is not None and \
server_name is not None and \
......@@ -128,7 +249,6 @@ class Op(object):
self._server_model = server_model
self._server_port = server_port
self._device = device
self._data_ids = []
def set_client(self, client_config, server_name, fetch_names):
self._client = Client()
......@@ -139,59 +259,57 @@ class Op(object):
def with_serving(self):
return self._client is not None
def get_inputs(self):
return self._inputs
def get_input(self):
return self._input
def set_inputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
for channel in channels:
channel.add_consumer(self._name)
self._inputs = channels
def set_input(self, channel):
if not isinstance(channel, Channel):
raise TypeError(
self._log('input channel must be Channel type, not {}'.format(
type(channel))))
channel.add_consumer(self._name)
self._input = channel
def get_outputs(self):
return self._outputs
def set_outputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
raise TypeError(
self._log('output channels must be list type, not {}'.format(
type(channels))))
for channel in channels:
channel.add_producer(self._name)
self._outputs = channels
def get_data_ids(self):
return self._data_ids
def clear_data_ids(self):
self._data_ids = []
def append_id_to_data_ids(self, data_id):
self._data_ids.append(data_id)
def preprocess(self, input_data):
if len(input_data) != 1:
def preprocess(self, data):
if isinstance(data, dict):
raise Exception(
'this Op has multiple previous channels. Please override this method'
)
feed_batch = []
self.clear_data_ids()
for data in input_data:
if len(data.insts) != self._batch_size:
raise Exception('len(data_insts) != self._batch_size')
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
feed_batch.append(feed)
self.append_id_to_data_ids(data.id)
return feed_batch
self._log(
'this Op has multiple previous inputs. Please override this method'
))
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
return feed
def midprocess(self, data):
# data = preprocess(input), which must be a dict
logging.debug('data: {}'.format(data))
logging.debug('fetch: {}'.format(self._fetch_names))
if not isinstance(data, dict):
raise Exception(
self._log(
'data must be dict type(the output of preprocess()), but get {}'.
format(type(data))))
logging.debug(self._log('data: {}'.format(data)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
logging.debug(self._log("finish predict"))
return fetch_map
def postprocess(self, output_data):
return output_data
raise Exception(
self._log(
'Please override this method to convert data to the format in channel.'
))
def stop(self):
self._run = False
......@@ -199,22 +317,33 @@ class Op(object):
def start(self):
self._run = True
while self._run:
input_data = []
for channel in self._inputs:
input_data.append(channel.front(self._name))
if len(input_data) > 1:
data = self.preprocess(input_data)
input_data = self._input.front(self._name)
data_id = None
logging.debug(self._log("input_data: {}".format(input_data)))
if isinstance(input_data, dict):
key = input_data.keys()[0]
data_id = input_data[key].id
else:
data = self.preprocess(input_data[0])
data_id = input_data.id
data = self.preprocess(input_data)
if self.with_serving():
fetch_map = self.midprocess(data)
output_data = self.postprocess(fetch_map)
else:
output_data = self.postprocess(data)
data = self.midprocess(data)
output_data = self.postprocess(data)
if not isinstance(output_data,
python_service_channel_pb2.ChannelData):
raise TypeError(
self._log(
'output_data must be ChannelData type, but get {}'.
format(type(output_data))))
output_data.id = data_id
for channel in self._outputs:
channel.push(output_data)
channel.push(output_data, self._name)
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
def get_concurrency(self):
return self._concurrency
......@@ -224,8 +353,11 @@ class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__()
self._name = "__GeneralPythonService__"
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug()))
logging.debug(self._log(out_channel.debug()))
#TODO:
# multi-lock for different clients
# diffenert lock for server and client
......@@ -236,29 +368,35 @@ class GeneralPythonService(
self._recive_func = threading.Thread(
target=GeneralPythonService._recive_out_channel_func, args=(self, ))
self._recive_func.start()
logging.debug('succ init')
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
def set_in_channel(self, in_channel):
if not isinstance(in_channel, Channel):
raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format(
type(in_channel))))
in_channel.add_producer(self._name)
self._in_channel = in_channel
def set_out_channel(self, out_channel):
if isinstance(out_channel, list):
raise TypeError('out_channel can not be list type')
out_channel.add_consumer("__GeneralPythonService__")
if not isinstance(out_channel, Channel):
raise TypeError(
self._log('out_channel must be Channel type, but get {}'.format(
type(out_channel))))
out_channel.add_consumer(self._name)
self._out_channel = out_channel
def _recive_out_channel_func(self):
while True:
data = self._out_channel.front()
data_id = None
for d in data:
if data_id is None:
data_id = d.id
if data_id != d.id:
raise Exception("id not match: {} vs {}".format(data_id,
d.id))
data = self._out_channel.front(self._name)
if not isinstance(data, python_service_channel_pb2.ChannelData):
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
format(type(data))))
self._cv.acquire()
self._globel_resp_dict[data_id] = data
self._globel_resp_dict[data.id] = data
self._cv.notify_all()
self._cv.release()
......@@ -277,13 +415,14 @@ class GeneralPythonService(
return resp
def _pack_data_for_infer(self, request):
logging.debug('start inferce')
logging.debug(self._log('start inferce'))
data = python_service_channel_pb2.ChannelData()
data_id = self._get_next_id()
data.id = data_id
for idx, name in enumerate(request.feed_var_names):
logging.debug('name: {}'.format(request.feed_var_names[idx]))
logging.debug('data: {}'.format(request.feed_insts[idx]))
logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
inst = python_service_channel_pb2.Inst()
inst.data = request.feed_insts[idx]
inst.name = name
......@@ -291,23 +430,22 @@ class GeneralPythonService(
return data, data_id
def _pack_data_for_resp(self, data):
data = data[0] #TODO batchsize = 1
logging.debug('get data')
logging.debug(self._log('get data'))
resp = general_python_service_pb2.Response()
logging.debug('gen resp')
logging.debug(self._log('gen resp'))
logging.debug(data)
for inst in data.insts:
logging.debug('append data')
logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data)
logging.debug('append name')
logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name)
return resp
def inference(self, request, context):
data, data_id = self._pack_data_for_infer(request)
logging.debug('push data')
self._in_channel.push(data)
logging.debug('wait for infer')
logging.debug(self._log('push data'))
self._in_channel.push(data, self._name)
logging.debug(self._log('wait for infer'))
resp_data = None
resp_data = self._get_data_in_globel_resp_dict(data_id)
resp = self._pack_data_for_resp(resp_data)
......@@ -340,7 +478,7 @@ class PyServer(object):
inputs = set()
outputs = set()
for op in self._ops:
inputs |= set(op.get_inputs())
inputs |= set([op.get_input()])
outputs |= set(op.get_outputs())
if op.with_serving():
self.prepare_serving(op)
......@@ -360,6 +498,8 @@ class PyServer(object):
def _run_ops(self):
for op in self._ops:
op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format(
op._name, op_concurrency))
for c in range(op_concurrency):
# th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
th = threading.Thread(
......@@ -387,13 +527,13 @@ class PyServer(object):
port = op._server_port
device = op._device
# run a server (not in PyServing)
if device == "cpu":
cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port)
else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port)
logging.info(cmd)
# run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd))
return
# os.system(cmd)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册