提交 4f8f9801 编写于 作者: B barrierye

fix bug

上级 6498b3fd
...@@ -37,7 +37,7 @@ class CombineOp(Op): ...@@ -37,7 +37,7 @@ class CombineOp(Op):
return data return data
read_op = Op(name="read", input=None) read_op = Op(name="read", inputs=None)
uci1_op = Op(name="uci1", uci1_op = Op(name="uci1",
inputs=[read_op], inputs=[read_op],
server_model="./uci_housing_model", server_model="./uci_housing_model",
......
...@@ -686,10 +686,6 @@ class GeneralPythonService( ...@@ -686,10 +686,6 @@ class GeneralPythonService(
return resp return resp
class VirtualOp(Op):
pass
class PyServer(object): class PyServer(object):
def __init__(self, retry=2, profile=False): def __init__(self, retry=2, profile=False):
self._channels = [] self._channels = []
...@@ -710,7 +706,7 @@ class PyServer(object): ...@@ -710,7 +706,7 @@ class PyServer(object):
self._user_ops.append(op) self._user_ops.append(op)
def add_ops(self, ops): def add_ops(self, ops):
self._user_ops.expand(ops) self._user_ops.extend(ops)
def gen_desc(self): def gen_desc(self):
logging.info('here will generate desc for PAAS') logging.info('here will generate desc for PAAS')
...@@ -718,9 +714,9 @@ class PyServer(object): ...@@ -718,9 +714,9 @@ class PyServer(object):
def _topo_sort(self): def _topo_sort(self):
indeg_num = {} indeg_num = {}
outdegs = {} outdegs = {op.name: [] for op in self._user_ops}
que_idx = 0 # scroll queue que_idx = 0 # scroll queue
ques = [Queue.SimpleQueue() for _ in range(2)] ques = [Queue.Queue() for _ in range(2)]
for idx, op in enumerate(self._user_ops): for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique # check the name of op is globally unique
if op.name in indeg_num: if op.name in indeg_num:
...@@ -729,10 +725,7 @@ class PyServer(object): ...@@ -729,10 +725,7 @@ class PyServer(object):
if indeg_num[op.name] == 0: if indeg_num[op.name] == 0:
ques[que_idx].put(op) ques[que_idx].put(op)
for pred_op in op.get_input_ops(): for pred_op in op.get_input_ops():
if op.name in outdegs: outdegs[pred_op.name].append(op)
outdegs[op.name].append(op)
else:
outdegs[op.name] = [op]
# get dag_views # get dag_views
dag_views = [] dag_views = []
...@@ -747,7 +740,7 @@ class PyServer(object): ...@@ -747,7 +740,7 @@ class PyServer(object):
op_name = op.name op_name = op.name
sorted_op_num += 1 sorted_op_num += 1
for succ_op in outdegs[op_name]: for succ_op in outdegs[op_name]:
indeg_num[op_name] -= 1 indeg_num[succ_op.name] -= 1
if indeg_num[succ_op.name] == 0: if indeg_num[succ_op.name] == 0:
next_que.put(succ_op) next_que.put(succ_op)
dag_views.append(dag_view) dag_views.append(dag_view)
...@@ -782,7 +775,7 @@ class PyServer(object): ...@@ -782,7 +775,7 @@ class PyServer(object):
pred_op_of_next_view_op[succ_op.name] = [] pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op) pred_op_of_next_view_op[succ_op.name].append(op)
else: else:
vop = VirtualOp(name="vir{}".format(virtual_op_idx)) vop = Op(name="vir{}".format(virtual_op_idx), inputs=[])
virtual_op_idx += 1 virtual_op_idx += 1
virtual_ops.append(virtual_op) virtual_ops.append(virtual_op)
outdegs[vop.name] = [succ_op] outdegs[vop.name] = [succ_op]
...@@ -826,7 +819,11 @@ class PyServer(object): ...@@ -826,7 +819,11 @@ class PyServer(object):
last_op = dag_views[-1][0] last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
self._ops = self._user_ops + virtual_ops self._ops = virtual_ops
for op in self._user_ops:
if len(op.get_input_ops()) == 0:
continue
self._ops.append(op)
self._channels = channels self._channels = channels
return input_channel, output_channel return input_channel, output_channel
...@@ -836,7 +833,10 @@ class PyServer(object): ...@@ -836,7 +833,10 @@ class PyServer(object):
input_channel, output_channel = self._topo_sort() input_channel, output_channel = self._topo_sort()
self._in_channel = input_channel self._in_channel = input_channel
self.out_channel = output_channel self._out_channel = output_channel
for op in self._ops:
if op.with_serving():
self.prepare_serving(op)
self.gen_desc() self.gen_desc()
def _op_start_wrapper(self, op, concurrency_idx): def _op_start_wrapper(self, op, concurrency_idx):
...@@ -849,14 +849,12 @@ class PyServer(object): ...@@ -849,14 +849,12 @@ class PyServer(object):
logging.debug("run op: {}, op_concurrency: {}".format( logging.debug("run op: {}, op_concurrency: {}".format(
op.name, op_concurrency)) op.name, op_concurrency))
for c in range(op_concurrency): for c in range(op_concurrency):
# th = multiprocessing.Process(
th = threading.Thread( th = threading.Thread(
target=self._op_start_wrapper, args=(op, c)) target=self._op_start_wrapper, args=(op, c))
th.start() th.start()
self._op_threads.append(th) self._op_threads.append(th)
def _stop_ops(self): def _stop_ops(self):
# TODO
for op in self._ops: for op in self._ops:
op.stop() op.stop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册