提交 0712a90a 编写于 作者: B barrierye

bug fix

上级 b4bcf477
......@@ -27,8 +27,6 @@ logging.basicConfig(
class CombineOp(Op):
pass
'''
def preprocess(self, input_data):
combined_prediction = 0
for op_name, channeldata in input_data.items():
......@@ -37,7 +35,6 @@ class CombineOp(Op):
combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2}
return data
'''
read_op = Op(name="read", inputs=None)
......
......@@ -98,10 +98,10 @@ class ChannelData(object):
'''
There are several ways to use it:
- ChannelData(future, pbdata[, callback_func])
- ChannelData(future, data_id[, callback_func])
- ChannelData(pbdata)
- ChannelData(ecode, error_info, data_id)
1. ChannelData(future, pbdata[, callback_func])
2. ChannelData(future, data_id[, callback_func])
3. ChannelData(pbdata)
4. ChannelData(ecode, error_info, data_id)
'''
if ecode is not None:
if data_id is None or error_info is None:
......@@ -138,9 +138,8 @@ class ChannelData(object):
if self.callback_func is not None:
feed = self.callback_func(feed)
else:
raise TypeError(
self._log("Error type({}) in pbdata.type.".format(
self.pbdata.type)))
raise TypeError("Error type({}) in pbdata.type.".format(
self.pbdata.type))
return feed
......@@ -334,6 +333,7 @@ class Channel(Queue.Queue):
#TODO
self.close()
self._stop = True
self._cv.notify_all()
class Op(object):
......@@ -358,7 +358,7 @@ class Op(object):
self._server_port = server_port
self._device = device
self._timeout = timeout
self._retry = retry
self._retry = max(1, retry)
self._input = None
self._outputs = []
......@@ -443,48 +443,51 @@ class Op(object):
self._run = False
def _parse_channeldata(self, channeldata):
data_id, error_data = None, None
data_id, error_pbdata = None, None
if isinstance(channeldata, dict):
parsed_data = {}
key = channeldata.keys()[0]
data_id = channeldata[key].pbdata.id
for _, data in channeldata.items():
if data.pbdata.ecode != 0:
error_data = data.pbdata
if data.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = data.pbdata
break
else:
data_id = channeldata.pbdata.id
if channeldata.pbdata.ecode != 0:
error_data = channeldata.pbdata
return data_id, error_data
if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = channeldata.pbdata
return data_id, error_pbdata
def _push_to_output_channels(self, data):
def _push_to_output_channels(self, data, name=None):
if name is None:
name = self.name
for channel in self._outputs:
channel.push(data, self.name)
channel.push(data, name)
def start(self, concurrency_idx):
op_info_prefix = "[{}{}]".format(self.name, concurrency_idx)
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
_profiler.record("{}-get_0".format(op_info_prefix))
input_data = self._input.front(self.name)
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(input_data)))
logging.debug(log("input_data: {}".format(channeldata)))
data_id, error_data = self._parse_channeldata(input_data)
data_id, error_pbdata = self._parse_channeldata(channeldata)
# predecessor Op error
if error_data is not None:
self._push_to_output_channels(ChannelData(pbdata=error_data))
# error data in predecessor Op
if error_pbdata is not None:
self._push_to_output_channels(ChannelData(pbdata=error_pbdata))
continue
# preprocess function not implemented
# preprecess
try:
_profiler.record("{}-prep_0".format(op_info_prefix))
data = self.preprocess(input_data)
preped_data = self.preprocess(channeldata)
_profiler.record("{}-prep_1".format(op_info_prefix))
except NotImplementedError as e:
# preprocess function not implemented
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
......@@ -493,18 +496,35 @@ class Op(object):
error_info=error_info,
data_id=data_id))
continue
except TypeError as e:
# Error type in channeldata.pbdata.type
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id))
continue
except Exception as e:
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id))
continue
# midprocess
call_future = None
ecode = 0
error_info = None
if self.with_serving():
ecode = ChannelDataEcode.OK.value
_profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0:
try:
call_future = self.midprocess(data)
call_future = self.midprocess(preped_data)
except Exception as e:
logging.error(self._log(e))
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
......@@ -512,15 +532,17 @@ class Op(object):
for i in range(self._retry):
try:
call_future = func_timeout.func_timeout(
self._timeout, self.midprocess, args=(data, ))
except func_timeout.FunctionTimedOut:
self._timeout,
self.midprocess,
args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = "{} timeout".format(op_info_prefix)
error_info = log(e)
logging.error(error_info)
else:
logging.warn(
log("warn: timeout, retry({})".format(i +
1)))
log("timeout, retry({})".format(i + 1)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
......@@ -528,7 +550,7 @@ class Op(object):
break
else:
break
if ecode != 0:
if ecode != ChannelDataEcode.OK.value:
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
......@@ -539,32 +561,63 @@ class Op(object):
# postprocess
output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving(): # use call_future
if self.with_serving():
# use call_future
output_data = ChannelData(
future=call_future,
data_id=data_id,
callback_func=self.postprocess)
else:
post_data = self.postprocess(data)
if not isinstance(post_data, dict):
try:
postped_data = self.postprocess(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(post_data)))
"dict type, but get {}".format(type(postped_data)))
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
ecode = ChannelDataEcode.OK.value
error_info = None
pbdata = channel_pb2.ChannelData()
for name, value in post_data.items():
for name, value in postped_data.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
pbdata.ecode = 0
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
pbdata.ecode = ecode
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}-postp_1".format(op_info_prefix))
......@@ -587,6 +640,45 @@ class Op(object):
return self._concurrency
class VirtualOp(Op):
''' For connecting two channels. '''
def __init__(self, name, concurrency=1):
super(VirtualOp, self).__init__(
name=name, inputs=None, concurrency=concurrency)
self._virtual_pred_ops = []
def add_virtual_pred_op(self, op):
self._virtual_pred_ops.append(op)
def add_output_channel(self, channel):
if not isinstance(channel, Channel):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
for op in self._virtual_pred_ops:
channel.add_producer(op.name)
self._outputs.append(channel)
def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
_profiler.record("{}-push_0".format(op_info_prefix))
if isinstance(channeldata, dict):
for name, data in channeldata.items():
self._push_to_output_channels(data, name=name)
else:
self._push_to_output_channels(channeldata,
self._virtual_pred_ops[0].name)
_profiler.record("{}-push_1".format(op_info_prefix))
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel, retry=2):
......@@ -668,35 +760,27 @@ class GeneralPythonService(
inst.name = name
inst.type = request.type[idx]
pbdata.insts.append(inst)
pbdata.ecode = 0 #TODO: parse request error
pbdata.ecode = ChannelDataEcode.OK.value #TODO: parse request error
return ChannelData(pbdata=pbdata), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
logging.debug(self._log('gen resp'))
resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode
if resp.ecode == 0:
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts:
logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data)
logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name)
logging.debug(self._log('append shape'))
resp.shape.append(inst.shape)
logging.debug(self._log('append type'))
resp.type.append(inst.type)
elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
feed = channeldata.futures.result()
if channeldata.callback_func is not None:
feed = channeldata.callback_func(feed)
for name, var in feed:
logging.debug(self._log('append data'))
resp.fetch_insts.append(var.tobytes())
logging.debug(self._log('append name'))
resp.fetch_var_names.append(name)
logging.debug(self._log('append shape'))
resp.shape.append(
np.array(
var.shape, dtype="int32").tobytes())
......@@ -726,7 +810,7 @@ class GeneralPythonService(
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == 0:
if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
break
if i + 1 < self._retry:
logging.warn("retry({}): {}".format(
......@@ -743,7 +827,7 @@ class PyServer(object):
def __init__(self, retry=2, profile=False):
self._channels = []
self._user_ops = []
self._total_ops = []
self._actual_ops = []
self._op_threads = []
self._port = None
self._worker_num = None
......@@ -767,9 +851,13 @@ class PyServer(object):
def _topo_sort(self):
indeg_num = {}
outdegs = {op.name: [] for op in self._user_ops}
que_idx = 0 # scroll queue
ques = [Queue.Queue() for _ in range(2)]
for op in self._user_ops:
if len(op.get_input_ops()) == 0:
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
......@@ -780,7 +868,7 @@ class PyServer(object):
for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op)
# get dag_views
# topo sort to get dag_views
dag_views = []
sorted_op_num = 0
while True:
......@@ -790,9 +878,8 @@ class PyServer(object):
while que.qsize() != 0:
op = que.get()
dag_view.append(op)
op_name = op.name
sorted_op_num += 1
for succ_op in outdegs[op_name]:
for succ_op in outdegs[op.name]:
indeg_num[succ_op.name] -= 1
if indeg_num[succ_op.name] == 0:
next_que.put(succ_op)
......@@ -808,52 +895,69 @@ class PyServer(object):
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
virtual_op_idx = 0
channel_idx = 0
def name_generator(prefix):
def number_generator():
idx = 0
while True:
yield "{}{}".format(prefix, idx)
idx += 1
return number_generator()
virtual_op_name_gen = name_generator("vir")
channel_name_gen = name_generator("chl")
virtual_ops = []
channels = []
input_channel = None
actual_view = None
for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views):
break
next_view = dag_views[v_idx + 1]
if actual_view is None:
actual_view = view
actual_next_view = []
pred_op_of_next_view_op = {}
for op in view:
# create virtual op
for op in actual_view:
# find actual succ op in next view and create virtual op
for succ_op in outdegs[op.name]:
if succ_op in next_view:
if succ_op not in actual_next_view:
actual_next_view.append(succ_op)
if succ_op.name not in pred_op_of_next_view_op:
pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op)
else:
vop = Op(name="vir{}".format(virtual_op_idx), inputs=[])
virtual_op_idx += 1
# create virtual op
virtual_op = None
virtual_op = VirtualOp(name=virtual_op_name_gen.next())
virtual_ops.append(virtual_op)
outdegs[vop.name] = [succ_op]
actual_next_view.append(vop)
# TODO: combine vop
pred_op_of_next_view_op[vop.name] = [op]
outdegs[virtual_op.name] = [succ_op]
actual_next_view.append(virtual_op)
pred_op_of_next_view_op[virtual_op.name] = [op]
virtual_op.add_virtual_pred_op(op)
actual_view = actual_next_view
# create channel
processed_op = set()
for o_idx, op in enumerate(actual_next_view):
op_name = op.name
if op_name in processed_op:
if op.name in processed_op:
continue
channel = Channel(name="chl{}".format(channel_idx))
channel_idx += 1
channel = Channel(name=channel_name_gen.next())
channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op_name]
pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0:
input_channel = channel
else:
# if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops:
logging.debug("{} => {}".format(pred_op.name,
channel.name))
pred_op.add_output_channel(channel)
processed_op.add(op_name)
# combine channel
for other_op in actual_next_view[o_idx:]:
processed_op.add(op.name)
# find same input op to combine channel
for other_op in actual_next_view[o_idx + 1:]:
if other_op.name in processed_op:
continue
other_pred_ops = pred_op_of_next_view_op[other_op.name]
......@@ -865,19 +969,24 @@ class PyServer(object):
same_flag = False
break
if same_flag:
logging.debug("{} => {}".format(channel.name,
other_op.name))
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
output_channel = Channel(name="Ochl")
output_channel = Channel(name=channel_name_gen.next())
channels.append(output_channel)
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
self._ops = virtual_ops
self._actual_ops = virtual_ops
for op in self._user_ops:
if len(op.get_input_ops()) == 0:
# pass read op
continue
self._ops.append(op)
self._actual_ops.append(op)
self._channels = channels
for c in channels:
logging.debug(c.debug())
return input_channel, output_channel
def prepare_server(self, port, worker_num):
......@@ -887,7 +996,7 @@ class PyServer(object):
input_channel, output_channel = self._topo_sort()
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._ops:
for op in self._actual_ops:
if op.with_serving():
self.prepare_serving(op)
self.gen_desc()
......@@ -896,7 +1005,7 @@ class PyServer(object):
return op.start(concurrency_idx)
def _run_ops(self):
for op in self._ops:
for op in self._actual_ops:
op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format(
op.name, op_concurrency))
......@@ -907,7 +1016,7 @@ class PyServer(object):
self._op_threads.append(th)
def _stop_ops(self):
for op in self._ops:
for op in self._actual_ops:
op.stop()
def run_server(self):
......@@ -921,6 +1030,8 @@ class PyServer(object):
server.start()
server.wait_for_termination()
self._stop_ops() # TODO
for th in self._op_threads:
th.join()
def prepare_serving(self, op):
model_path = op._server_model
......@@ -935,5 +1046,3 @@ class PyServer(object):
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
# run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd))
return
# os.system(cmd)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册