提交 0f18e403 编写于 作者: B barrierye

[WIP] remove channel def in user side

上级 26eda7a0
......@@ -25,8 +25,6 @@ logging.basicConfig(
#level=logging.DEBUG)
level=logging.INFO)
# channel data: {name(str): data(narray)}
class CombineOp(Op):
def preprocess(self, input_data):
......@@ -39,13 +37,9 @@ class CombineOp(Op):
return data
read_channel = Channel(name="read_channel")
combine_channel = Channel(name="combine_channel")
out_channel = Channel(name="out_channel")
read_op = Op(name="read", input=None)
uci1_op = Op(name="uci1",
input=read_channel,
outputs=[combine_channel],
inputs=[read_op],
server_model="./uci_housing_model",
server_port="9393",
device="cpu",
......@@ -55,10 +49,8 @@ uci1_op = Op(name="uci1",
concurrency=1,
timeout=0.1,
retry=2)
uci2_op = Op(name="uci2",
input=read_channel,
outputs=[combine_channel],
inputs=[read_op],
server_model="./uci_housing_model",
server_port="9292",
device="cpu",
......@@ -68,24 +60,14 @@ uci2_op = Op(name="uci2",
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp(
name="combine",
input=combine_channel,
outputs=[out_channel],
inputs=[uci1_op, uci2_op],
concurrency=1,
timeout=-1,
retry=1)
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
logging.info(out_channel.debug())
pyserver = PyServer(profile=False, retry=1)
pyserver.add_channel(read_channel)
pyserver.add_channel(combine_channel)
pyserver.add_channel(out_channel)
pyserver.add_op(uci1_op)
pyserver.add_op(uci2_op)
pyserver.add_op(combine_op)
pyserver.add_ops([read_op, uci1_op, uci2_op, combine_op])
pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server()
......@@ -31,6 +31,7 @@ import random
import time
import func_timeout
import enum
import collections
class _TimeProfiler(object):
......@@ -140,7 +141,7 @@ class Channel(Queue.Queue):
Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self._name = name
self.name = name
self._stop = False
self._cv = threading.Condition()
......@@ -161,7 +162,7 @@ class Channel(Queue.Queue):
return self._consumers.keys()
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
return "[{}] {}".format(self.name, info_str)
def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(),
......@@ -313,8 +314,7 @@ class Channel(Queue.Queue):
class Op(object):
def __init__(self,
name,
input,
outputs,
inputs,
server_model=None,
server_port=None,
device=None,
......@@ -325,23 +325,24 @@ class Op(object):
timeout=-1,
retry=2):
self._run = False
# TODO: globally unique check
self._name = name # to identify the type of OP, it must be globally unique
self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_input(input)
self.set_outputs(outputs)
self._client = None
if client_config is not None and \
server_name is not None and \
fetch_names is not None:
self.set_client(client_config, server_name, fetch_names)
self.set_input_ops(inputs)
self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model
self._server_port = server_port
self._device = device
self._timeout = timeout
self._retry = retry
self._input = None
self._outputs = []
def set_client(self, client_config, server_name, fetch_names):
self._client = None
if client_config is None or \
server_name is None or \
fetch_names is None:
return
self._client = Client()
self._client.load_client_config(client_config)
self._client.connect([server_name])
......@@ -350,28 +351,41 @@ class Op(object):
def with_serving(self):
return self._client is not None
def get_input(self):
def get_input_channel(self):
return self._input
def set_input(self, channel):
def get_input_ops(self):
return self._input_ops
def set_input_ops(self, ops):
if not isinstance(ops, list):
ops = [] if ops is None else [ops]
self._input_ops = []
for op in ops:
if not isinstance(op, Op):
raise TypeError(
self._log('input op must be Op type, not {}'.format(
type(op))))
self._input_ops.append(op)
def add_input_channel(self, channel):
if not isinstance(channel, Channel):
raise TypeError(
self._log('input channel must be Channel type, not {}'.format(
type(channel))))
channel.add_consumer(self._name)
channel.add_consumer(self.name)
self._input = channel
def get_outputs(self):
def get_output_channels(self):
return self._outputs
def set_outputs(self, channels):
if not isinstance(channels, list):
def add_output_channel(self, channel):
if not isinstance(channel, Channel):
raise TypeError(
self._log('output channels must be list type, not {}'.format(
type(channels))))
for channel in channels:
channel.add_producer(self._name)
self._outputs = channels
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
channel.add_producer(self.name)
self._outputs.append(channel)
def preprocess(self, channeldata):
if isinstance(channeldata, dict):
......@@ -430,26 +444,26 @@ class Op(object):
def start(self, concurrency_idx):
self._run = True
while self._run:
_profiler.record("{}{}-get_0".format(self._name, concurrency_idx))
input_data = self._input.front(self._name)
_profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
_profiler.record("{}{}-get_0".format(self.name, concurrency_idx))
input_data = self._input.front(self.name)
_profiler.record("{}{}-get_1".format(self.name, concurrency_idx))
logging.debug(self._log("input_data: {}".format(input_data)))
data_id, error_data = self._parse_channeldata(input_data)
output_data = None
if error_data is None:
_profiler.record("{}{}-prep_0".format(self._name,
_profiler.record("{}{}-prep_0".format(self.name,
concurrency_idx))
data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self._name,
_profiler.record("{}{}-prep_1".format(self.name,
concurrency_idx))
call_future = None
error_info = None
if self.with_serving():
for i in range(self._retry):
_profiler.record("{}{}-midp_0".format(self._name,
_profiler.record("{}{}-midp_0".format(self.name,
concurrency_idx))
if self._timeout > 0:
try:
......@@ -460,21 +474,21 @@ class Op(object):
except func_timeout.FunctionTimedOut:
logging.error("error: timeout")
error_info = "{}({}): timeout".format(
self._name, concurrency_idx)
self.name, concurrency_idx)
except Exception as e:
logging.error("error: {}".format(e))
error_info = "{}({}): {}".format(
self._name, concurrency_idx, e)
self.name, concurrency_idx, e)
else:
call_future = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self._name,
_profiler.record("{}{}-midp_1".format(self.name,
concurrency_idx))
if i + 1 < self._retry:
error_info = None
logging.warn(
self._log("warn: timeout, retry({})".format(i +
1)))
_profiler.record("{}{}-postp_0".format(self._name,
_profiler.record("{}{}-postp_0".format(self.name,
concurrency_idx))
if error_info is not None:
error_data = self.errorprocess(error_info, data_id)
......@@ -504,18 +518,18 @@ class Op(object):
pbdata.ecode = 0
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}{}-postp_1".format(self._name,
_profiler.record("{}{}-postp_1".format(self.name,
concurrency_idx))
else:
output_data = ChannelData(pbdata=error_data)
_profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
_profiler.record("{}{}-push_0".format(self.name, concurrency_idx))
for channel in self._outputs:
channel.push(output_data, self._name)
_profiler.record("{}{}-push_1".format(self._name, concurrency_idx))
channel.push(output_data, self.name)
_profiler.record("{}{}-push_1".format(self.name, concurrency_idx))
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
return "[{}] {}".format(self.name, info_str)
def get_concurrency(self):
return self._concurrency
......@@ -525,7 +539,7 @@ class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__()
self._name = "#G"
self.name = "#G"
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug()))
......@@ -543,14 +557,14 @@ class GeneralPythonService(
self._recive_func.start()
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
return "[{}] {}".format(self.name, info_str)
def set_in_channel(self, in_channel):
if not isinstance(in_channel, Channel):
raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format(
type(in_channel))))
in_channel.add_producer(self._name)
in_channel.add_producer(self.name)
self._in_channel = in_channel
def set_out_channel(self, out_channel):
......@@ -558,12 +572,12 @@ class GeneralPythonService(
raise TypeError(
self._log('out_channel must be Channel type, but get {}'.format(
type(out_channel))))
out_channel.add_consumer(self._name)
out_channel.add_consumer(self.name)
self._out_channel = out_channel
def _recive_out_channel_func(self):
while True:
channeldata = self._out_channel.front(self._name)
channeldata = self._out_channel.front(self.name)
if not isinstance(channeldata, ChannelData):
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
......@@ -644,38 +658,43 @@ class GeneralPythonService(
return resp
def inference(self, request, context):
_profiler.record("{}-prepack_0".format(self._name))
_profiler.record("{}-prepack_0".format(self.name))
data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self._name))
_profiler.record("{}-prepack_1".format(self.name))
resp_channeldata = None
for i in range(self._retry):
logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name))
self._in_channel.push(data, self._name)
_profiler.record("{}-push_1".format(self._name))
_profiler.record("{}-push_0".format(self.name))
self._in_channel.push(data, self.name)
_profiler.record("{}-push_1".format(self.name))
logging.debug(self._log('wait for infer'))
_profiler.record("{}-fetch_0".format(self._name))
_profiler.record("{}-fetch_0".format(self.name))
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name))
_profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == 0:
break
logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
_profiler.record("{}-postpack_0".format(self._name))
_profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata)
_profiler.record("{}-postpack_1".format(self._name))
_profiler.record("{}-postpack_1".format(self.name))
_profiler.print_profile()
return resp
class VirtualOp(Op):
pass
class PyServer(object):
def __init__(self, retry=2, profile=False):
self._channels = []
self._ops = []
self._user_ops = []
self._total_ops = []
self._op_threads = []
self._port = None
self._worker_num = None
......@@ -688,40 +707,147 @@ class PyServer(object):
self._channels.append(channel)
def add_op(self, op):
self._ops.append(op)
self._user_ops.append(op)
def add_ops(self, ops):
self._user_ops.expand(ops)
def gen_desc(self):
logging.info('here will generate desc for PAAS')
pass
def _topo_sort(self):
indeg_num = {}
outdegs = {}
que_idx = 0 # scroll queue
ques = [Queue.SimpleQueue() for _ in range(2)]
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
raise Exception("the name of Op must be unique")
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
for pred_op in op.get_input_ops():
if op.name in outdegs:
outdegs[op.name].append(op)
else:
outdegs[op.name] = [op]
# get dag_views
dag_views = []
sorted_op_num = 0
while True:
que = ques[que_idx]
next_que = ques[(que_idx + 1) % 2]
dag_view = []
while que.qsize() != 0:
op = que.get()
dag_view.append(op)
op_name = op.name
sorted_op_num += 1
for succ_op in outdegs[op_name]:
indeg_num[op_name] -= 1
if indeg_num[succ_op.name] == 0:
next_que.put(succ_op)
dag_views.append(dag_view)
if next_que.qsize() == 0:
break
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
virtual_op_idx = 0
channel_idx = 0
virtual_ops = []
channels = []
input_channel = None
for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views):
break
next_view = dag_views[v_idx + 1]
actual_next_view = []
pred_op_of_next_view_op = {}
for op in view:
# create virtual op
for succ_op in outdegs[op.name]:
if succ_op in next_view:
actual_next_view.append(succ_op)
if succ_op.name not in pred_op_of_next_view_op:
pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op)
else:
vop = VirtualOp(name="vir{}".format(virtual_op_idx))
virtual_op_idx += 1
virtual_ops.append(virtual_op)
outdegs[vop.name] = [succ_op]
actual_next_view.append(vop)
# TODO: combine vop
pred_op_of_next_view_op[vop.name] = [op]
# create channel
processed_op = set()
for o_idx, op in enumerate(actual_next_view):
op_name = op.name
if op_name in processed_op:
continue
channel = Channel(name="chl{}".format(channel_idx))
channel_idx += 1
channels.append(channel)
op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op_name]
if v_idx == 0:
input_channel = channel
else:
for pred_op in pred_ops:
pred_op.add_output_channel(channel)
processed_op.add(op_name)
# combine channel
for other_op in actual_next_view[o_idx:]:
if other_op.name in processed_op:
continue
other_pred_ops = pred_op_of_next_view_op[other_op.name]
if len(other_pred_ops) != len(pred_ops):
continue
same_flag = True
for pred_op in pred_ops:
if pred_op not in other_pred_ops:
same_flag = False
break
if same_flag:
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
output_channel = Channel(name="Ochl")
channels.append(output_channel)
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
self._ops = self._user_ops + virtual_ops
self._channels = channels
return input_channel, output_channel
def prepare_server(self, port, worker_num):
self._port = port
self._worker_num = worker_num
inputs = set()
outputs = set()
for op in self._ops:
inputs |= set([op.get_input()])
outputs |= set(op.get_outputs())
if op.with_serving():
self.prepare_serving(op)
in_channel = inputs - outputs
out_channel = outputs - inputs
if len(in_channel) != 1 or len(out_channel) != 1:
raise Exception(
"in_channel(out_channel) more than 1 or no in_channel(out_channel)"
)
self._in_channel = in_channel.pop()
self._out_channel = out_channel.pop()
input_channel, output_channel = self._topo_sort()
self._in_channel = input_channel
self.out_channel = output_channel
self.gen_desc()
def _op_start_wrapper(self, op, concurrency_idx):
return op.start(concurrency_idx)
def _run_ops(self):
#TODO
for op in self._ops:
op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format(
op._name, op_concurrency))
op.name, op_concurrency))
for c in range(op_concurrency):
# th = multiprocessing.Process(
th = threading.Thread(
......@@ -730,6 +856,7 @@ class PyServer(object):
self._op_threads.append(th)
def _stop_ops(self):
# TODO
for op in self._ops:
op.stop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册