提交 0f18e403 编写于 作者: B barrierye

[WIP] remove channel def in user side

上级 26eda7a0
...@@ -25,8 +25,6 @@ logging.basicConfig( ...@@ -25,8 +25,6 @@ logging.basicConfig(
#level=logging.DEBUG) #level=logging.DEBUG)
level=logging.INFO) level=logging.INFO)
# channel data: {name(str): data(narray)}
class CombineOp(Op): class CombineOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
...@@ -39,13 +37,9 @@ class CombineOp(Op): ...@@ -39,13 +37,9 @@ class CombineOp(Op):
return data return data
read_channel = Channel(name="read_channel") read_op = Op(name="read", input=None)
combine_channel = Channel(name="combine_channel")
out_channel = Channel(name="out_channel")
uci1_op = Op(name="uci1", uci1_op = Op(name="uci1",
input=read_channel, inputs=[read_op],
outputs=[combine_channel],
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9393", server_port="9393",
device="cpu", device="cpu",
...@@ -55,10 +49,8 @@ uci1_op = Op(name="uci1", ...@@ -55,10 +49,8 @@ uci1_op = Op(name="uci1",
concurrency=1, concurrency=1,
timeout=0.1, timeout=0.1,
retry=2) retry=2)
uci2_op = Op(name="uci2", uci2_op = Op(name="uci2",
input=read_channel, inputs=[read_op],
outputs=[combine_channel],
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9292", server_port="9292",
device="cpu", device="cpu",
...@@ -68,24 +60,14 @@ uci2_op = Op(name="uci2", ...@@ -68,24 +60,14 @@ uci2_op = Op(name="uci2",
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1) retry=1)
combine_op = CombineOp( combine_op = CombineOp(
name="combine", name="combine",
input=combine_channel, inputs=[uci1_op, uci2_op],
outputs=[out_channel],
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1) retry=1)
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
logging.info(out_channel.debug())
pyserver = PyServer(profile=False, retry=1) pyserver = PyServer(profile=False, retry=1)
pyserver.add_channel(read_channel) pyserver.add_ops([read_op, uci1_op, uci2_op, combine_op])
pyserver.add_channel(combine_channel)
pyserver.add_channel(out_channel)
pyserver.add_op(uci1_op)
pyserver.add_op(uci2_op)
pyserver.add_op(combine_op)
pyserver.prepare_server(port=8080, worker_num=2) pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server() pyserver.run_server()
...@@ -31,6 +31,7 @@ import random ...@@ -31,6 +31,7 @@ import random
import time import time
import func_timeout import func_timeout
import enum import enum
import collections
class _TimeProfiler(object): class _TimeProfiler(object):
...@@ -140,7 +141,7 @@ class Channel(Queue.Queue): ...@@ -140,7 +141,7 @@ class Channel(Queue.Queue):
Queue.Queue.__init__(self, maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self._name = name self.name = name
self._stop = False self._stop = False
self._cv = threading.Condition() self._cv = threading.Condition()
...@@ -161,7 +162,7 @@ class Channel(Queue.Queue): ...@@ -161,7 +162,7 @@ class Channel(Queue.Queue):
return self._consumers.keys() return self._consumers.keys()
def _log(self, info_str): def _log(self, info_str):
return "[{}] {}".format(self._name, info_str) return "[{}] {}".format(self.name, info_str)
def debug(self): def debug(self):
return self._log("p: {}, c: {}".format(self.get_producers(), return self._log("p: {}, c: {}".format(self.get_producers(),
...@@ -313,8 +314,7 @@ class Channel(Queue.Queue): ...@@ -313,8 +314,7 @@ class Channel(Queue.Queue):
class Op(object): class Op(object):
def __init__(self, def __init__(self,
name, name,
input, inputs,
outputs,
server_model=None, server_model=None,
server_port=None, server_port=None,
device=None, device=None,
...@@ -325,23 +325,24 @@ class Op(object): ...@@ -325,23 +325,24 @@ class Op(object):
timeout=-1, timeout=-1,
retry=2): retry=2):
self._run = False self._run = False
# TODO: globally unique check self.name = name # to identify the type of OP, it must be globally unique
self._name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency self._concurrency = concurrency # amount of concurrency
self.set_input(input) self.set_input_ops(inputs)
self.set_outputs(outputs) self.set_client(client_config, server_name, fetch_names)
self._client = None
if client_config is not None and \
server_name is not None and \
fetch_names is not None:
self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model self._server_model = server_model
self._server_port = server_port self._server_port = server_port
self._device = device self._device = device
self._timeout = timeout self._timeout = timeout
self._retry = retry self._retry = retry
self._input = None
self._outputs = []
def set_client(self, client_config, server_name, fetch_names): def set_client(self, client_config, server_name, fetch_names):
self._client = None
if client_config is None or \
server_name is None or \
fetch_names is None:
return
self._client = Client() self._client = Client()
self._client.load_client_config(client_config) self._client.load_client_config(client_config)
self._client.connect([server_name]) self._client.connect([server_name])
...@@ -350,28 +351,41 @@ class Op(object): ...@@ -350,28 +351,41 @@ class Op(object):
def with_serving(self): def with_serving(self):
return self._client is not None return self._client is not None
def get_input(self): def get_input_channel(self):
return self._input return self._input
def set_input(self, channel): def get_input_ops(self):
return self._input_ops
def set_input_ops(self, ops):
if not isinstance(ops, list):
ops = [] if ops is None else [ops]
self._input_ops = []
for op in ops:
if not isinstance(op, Op):
raise TypeError(
self._log('input op must be Op type, not {}'.format(
type(op))))
self._input_ops.append(op)
def add_input_channel(self, channel):
if not isinstance(channel, Channel): if not isinstance(channel, Channel):
raise TypeError( raise TypeError(
self._log('input channel must be Channel type, not {}'.format( self._log('input channel must be Channel type, not {}'.format(
type(channel)))) type(channel))))
channel.add_consumer(self._name) channel.add_consumer(self.name)
self._input = channel self._input = channel
def get_outputs(self): def get_output_channels(self):
return self._outputs return self._outputs
def set_outputs(self, channels): def add_output_channel(self, channel):
if not isinstance(channels, list): if not isinstance(channel, Channel):
raise TypeError( raise TypeError(
self._log('output channels must be list type, not {}'.format( self._log('output channel must be Channel type, not {}'.format(
type(channels)))) type(channel))))
for channel in channels: channel.add_producer(self.name)
channel.add_producer(self._name) self._outputs.append(channel)
self._outputs = channels
def preprocess(self, channeldata): def preprocess(self, channeldata):
if isinstance(channeldata, dict): if isinstance(channeldata, dict):
...@@ -430,26 +444,26 @@ class Op(object): ...@@ -430,26 +444,26 @@ class Op(object):
def start(self, concurrency_idx): def start(self, concurrency_idx):
self._run = True self._run = True
while self._run: while self._run:
_profiler.record("{}{}-get_0".format(self._name, concurrency_idx)) _profiler.record("{}{}-get_0".format(self.name, concurrency_idx))
input_data = self._input.front(self._name) input_data = self._input.front(self.name)
_profiler.record("{}{}-get_1".format(self._name, concurrency_idx)) _profiler.record("{}{}-get_1".format(self.name, concurrency_idx))
logging.debug(self._log("input_data: {}".format(input_data))) logging.debug(self._log("input_data: {}".format(input_data)))
data_id, error_data = self._parse_channeldata(input_data) data_id, error_data = self._parse_channeldata(input_data)
output_data = None output_data = None
if error_data is None: if error_data is None:
_profiler.record("{}{}-prep_0".format(self._name, _profiler.record("{}{}-prep_0".format(self.name,
concurrency_idx)) concurrency_idx))
data = self.preprocess(input_data) data = self.preprocess(input_data)
_profiler.record("{}{}-prep_1".format(self._name, _profiler.record("{}{}-prep_1".format(self.name,
concurrency_idx)) concurrency_idx))
call_future = None call_future = None
error_info = None error_info = None
if self.with_serving(): if self.with_serving():
for i in range(self._retry): for i in range(self._retry):
_profiler.record("{}{}-midp_0".format(self._name, _profiler.record("{}{}-midp_0".format(self.name,
concurrency_idx)) concurrency_idx))
if self._timeout > 0: if self._timeout > 0:
try: try:
...@@ -460,21 +474,21 @@ class Op(object): ...@@ -460,21 +474,21 @@ class Op(object):
except func_timeout.FunctionTimedOut: except func_timeout.FunctionTimedOut:
logging.error("error: timeout") logging.error("error: timeout")
error_info = "{}({}): timeout".format( error_info = "{}({}): timeout".format(
self._name, concurrency_idx) self.name, concurrency_idx)
except Exception as e: except Exception as e:
logging.error("error: {}".format(e)) logging.error("error: {}".format(e))
error_info = "{}({}): {}".format( error_info = "{}({}): {}".format(
self._name, concurrency_idx, e) self.name, concurrency_idx, e)
else: else:
call_future = self.midprocess(data) call_future = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self._name, _profiler.record("{}{}-midp_1".format(self.name,
concurrency_idx)) concurrency_idx))
if i + 1 < self._retry: if i + 1 < self._retry:
error_info = None error_info = None
logging.warn( logging.warn(
self._log("warn: timeout, retry({})".format(i + self._log("warn: timeout, retry({})".format(i +
1))) 1)))
_profiler.record("{}{}-postp_0".format(self._name, _profiler.record("{}{}-postp_0".format(self.name,
concurrency_idx)) concurrency_idx))
if error_info is not None: if error_info is not None:
error_data = self.errorprocess(error_info, data_id) error_data = self.errorprocess(error_info, data_id)
...@@ -504,18 +518,18 @@ class Op(object): ...@@ -504,18 +518,18 @@ class Op(object):
pbdata.ecode = 0 pbdata.ecode = 0
pbdata.id = data_id pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata) output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}{}-postp_1".format(self._name, _profiler.record("{}{}-postp_1".format(self.name,
concurrency_idx)) concurrency_idx))
else: else:
output_data = ChannelData(pbdata=error_data) output_data = ChannelData(pbdata=error_data)
_profiler.record("{}{}-push_0".format(self._name, concurrency_idx)) _profiler.record("{}{}-push_0".format(self.name, concurrency_idx))
for channel in self._outputs: for channel in self._outputs:
channel.push(output_data, self._name) channel.push(output_data, self.name)
_profiler.record("{}{}-push_1".format(self._name, concurrency_idx)) _profiler.record("{}{}-push_1".format(self.name, concurrency_idx))
def _log(self, info_str): def _log(self, info_str):
return "[{}] {}".format(self._name, info_str) return "[{}] {}".format(self.name, info_str)
def get_concurrency(self): def get_concurrency(self):
return self._concurrency return self._concurrency
...@@ -525,7 +539,7 @@ class GeneralPythonService( ...@@ -525,7 +539,7 @@ class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel, retry=2): def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__() super(GeneralPythonService, self).__init__()
self._name = "#G" self.name = "#G"
self.set_in_channel(in_channel) self.set_in_channel(in_channel)
self.set_out_channel(out_channel) self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug())) logging.debug(self._log(in_channel.debug()))
...@@ -543,14 +557,14 @@ class GeneralPythonService( ...@@ -543,14 +557,14 @@ class GeneralPythonService(
self._recive_func.start() self._recive_func.start()
def _log(self, info_str): def _log(self, info_str):
return "[{}] {}".format(self._name, info_str) return "[{}] {}".format(self.name, info_str)
def set_in_channel(self, in_channel): def set_in_channel(self, in_channel):
if not isinstance(in_channel, Channel): if not isinstance(in_channel, Channel):
raise TypeError( raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format( self._log('in_channel must be Channel type, but get {}'.format(
type(in_channel)))) type(in_channel))))
in_channel.add_producer(self._name) in_channel.add_producer(self.name)
self._in_channel = in_channel self._in_channel = in_channel
def set_out_channel(self, out_channel): def set_out_channel(self, out_channel):
...@@ -558,12 +572,12 @@ class GeneralPythonService( ...@@ -558,12 +572,12 @@ class GeneralPythonService(
raise TypeError( raise TypeError(
self._log('out_channel must be Channel type, but get {}'.format( self._log('out_channel must be Channel type, but get {}'.format(
type(out_channel)))) type(out_channel))))
out_channel.add_consumer(self._name) out_channel.add_consumer(self.name)
self._out_channel = out_channel self._out_channel = out_channel
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
while True: while True:
channeldata = self._out_channel.front(self._name) channeldata = self._out_channel.front(self.name)
if not isinstance(channeldata, ChannelData): if not isinstance(channeldata, ChannelData):
raise TypeError( raise TypeError(
self._log('data must be ChannelData type, but get {}'. self._log('data must be ChannelData type, but get {}'.
...@@ -644,38 +658,43 @@ class GeneralPythonService( ...@@ -644,38 +658,43 @@ class GeneralPythonService(
return resp return resp
def inference(self, request, context): def inference(self, request, context):
_profiler.record("{}-prepack_0".format(self._name)) _profiler.record("{}-prepack_0".format(self.name))
data, data_id = self._pack_data_for_infer(request) data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self._name)) _profiler.record("{}-prepack_1".format(self.name))
resp_channeldata = None resp_channeldata = None
for i in range(self._retry): for i in range(self._retry):
logging.debug(self._log('push data')) logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name)) _profiler.record("{}-push_0".format(self.name))
self._in_channel.push(data, self._name) self._in_channel.push(data, self.name)
_profiler.record("{}-push_1".format(self._name)) _profiler.record("{}-push_1".format(self.name))
logging.debug(self._log('wait for infer')) logging.debug(self._log('wait for infer'))
_profiler.record("{}-fetch_0".format(self._name)) _profiler.record("{}-fetch_0".format(self.name))
resp_channeldata = self._get_data_in_globel_resp_dict(data_id) resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name)) _profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == 0: if resp_channeldata.pbdata.ecode == 0:
break break
logging.warn("retry({}): {}".format( logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info)) i + 1, resp_channeldata.pbdata.error_info))
_profiler.record("{}-postpack_0".format(self._name)) _profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata) resp = self._pack_data_for_resp(resp_channeldata)
_profiler.record("{}-postpack_1".format(self._name)) _profiler.record("{}-postpack_1".format(self.name))
_profiler.print_profile() _profiler.print_profile()
return resp return resp
class VirtualOp(Op):
pass
class PyServer(object): class PyServer(object):
def __init__(self, retry=2, profile=False): def __init__(self, retry=2, profile=False):
self._channels = [] self._channels = []
self._ops = [] self._user_ops = []
self._total_ops = []
self._op_threads = [] self._op_threads = []
self._port = None self._port = None
self._worker_num = None self._worker_num = None
...@@ -688,40 +707,147 @@ class PyServer(object): ...@@ -688,40 +707,147 @@ class PyServer(object):
self._channels.append(channel) self._channels.append(channel)
def add_op(self, op): def add_op(self, op):
self._ops.append(op) self._user_ops.append(op)
def add_ops(self, ops):
self._user_ops.expand(ops)
def gen_desc(self): def gen_desc(self):
logging.info('here will generate desc for PAAS') logging.info('here will generate desc for PAAS')
pass pass
def _topo_sort(self):
indeg_num = {}
outdegs = {}
que_idx = 0 # scroll queue
ques = [Queue.SimpleQueue() for _ in range(2)]
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
raise Exception("the name of Op must be unique")
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
for pred_op in op.get_input_ops():
if op.name in outdegs:
outdegs[op.name].append(op)
else:
outdegs[op.name] = [op]
# get dag_views
dag_views = []
sorted_op_num = 0
while True:
que = ques[que_idx]
next_que = ques[(que_idx + 1) % 2]
dag_view = []
while que.qsize() != 0:
op = que.get()
dag_view.append(op)
op_name = op.name
sorted_op_num += 1
for succ_op in outdegs[op_name]:
indeg_num[op_name] -= 1
if indeg_num[succ_op.name] == 0:
next_que.put(succ_op)
dag_views.append(dag_view)
if next_que.qsize() == 0:
break
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
virtual_op_idx = 0
channel_idx = 0
virtual_ops = []
channels = []
input_channel = None
for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views):
break
next_view = dag_views[v_idx + 1]
actual_next_view = []
pred_op_of_next_view_op = {}
for op in view:
# create virtual op
for succ_op in outdegs[op.name]:
if succ_op in next_view:
actual_next_view.append(succ_op)
if succ_op.name not in pred_op_of_next_view_op:
pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op)
else:
vop = VirtualOp(name="vir{}".format(virtual_op_idx))
virtual_op_idx += 1
virtual_ops.append(virtual_op)
outdegs[vop.name] = [succ_op]
actual_next_view.append(vop)
# TODO: combine vop
pred_op_of_next_view_op[vop.name] = [op]
# create channel
processed_op = set()
for o_idx, op in enumerate(actual_next_view):
op_name = op.name
if op_name in processed_op:
continue
channel = Channel(name="chl{}".format(channel_idx))
channel_idx += 1
channels.append(channel)
op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op_name]
if v_idx == 0:
input_channel = channel
else:
for pred_op in pred_ops:
pred_op.add_output_channel(channel)
processed_op.add(op_name)
# combine channel
for other_op in actual_next_view[o_idx:]:
if other_op.name in processed_op:
continue
other_pred_ops = pred_op_of_next_view_op[other_op.name]
if len(other_pred_ops) != len(pred_ops):
continue
same_flag = True
for pred_op in pred_ops:
if pred_op not in other_pred_ops:
same_flag = False
break
if same_flag:
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
output_channel = Channel(name="Ochl")
channels.append(output_channel)
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
self._ops = self._user_ops + virtual_ops
self._channels = channels
return input_channel, output_channel
def prepare_server(self, port, worker_num): def prepare_server(self, port, worker_num):
self._port = port self._port = port
self._worker_num = worker_num self._worker_num = worker_num
inputs = set()
outputs = set() input_channel, output_channel = self._topo_sort()
for op in self._ops: self._in_channel = input_channel
inputs |= set([op.get_input()]) self.out_channel = output_channel
outputs |= set(op.get_outputs())
if op.with_serving():
self.prepare_serving(op)
in_channel = inputs - outputs
out_channel = outputs - inputs
if len(in_channel) != 1 or len(out_channel) != 1:
raise Exception(
"in_channel(out_channel) more than 1 or no in_channel(out_channel)"
)
self._in_channel = in_channel.pop()
self._out_channel = out_channel.pop()
self.gen_desc() self.gen_desc()
def _op_start_wrapper(self, op, concurrency_idx): def _op_start_wrapper(self, op, concurrency_idx):
return op.start(concurrency_idx) return op.start(concurrency_idx)
def _run_ops(self): def _run_ops(self):
#TODO
for op in self._ops: for op in self._ops:
op_concurrency = op.get_concurrency() op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format( logging.debug("run op: {}, op_concurrency: {}".format(
op._name, op_concurrency)) op.name, op_concurrency))
for c in range(op_concurrency): for c in range(op_concurrency):
# th = multiprocessing.Process( # th = multiprocessing.Process(
th = threading.Thread( th = threading.Thread(
...@@ -730,6 +856,7 @@ class PyServer(object): ...@@ -730,6 +856,7 @@ class PyServer(object):
self._op_threads.append(th) self._op_threads.append(th)
def _stop_ops(self): def _stop_ops(self):
# TODO
for op in self._ops: for op in self._ops:
op.stop() op.stop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册