提交 20985a8c 编写于 作者: B barrierye

support different type ops concurrent fetch data from one channel && TODO:...

support different type ops concurrent fetch data from one channel && TODO: some op likes combine_op should get data with the sanme data_id not support yet
上级 57d28af9
......@@ -26,21 +26,10 @@ logging.basicConfig(
level=logging.INFO)
# channel data: {name(str): data(bytes)}
"""
class ImdbOp(Op):
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["prediction"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
return data
"""
class CombineOp(Op):
#TODO: different id of data
def preprocess(self, input_data):
data_id = None
cnt = 0
......@@ -58,7 +47,6 @@ class CombineOp(Op):
inst.name = "resp"
data.insts.append(inst)
data.id = data_id
print(data)
return data
......@@ -75,11 +63,13 @@ class UciOp(Op):
return data
read_channel = Channel(consumer=2)
read_channel = Channel()
cnn_out_channel = Channel()
bow_out_channel = Channel()
combine_out_channel = Channel()
cnn_op = UciOp(
name="cnn_op",
inputs=[read_channel],
in_dtype='float',
outputs=[cnn_out_channel],
......@@ -90,7 +80,9 @@ cnn_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
bow_op = UciOp(
name="bow_op",
inputs=[read_channel],
in_dtype='float',
outputs=[bow_out_channel],
......@@ -101,27 +93,9 @@ bow_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
'''
cnn_op = ImdbOp(
inputs=[read_channel],
outputs=[cnn_out_channel],
server_model="./imdb_cnn_model",
server_port="9393",
device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["acc", "cost", "prediction"])
bow_op = ImdbOp(
inputs=[read_channel],
outputs=[bow_out_channel],
server_model="./imdb_bow_model",
server_port="9292",
device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["acc", "cost", "prediction"])
'''
combine_op = CombineOp(
name="combine_op",
inputs=[cnn_out_channel, bow_out_channel],
in_dtype='float',
outputs=[combine_out_channel],
......
......@@ -29,43 +29,76 @@ import time
class Channel(Queue.Queue):
def __init__(self, consumer=1, maxsize=-1, timeout=None, batchsize=1):
def __init__(self, maxsize=-1, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize)
# super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self._batchsize = batchsize
self._consumer = consumer
self._pushlock = threading.Lock()
self._frontlock = threading.Lock()
self._pushbatch = []
self._frontbatch = None
self._count = 0
self._order = 0
self._consumer = {} # {op_name: idx}
self._consumer_base_idx = 0
self._frontbatch = []
self._idx_consumer_num = {}
def add_consumer(self, op_name):
""" not thread safe """
if op_name in self._consumer:
raise ValueError("op_name({}) is already in channel".format(
op_name))
self._consumer_id[op_name] = 0
if self._idx_consumer_num.get(0) is None:
self._idx_consumer_num[0] = 0
self._idx_consumer_num[0] += 1
def push(self, item):
with self._pushlock:
self._pushbatch.append(item)
self._order += 1
if len(self._pushbatch) == self._batchsize:
self.put(self._pushbatch, timeout=self._timeout)
# self.put(self._pushbatch)
self._pushbatch = []
def front(self):
if self._consumer == 1:
def front(self, op_name):
if len(self._consumer) == 0:
raise Exception(
"expected number of consumers to be greater than 0, but the it is 0."
)
elif len(self._consumer) == 1:
return self.get(timeout=self._timeout)
with self._frontlock:
if self._count == 0:
self._frontbatch = self.get(timeout=self._timeout)
self._count += 1
if self._count == self._consumer:
self._count = 0
return self._frontbatch
consumer_idx = self._consumer[op_name]
base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx
if data_idx >= len(self._frontbatch):
batch_data = self.get(timeout=self._timeout)
self._frontbatch.append(batch_data)
resp = self._frontbatch[data_idx]
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx)
self._frontbatch.pop(0)
self._consumer_base_idx += 1
self._consumer[op_name] += 1
new_consumer_idx = self._consumer[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
return resp # reference, read only
class Op(object):
def __init__(self,
name,
inputs,
in_dtype,
outputs,
......@@ -76,8 +109,12 @@ class Op(object):
device=None,
client_config=None,
server_name=None,
fetch_names=None):
fetch_names=None,
concurrency=1):
self._run = False
# TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_inputs(inputs)
self._in_dtype = in_dtype
self.set_outputs(outputs)
......@@ -108,6 +145,8 @@ class Op(object):
def set_inputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
for channel in channels:
channel.add_consumer(self._name)
self._inputs = channels
def get_outputs(self):
......@@ -162,7 +201,7 @@ class Op(object):
while self._run:
input_data = []
for channel in self._inputs:
input_data.append(channel.front())
input_data.append(channel.front(self._name))
if len(input_data) > 1:
data = self.preprocess(input_data)
else:
......@@ -177,13 +216,16 @@ class Op(object):
for channel in self._outputs:
channel.push(output_data)
def get_concurrency(self):
return self._concurrency
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__()
self._in_channel = in_channel
self._out_channel = out_channel
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
#TODO:
# multi-lock for different clients
# diffenert lock for server and client
......@@ -196,6 +238,15 @@ class GeneralPythonService(
self._recive_func.start()
logging.debug('succ init')
def set_in_channel(self, in_channel):
self._in_channel = in_channel
def set_out_channel(self, out_channel):
if isinstance(out_channel, list):
raise TypeError('out_channel can not be list type')
out_channel.add_consumer("__GeneralPythonService__")
self._out_channel = out_channel
def _recive_out_channel_func(self):
while True:
data = self._out_channel.front()
......@@ -303,15 +354,21 @@ class PyServer(object):
self._out_channel = out_channel.pop()
self.gen_desc()
def op_start_wrapper(self, op):
def _op_start_wrapper(self, op):
return op.start()
def run_server(self):
def _run_ops(self):
for op in self._ops:
# th = multiprocessing.Process(target=self.op_start_wrapper, args=(op, ))
th = threading.Thread(target=self.op_start_wrapper, args=(op, ))
op_concurrency = op.get_concurrency()
for c in range(op_concurrency):
# th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
th = threading.Thread(
target=self._op_start_wrapper, args=(op, ))
th.start()
self._op_threads.append(th)
def run_server(self):
self._run_ops()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
......@@ -339,4 +396,4 @@ class PyServer(object):
model_path, port)
logging.info(cmd)
return
os.system(cmd)
# os.system(cmd)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册