提交 f99da947 编写于 作者: B barrierye

bug fix

上级 e8d18527
......@@ -1004,6 +1004,7 @@ class PyServer(object):
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
......@@ -1011,8 +1012,16 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views
dag_views = []
......@@ -1035,10 +1044,6 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
def name_generator(prefix):
......@@ -1121,7 +1126,7 @@ class PyServer(object):
processed_op.add(other_op.name)
output_channel = Channel(self._manager, name=channel_name_gen.next())
channels.append(output_channel)
last_op = dag_views[-1][0] # TODO: fix it
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
self._actual_ops = virtual_ops
......
......@@ -1004,6 +1004,7 @@ class PyServer(object):
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
......@@ -1011,8 +1012,16 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views
dag_views = []
......@@ -1035,10 +1044,6 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
def name_generator(prefix):
......@@ -1121,7 +1126,7 @@ class PyServer(object):
processed_op.add(other_op.name)
output_channel = Channel(self._manager, name=channel_name_gen.next())
channels.append(output_channel)
last_op = dag_views[-1][0] # TODO: fix it
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
self._actual_ops = virtual_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册