提交 da0949c4 编写于 作者: B barrierye

support different type ops concurrent fetch data from one channel && TODO:...

support different type ops concurrent fetch data from one channel && TODO: some op likes combine_op should get data with the sanme data_id not support yet
上级 d1c969ac
...@@ -26,21 +26,10 @@ logging.basicConfig( ...@@ -26,21 +26,10 @@ logging.basicConfig(
level=logging.INFO) level=logging.INFO)
# channel data: {name(str): data(bytes)} # channel data: {name(str): data(bytes)}
"""
class ImdbOp(Op):
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["prediction"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
return data
"""
class CombineOp(Op): class CombineOp(Op):
#TODO: different id of data
def preprocess(self, input_data): def preprocess(self, input_data):
data_id = None data_id = None
cnt = 0 cnt = 0
...@@ -58,7 +47,6 @@ class CombineOp(Op): ...@@ -58,7 +47,6 @@ class CombineOp(Op):
inst.name = "resp" inst.name = "resp"
data.insts.append(inst) data.insts.append(inst)
data.id = data_id data.id = data_id
print(data)
return data return data
...@@ -75,11 +63,13 @@ class UciOp(Op): ...@@ -75,11 +63,13 @@ class UciOp(Op):
return data return data
read_channel = Channel(consumer=2) read_channel = Channel()
cnn_out_channel = Channel() cnn_out_channel = Channel()
bow_out_channel = Channel() bow_out_channel = Channel()
combine_out_channel = Channel() combine_out_channel = Channel()
cnn_op = UciOp( cnn_op = UciOp(
name="cnn_op",
inputs=[read_channel], inputs=[read_channel],
in_dtype='float', in_dtype='float',
outputs=[cnn_out_channel], outputs=[cnn_out_channel],
...@@ -90,7 +80,9 @@ cnn_op = UciOp( ...@@ -90,7 +80,9 @@ cnn_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"]) fetch_names=["price"])
bow_op = UciOp( bow_op = UciOp(
name="bow_op",
inputs=[read_channel], inputs=[read_channel],
in_dtype='float', in_dtype='float',
outputs=[bow_out_channel], outputs=[bow_out_channel],
...@@ -101,27 +93,9 @@ bow_op = UciOp( ...@@ -101,27 +93,9 @@ bow_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"]) fetch_names=["price"])
'''
cnn_op = ImdbOp(
inputs=[read_channel],
outputs=[cnn_out_channel],
server_model="./imdb_cnn_model",
server_port="9393",
device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["acc", "cost", "prediction"])
bow_op = ImdbOp(
inputs=[read_channel],
outputs=[bow_out_channel],
server_model="./imdb_bow_model",
server_port="9292",
device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["acc", "cost", "prediction"])
'''
combine_op = CombineOp( combine_op = CombineOp(
name="combine_op",
inputs=[cnn_out_channel, bow_out_channel], inputs=[cnn_out_channel, bow_out_channel],
in_dtype='float', in_dtype='float',
outputs=[combine_out_channel], outputs=[combine_out_channel],
......
...@@ -29,43 +29,76 @@ import time ...@@ -29,43 +29,76 @@ import time
class Channel(Queue.Queue): class Channel(Queue.Queue):
def __init__(self, consumer=1, maxsize=-1, timeout=None, batchsize=1): def __init__(self, maxsize=-1, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
# super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self._batchsize = batchsize self._batchsize = batchsize
self._consumer = consumer
self._pushlock = threading.Lock() self._pushlock = threading.Lock()
self._frontlock = threading.Lock() self._frontlock = threading.Lock()
self._pushbatch = [] self._pushbatch = []
self._frontbatch = None
self._count = 0 self._consumer = {} # {op_name: idx}
self._order = 0 self._consumer_base_idx = 0
self._frontbatch = []
self._idx_consumer_num = {}
def add_consumer(self, op_name):
""" not thread safe """
if op_name in self._consumer:
raise ValueError("op_name({}) is already in channel".format(
op_name))
self._consumer_id[op_name] = 0
if self._idx_consumer_num.get(0) is None:
self._idx_consumer_num[0] = 0
self._idx_consumer_num[0] += 1
def push(self, item): def push(self, item):
with self._pushlock: with self._pushlock:
self._pushbatch.append(item) self._pushbatch.append(item)
self._order += 1
if len(self._pushbatch) == self._batchsize: if len(self._pushbatch) == self._batchsize:
self.put(self._pushbatch, timeout=self._timeout) self.put(self._pushbatch, timeout=self._timeout)
# self.put(self._pushbatch)
self._pushbatch = [] self._pushbatch = []
def front(self): def front(self, op_name):
if self._consumer == 1: if len(self._consumer) == 0:
raise Exception(
"expected number of consumers to be greater than 0, but the it is 0."
)
elif len(self._consumer) == 1:
return self.get(timeout=self._timeout) return self.get(timeout=self._timeout)
with self._frontlock: with self._frontlock:
if self._count == 0: consumer_idx = self._consumer[op_name]
self._frontbatch = self.get(timeout=self._timeout) base_idx = self._consumer_base_idx
self._count += 1 data_idx = consumer_idx - base_idx
if self._count == self._consumer:
self._count = 0 if data_idx >= len(self._frontbatch):
return self._frontbatch batch_data = self.get(timeout=self._timeout)
self._frontbatch.append(batch_data)
resp = self._frontbatch[data_idx]
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx)
self._frontbatch.pop(0)
self._consumer_base_idx += 1
self._consumer[op_name] += 1
new_consumer_idx = self._consumer[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
return resp # reference, read only
class Op(object): class Op(object):
def __init__(self, def __init__(self,
name,
inputs, inputs,
in_dtype, in_dtype,
outputs, outputs,
...@@ -76,8 +109,12 @@ class Op(object): ...@@ -76,8 +109,12 @@ class Op(object):
device=None, device=None,
client_config=None, client_config=None,
server_name=None, server_name=None,
fetch_names=None): fetch_names=None,
concurrency=1):
self._run = False self._run = False
# TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_inputs(inputs) self.set_inputs(inputs)
self._in_dtype = in_dtype self._in_dtype = in_dtype
self.set_outputs(outputs) self.set_outputs(outputs)
...@@ -108,6 +145,8 @@ class Op(object): ...@@ -108,6 +145,8 @@ class Op(object):
def set_inputs(self, channels): def set_inputs(self, channels):
if not isinstance(channels, list): if not isinstance(channels, list):
raise TypeError('channels must be list type') raise TypeError('channels must be list type')
for channel in channels:
channel.add_consumer(self._name)
self._inputs = channels self._inputs = channels
def get_outputs(self): def get_outputs(self):
...@@ -162,7 +201,7 @@ class Op(object): ...@@ -162,7 +201,7 @@ class Op(object):
while self._run: while self._run:
input_data = [] input_data = []
for channel in self._inputs: for channel in self._inputs:
input_data.append(channel.front()) input_data.append(channel.front(self._name))
if len(input_data) > 1: if len(input_data) > 1:
data = self.preprocess(input_data) data = self.preprocess(input_data)
else: else:
...@@ -177,13 +216,16 @@ class Op(object): ...@@ -177,13 +216,16 @@ class Op(object):
for channel in self._outputs: for channel in self._outputs:
channel.push(output_data) channel.push(output_data)
def get_concurrency(self):
return self._concurrency
class GeneralPythonService( class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel): def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__() super(GeneralPythonService, self).__init__()
self._in_channel = in_channel self.set_in_channel(in_channel)
self._out_channel = out_channel self.set_out_channel(out_channel)
#TODO: #TODO:
# multi-lock for different clients # multi-lock for different clients
# diffenert lock for server and client # diffenert lock for server and client
...@@ -196,6 +238,15 @@ class GeneralPythonService( ...@@ -196,6 +238,15 @@ class GeneralPythonService(
self._recive_func.start() self._recive_func.start()
logging.debug('succ init') logging.debug('succ init')
def set_in_channel(self, in_channel):
self._in_channel = in_channel
def set_out_channel(self, out_channel):
if isinstance(out_channel, list):
raise TypeError('out_channel can not be list type')
out_channel.add_consumer("__GeneralPythonService__")
self._out_channel = out_channel
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
while True: while True:
data = self._out_channel.front() data = self._out_channel.front()
...@@ -303,15 +354,21 @@ class PyServer(object): ...@@ -303,15 +354,21 @@ class PyServer(object):
self._out_channel = out_channel.pop() self._out_channel = out_channel.pop()
self.gen_desc() self.gen_desc()
def op_start_wrapper(self, op): def _op_start_wrapper(self, op):
return op.start() return op.start()
def run_server(self): def _run_ops(self):
for op in self._ops: for op in self._ops:
# th = multiprocessing.Process(target=self.op_start_wrapper, args=(op, )) op_concurrency = op.get_concurrency()
th = threading.Thread(target=self.op_start_wrapper, args=(op, )) for c in range(op_concurrency):
# th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
th = threading.Thread(
target=self._op_start_wrapper, args=(op, ))
th.start() th.start()
self._op_threads.append(th) self._op_threads.append(th)
def run_server(self):
self._run_ops()
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
...@@ -339,4 +396,4 @@ class PyServer(object): ...@@ -339,4 +396,4 @@ class PyServer(object):
model_path, port) model_path, port)
logging.info(cmd) logging.info(cmd)
return return
os.system(cmd) # os.system(cmd)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册