提交 67ca0f84 编写于 作者: B barrierye

fix bug

上级 31b11557
......@@ -37,7 +37,7 @@ class CombineOp(Op):
return data
read_op = Op(name="read", input=None)
read_op = Op(name="read", inputs=None)
uci1_op = Op(name="uci1",
inputs=[read_op],
server_model="./uci_housing_model",
......
......@@ -686,10 +686,6 @@ class GeneralPythonService(
return resp
class VirtualOp(Op):
pass
class PyServer(object):
def __init__(self, retry=2, profile=False):
self._channels = []
......@@ -710,7 +706,7 @@ class PyServer(object):
self._user_ops.append(op)
def add_ops(self, ops):
self._user_ops.expand(ops)
self._user_ops.extend(ops)
def gen_desc(self):
logging.info('here will generate desc for PAAS')
......@@ -718,9 +714,9 @@ class PyServer(object):
def _topo_sort(self):
indeg_num = {}
outdegs = {}
outdegs = {op.name: [] for op in self._user_ops}
que_idx = 0 # scroll queue
ques = [Queue.SimpleQueue() for _ in range(2)]
ques = [Queue.Queue() for _ in range(2)]
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
......@@ -729,10 +725,7 @@ class PyServer(object):
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
for pred_op in op.get_input_ops():
if op.name in outdegs:
outdegs[op.name].append(op)
else:
outdegs[op.name] = [op]
outdegs[pred_op.name].append(op)
# get dag_views
dag_views = []
......@@ -747,7 +740,7 @@ class PyServer(object):
op_name = op.name
sorted_op_num += 1
for succ_op in outdegs[op_name]:
indeg_num[op_name] -= 1
indeg_num[succ_op.name] -= 1
if indeg_num[succ_op.name] == 0:
next_que.put(succ_op)
dag_views.append(dag_view)
......@@ -782,7 +775,7 @@ class PyServer(object):
pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op)
else:
vop = VirtualOp(name="vir{}".format(virtual_op_idx))
vop = Op(name="vir{}".format(virtual_op_idx), inputs=[])
virtual_op_idx += 1
virtual_ops.append(virtual_op)
outdegs[vop.name] = [succ_op]
......@@ -826,7 +819,11 @@ class PyServer(object):
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
self._ops = self._user_ops + virtual_ops
self._ops = virtual_ops
for op in self._user_ops:
if len(op.get_input_ops()) == 0:
continue
self._ops.append(op)
self._channels = channels
return input_channel, output_channel
......@@ -836,7 +833,10 @@ class PyServer(object):
input_channel, output_channel = self._topo_sort()
self._in_channel = input_channel
self.out_channel = output_channel
self._out_channel = output_channel
for op in self._ops:
if op.with_serving():
self.prepare_serving(op)
self.gen_desc()
def _op_start_wrapper(self, op, concurrency_idx):
......@@ -849,14 +849,12 @@ class PyServer(object):
logging.debug("run op: {}, op_concurrency: {}".format(
op.name, op_concurrency))
for c in range(op_concurrency):
# th = multiprocessing.Process(
th = threading.Thread(
target=self._op_start_wrapper, args=(op, c))
th.start()
self._op_threads.append(th)
def _stop_ops(self):
# TODO
for op in self._ops:
op.stop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册