提交 9232a461 编写于 作者: B barrierye

bug fix

上级 a2a8409b
...@@ -27,8 +27,6 @@ logging.basicConfig( ...@@ -27,8 +27,6 @@ logging.basicConfig(
class CombineOp(Op): class CombineOp(Op):
pass
'''
def preprocess(self, input_data): def preprocess(self, input_data):
combined_prediction = 0 combined_prediction = 0
for op_name, channeldata in input_data.items(): for op_name, channeldata in input_data.items():
...@@ -37,7 +35,6 @@ class CombineOp(Op): ...@@ -37,7 +35,6 @@ class CombineOp(Op):
combined_prediction += data["prediction"] combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2} data = {"combined_prediction": combined_prediction / 2}
return data return data
'''
read_op = Op(name="read", inputs=None) read_op = Op(name="read", inputs=None)
......
...@@ -98,10 +98,10 @@ class ChannelData(object): ...@@ -98,10 +98,10 @@ class ChannelData(object):
''' '''
There are several ways to use it: There are several ways to use it:
- ChannelData(future, pbdata[, callback_func]) 1. ChannelData(future, pbdata[, callback_func])
- ChannelData(future, data_id[, callback_func]) 2. ChannelData(future, data_id[, callback_func])
- ChannelData(pbdata) 3. ChannelData(pbdata)
- ChannelData(ecode, error_info, data_id) 4. ChannelData(ecode, error_info, data_id)
''' '''
if ecode is not None: if ecode is not None:
if data_id is None or error_info is None: if data_id is None or error_info is None:
...@@ -138,9 +138,8 @@ class ChannelData(object): ...@@ -138,9 +138,8 @@ class ChannelData(object):
if self.callback_func is not None: if self.callback_func is not None:
feed = self.callback_func(feed) feed = self.callback_func(feed)
else: else:
raise TypeError( raise TypeError("Error type({}) in pbdata.type.".format(
self._log("Error type({}) in pbdata.type.".format( self.pbdata.type))
self.pbdata.type)))
return feed return feed
...@@ -334,6 +333,7 @@ class Channel(Queue.Queue): ...@@ -334,6 +333,7 @@ class Channel(Queue.Queue):
#TODO #TODO
self.close() self.close()
self._stop = True self._stop = True
self._cv.notify_all()
class Op(object): class Op(object):
...@@ -358,7 +358,7 @@ class Op(object): ...@@ -358,7 +358,7 @@ class Op(object):
self._server_port = server_port self._server_port = server_port
self._device = device self._device = device
self._timeout = timeout self._timeout = timeout
self._retry = retry self._retry = max(1, retry)
self._input = None self._input = None
self._outputs = [] self._outputs = []
...@@ -443,48 +443,51 @@ class Op(object): ...@@ -443,48 +443,51 @@ class Op(object):
self._run = False self._run = False
def _parse_channeldata(self, channeldata): def _parse_channeldata(self, channeldata):
data_id, error_data = None, None data_id, error_pbdata = None, None
if isinstance(channeldata, dict): if isinstance(channeldata, dict):
parsed_data = {} parsed_data = {}
key = channeldata.keys()[0] key = channeldata.keys()[0]
data_id = channeldata[key].pbdata.id data_id = channeldata[key].pbdata.id
for _, data in channeldata.items(): for _, data in channeldata.items():
if data.pbdata.ecode != 0: if data.pbdata.ecode != ChannelDataEcode.OK.value:
error_data = data.pbdata error_pbdata = data.pbdata
break break
else: else:
data_id = channeldata.pbdata.id data_id = channeldata.pbdata.id
if channeldata.pbdata.ecode != 0: if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
error_data = channeldata.pbdata error_pbdata = channeldata.pbdata
return data_id, error_data return data_id, error_pbdata
def _push_to_output_channels(self, data): def _push_to_output_channels(self, data, name=None):
if name is None:
name = self.name
for channel in self._outputs: for channel in self._outputs:
channel.push(data, self.name) channel.push(data, name)
def start(self, concurrency_idx): def start(self, concurrency_idx):
op_info_prefix = "[{}{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix) log = self._get_log_func(op_info_prefix)
self._run = True self._run = True
while self._run: while self._run:
_profiler.record("{}-get_0".format(op_info_prefix)) _profiler.record("{}-get_0".format(op_info_prefix))
input_data = self._input.front(self.name) channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(input_data))) logging.debug(log("input_data: {}".format(channeldata)))
data_id, error_data = self._parse_channeldata(input_data) data_id, error_pbdata = self._parse_channeldata(channeldata)
# predecessor Op error # error data in predecessor Op
if error_data is not None: if error_pbdata is not None:
self._push_to_output_channels(ChannelData(pbdata=error_data)) self._push_to_output_channels(ChannelData(pbdata=error_pbdata))
continue continue
# preprocess function not implemented # preprecess
try: try:
_profiler.record("{}-prep_0".format(op_info_prefix)) _profiler.record("{}-prep_0".format(op_info_prefix))
data = self.preprocess(input_data) preped_data = self.preprocess(channeldata)
_profiler.record("{}-prep_1".format(op_info_prefix)) _profiler.record("{}-prep_1".format(op_info_prefix))
except NotImplementedError as e: except NotImplementedError as e:
# preprocess function not implemented
error_info = log(e) error_info = log(e)
logging.error(error_info) logging.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
...@@ -493,18 +496,35 @@ class Op(object): ...@@ -493,18 +496,35 @@ class Op(object):
error_info=error_info, error_info=error_info,
data_id=data_id)) data_id=data_id))
continue continue
except TypeError as e:
# Error type in channeldata.pbdata.type
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id))
continue
except Exception as e:
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id))
continue
# midprocess # midprocess
call_future = None call_future = None
ecode = 0
error_info = None
if self.with_serving(): if self.with_serving():
ecode = ChannelDataEcode.OK.value
_profiler.record("{}-midp_0".format(op_info_prefix)) _profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0: if self._timeout <= 0:
try: try:
call_future = self.midprocess(data) call_future = self.midprocess(preped_data)
except Exception as e: except Exception as e:
logging.error(self._log(e))
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
logging.error(error_info) logging.error(error_info)
...@@ -512,15 +532,17 @@ class Op(object): ...@@ -512,15 +532,17 @@ class Op(object):
for i in range(self._retry): for i in range(self._retry):
try: try:
call_future = func_timeout.func_timeout( call_future = func_timeout.func_timeout(
self._timeout, self.midprocess, args=(data, )) self._timeout,
except func_timeout.FunctionTimedOut: self.midprocess,
args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
error_info = "{} timeout".format(op_info_prefix) error_info = log(e)
logging.error(error_info)
else: else:
logging.warn( logging.warn(
log("warn: timeout, retry({})".format(i + log("timeout, retry({})".format(i + 1)))
1)))
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
...@@ -528,7 +550,7 @@ class Op(object): ...@@ -528,7 +550,7 @@ class Op(object):
break break
else: else:
break break
if ecode != 0: if ecode != ChannelDataEcode.OK.value:
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, ecode=ecode, error_info=error_info,
...@@ -539,32 +561,63 @@ class Op(object): ...@@ -539,32 +561,63 @@ class Op(object):
# postprocess # postprocess
output_data = None output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix)) _profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving(): # use call_future if self.with_serving():
# use call_future
output_data = ChannelData( output_data = ChannelData(
future=call_future, future=call_future,
data_id=data_id, data_id=data_id,
callback_func=self.postprocess) callback_func=self.postprocess)
else: else:
post_data = self.postprocess(data) try:
if not isinstance(post_data, dict): postped_data = self.postprocess(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \ error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(post_data))) "dict type, but get {}".format(type(postped_data)))
logging.error(error_info) logging.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, ecode=ecode, error_info=error_info,
data_id=data_id)) data_id=data_id))
continue continue
ecode = ChannelDataEcode.OK.value
error_info = None
pbdata = channel_pb2.ChannelData() pbdata = channel_pb2.ChannelData()
for name, value in post_data.items(): for name, value in postped_data.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
inst = channel_pb2.Inst() inst = channel_pb2.Inst()
inst.data = value.tobytes() inst.data = value.tobytes()
inst.name = name inst.name = name
inst.shape = np.array(value.shape, dtype="int32").tobytes() inst.shape = np.array(value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype) inst.type = str(value.dtype)
pbdata.insts.append(inst) pbdata.insts.append(inst)
pbdata.ecode = 0 if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
continue
pbdata.ecode = ecode
pbdata.id = data_id pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata) output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}-postp_1".format(op_info_prefix)) _profiler.record("{}-postp_1".format(op_info_prefix))
...@@ -587,6 +640,45 @@ class Op(object): ...@@ -587,6 +640,45 @@ class Op(object):
return self._concurrency return self._concurrency
class VirtualOp(Op):
''' For connecting two channels. '''
def __init__(self, name, concurrency=1):
super(VirtualOp, self).__init__(
name=name, inputs=None, concurrency=concurrency)
self._virtual_pred_ops = []
def add_virtual_pred_op(self, op):
self._virtual_pred_ops.append(op)
def add_output_channel(self, channel):
if not isinstance(channel, Channel):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
for op in self._virtual_pred_ops:
channel.add_producer(op.name)
self._outputs.append(channel)
def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
_profiler.record("{}-push_0".format(op_info_prefix))
if isinstance(channeldata, dict):
for name, data in channeldata.items():
self._push_to_output_channels(data, name=name)
else:
self._push_to_output_channels(channeldata,
self._virtual_pred_ops[0].name)
_profiler.record("{}-push_1".format(op_info_prefix))
class GeneralPythonService( class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel, retry=2): def __init__(self, in_channel, out_channel, retry=2):
...@@ -668,35 +760,27 @@ class GeneralPythonService( ...@@ -668,35 +760,27 @@ class GeneralPythonService(
inst.name = name inst.name = name
inst.type = request.type[idx] inst.type = request.type[idx]
pbdata.insts.append(inst) pbdata.insts.append(inst)
pbdata.ecode = 0 #TODO: parse request error pbdata.ecode = ChannelDataEcode.OK.value #TODO: parse request error
return ChannelData(pbdata=pbdata), data_id return ChannelData(pbdata=pbdata), data_id
def _pack_data_for_resp(self, channeldata): def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata')) logging.debug(self._log('get channeldata'))
logging.debug(self._log('gen resp'))
resp = pyservice_pb2.Response() resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode resp.ecode = channeldata.pbdata.ecode
if resp.ecode == 0: if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value: if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts: for inst in channeldata.pbdata.insts:
logging.debug(self._log('append data'))
resp.fetch_insts.append(inst.data) resp.fetch_insts.append(inst.data)
logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name) resp.fetch_var_names.append(inst.name)
logging.debug(self._log('append shape'))
resp.shape.append(inst.shape) resp.shape.append(inst.shape)
logging.debug(self._log('append type'))
resp.type.append(inst.type) resp.type.append(inst.type)
elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value: elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
feed = channeldata.futures.result() feed = channeldata.futures.result()
if channeldata.callback_func is not None: if channeldata.callback_func is not None:
feed = channeldata.callback_func(feed) feed = channeldata.callback_func(feed)
for name, var in feed: for name, var in feed:
logging.debug(self._log('append data'))
resp.fetch_insts.append(var.tobytes()) resp.fetch_insts.append(var.tobytes())
logging.debug(self._log('append name'))
resp.fetch_var_names.append(name) resp.fetch_var_names.append(name)
logging.debug(self._log('append shape'))
resp.shape.append( resp.shape.append(
np.array( np.array(
var.shape, dtype="int32").tobytes()) var.shape, dtype="int32").tobytes())
...@@ -726,7 +810,7 @@ class GeneralPythonService( ...@@ -726,7 +810,7 @@ class GeneralPythonService(
resp_channeldata = self._get_data_in_globel_resp_dict(data_id) resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name)) _profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == 0: if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
break break
if i + 1 < self._retry: if i + 1 < self._retry:
logging.warn("retry({}): {}".format( logging.warn("retry({}): {}".format(
...@@ -743,7 +827,7 @@ class PyServer(object): ...@@ -743,7 +827,7 @@ class PyServer(object):
def __init__(self, retry=2, profile=False): def __init__(self, retry=2, profile=False):
self._channels = [] self._channels = []
self._user_ops = [] self._user_ops = []
self._total_ops = [] self._actual_ops = []
self._op_threads = [] self._op_threads = []
self._port = None self._port = None
self._worker_num = None self._worker_num = None
...@@ -767,9 +851,13 @@ class PyServer(object): ...@@ -767,9 +851,13 @@ class PyServer(object):
def _topo_sort(self): def _topo_sort(self):
indeg_num = {} indeg_num = {}
outdegs = {op.name: [] for op in self._user_ops}
que_idx = 0 # scroll queue que_idx = 0 # scroll queue
ques = [Queue.Queue() for _ in range(2)] ques = [Queue.Queue() for _ in range(2)]
for op in self._user_ops:
if len(op.get_input_ops()) == 0:
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
for idx, op in enumerate(self._user_ops): for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique # check the name of op is globally unique
if op.name in indeg_num: if op.name in indeg_num:
...@@ -780,7 +868,7 @@ class PyServer(object): ...@@ -780,7 +868,7 @@ class PyServer(object):
for pred_op in op.get_input_ops(): for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op) outdegs[pred_op.name].append(op)
# get dag_views # topo sort to get dag_views
dag_views = [] dag_views = []
sorted_op_num = 0 sorted_op_num = 0
while True: while True:
...@@ -790,9 +878,8 @@ class PyServer(object): ...@@ -790,9 +878,8 @@ class PyServer(object):
while que.qsize() != 0: while que.qsize() != 0:
op = que.get() op = que.get()
dag_view.append(op) dag_view.append(op)
op_name = op.name
sorted_op_num += 1 sorted_op_num += 1
for succ_op in outdegs[op_name]: for succ_op in outdegs[op.name]:
indeg_num[succ_op.name] -= 1 indeg_num[succ_op.name] -= 1
if indeg_num[succ_op.name] == 0: if indeg_num[succ_op.name] == 0:
next_que.put(succ_op) next_que.put(succ_op)
...@@ -808,52 +895,69 @@ class PyServer(object): ...@@ -808,52 +895,69 @@ class PyServer(object):
raise Exception("DAG contains multiple output Ops") raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops # create channels and virtual ops
virtual_op_idx = 0 def name_generator(prefix):
channel_idx = 0 def number_generator():
idx = 0
while True:
yield "{}{}".format(prefix, idx)
idx += 1
return number_generator()
virtual_op_name_gen = name_generator("vir")
channel_name_gen = name_generator("chl")
virtual_ops = [] virtual_ops = []
channels = [] channels = []
input_channel = None input_channel = None
actual_view = None
for v_idx, view in enumerate(dag_views): for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views): if v_idx + 1 >= len(dag_views):
break break
next_view = dag_views[v_idx + 1] next_view = dag_views[v_idx + 1]
if actual_view is None:
actual_view = view
actual_next_view = [] actual_next_view = []
pred_op_of_next_view_op = {} pred_op_of_next_view_op = {}
for op in view: for op in actual_view:
# create virtual op # find actual succ op in next view and create virtual op
for succ_op in outdegs[op.name]: for succ_op in outdegs[op.name]:
if succ_op in next_view: if succ_op in next_view:
actual_next_view.append(succ_op) if succ_op not in actual_next_view:
actual_next_view.append(succ_op)
if succ_op.name not in pred_op_of_next_view_op: if succ_op.name not in pred_op_of_next_view_op:
pred_op_of_next_view_op[succ_op.name] = [] pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op) pred_op_of_next_view_op[succ_op.name].append(op)
else: else:
vop = Op(name="vir{}".format(virtual_op_idx), inputs=[]) # create virtual op
virtual_op_idx += 1 virtual_op = None
virtual_op = VirtualOp(name=virtual_op_name_gen.next())
virtual_ops.append(virtual_op) virtual_ops.append(virtual_op)
outdegs[vop.name] = [succ_op] outdegs[virtual_op.name] = [succ_op]
actual_next_view.append(vop) actual_next_view.append(virtual_op)
# TODO: combine vop pred_op_of_next_view_op[virtual_op.name] = [op]
pred_op_of_next_view_op[vop.name] = [op] virtual_op.add_virtual_pred_op(op)
actual_view = actual_next_view
# create channel # create channel
processed_op = set() processed_op = set()
for o_idx, op in enumerate(actual_next_view): for o_idx, op in enumerate(actual_next_view):
op_name = op.name if op.name in processed_op:
if op_name in processed_op:
continue continue
channel = Channel(name="chl{}".format(channel_idx)) channel = Channel(name=channel_name_gen.next())
channel_idx += 1
channels.append(channel) channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel) op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op_name] pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0: if v_idx == 0:
input_channel = channel input_channel = channel
else: else:
# if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops: for pred_op in pred_ops:
logging.debug("{} => {}".format(pred_op.name,
channel.name))
pred_op.add_output_channel(channel) pred_op.add_output_channel(channel)
processed_op.add(op_name) processed_op.add(op.name)
# combine channel # find same input op to combine channel
for other_op in actual_next_view[o_idx:]: for other_op in actual_next_view[o_idx + 1:]:
if other_op.name in processed_op: if other_op.name in processed_op:
continue continue
other_pred_ops = pred_op_of_next_view_op[other_op.name] other_pred_ops = pred_op_of_next_view_op[other_op.name]
...@@ -865,19 +969,24 @@ class PyServer(object): ...@@ -865,19 +969,24 @@ class PyServer(object):
same_flag = False same_flag = False
break break
if same_flag: if same_flag:
logging.debug("{} => {}".format(channel.name,
other_op.name))
other_op.add_input_channel(channel) other_op.add_input_channel(channel)
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = Channel(name="Ochl") output_channel = Channel(name=channel_name_gen.next())
channels.append(output_channel) channels.append(output_channel)
last_op = dag_views[-1][0] last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
self._ops = virtual_ops self._actual_ops = virtual_ops
for op in self._user_ops: for op in self._user_ops:
if len(op.get_input_ops()) == 0: if len(op.get_input_ops()) == 0:
# pass read op
continue continue
self._ops.append(op) self._actual_ops.append(op)
self._channels = channels self._channels = channels
for c in channels:
logging.debug(c.debug())
return input_channel, output_channel return input_channel, output_channel
def prepare_server(self, port, worker_num): def prepare_server(self, port, worker_num):
...@@ -887,7 +996,7 @@ class PyServer(object): ...@@ -887,7 +996,7 @@ class PyServer(object):
input_channel, output_channel = self._topo_sort() input_channel, output_channel = self._topo_sort()
self._in_channel = input_channel self._in_channel = input_channel
self._out_channel = output_channel self._out_channel = output_channel
for op in self._ops: for op in self._actual_ops:
if op.with_serving(): if op.with_serving():
self.prepare_serving(op) self.prepare_serving(op)
self.gen_desc() self.gen_desc()
...@@ -896,7 +1005,7 @@ class PyServer(object): ...@@ -896,7 +1005,7 @@ class PyServer(object):
return op.start(concurrency_idx) return op.start(concurrency_idx)
def _run_ops(self): def _run_ops(self):
for op in self._ops: for op in self._actual_ops:
op_concurrency = op.get_concurrency() op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format( logging.debug("run op: {}, op_concurrency: {}".format(
op.name, op_concurrency)) op.name, op_concurrency))
...@@ -907,7 +1016,7 @@ class PyServer(object): ...@@ -907,7 +1016,7 @@ class PyServer(object):
self._op_threads.append(th) self._op_threads.append(th)
def _stop_ops(self): def _stop_ops(self):
for op in self._ops: for op in self._actual_ops:
op.stop() op.stop()
def run_server(self): def run_server(self):
...@@ -921,6 +1030,8 @@ class PyServer(object): ...@@ -921,6 +1030,8 @@ class PyServer(object):
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
self._stop_ops() # TODO self._stop_ops() # TODO
for th in self._op_threads:
th.join()
def prepare_serving(self, op): def prepare_serving(self, op):
model_path = op._server_model model_path = op._server_model
...@@ -935,5 +1046,3 @@ class PyServer(object): ...@@ -935,5 +1046,3 @@ class PyServer(object):
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port) " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
# run a server (not in PyServing) # run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd)) logging.info("run a server (not in PyServing): {}".format(cmd))
return
# os.system(cmd)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册