提交 f99da947 编写于 作者: B barrierye

bug fix

上级 e8d18527
...@@ -1004,6 +1004,7 @@ class PyServer(object): ...@@ -1004,6 +1004,7 @@ class PyServer(object):
op.name = "#G" # update read_op.name op.name = "#G" # update read_op.name
break break
outdegs = {op.name: [] for op in self._user_ops} outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops): for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique # check the name of op is globally unique
if op.name in indeg_num: if op.name in indeg_num:
...@@ -1011,8 +1012,16 @@ class PyServer(object): ...@@ -1011,8 +1012,16 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops()) indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0: if indeg_num[op.name] == 0:
ques[que_idx].put(op) ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops(): for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op) outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views # topo sort to get dag_views
dag_views = [] dag_views = []
...@@ -1035,10 +1044,6 @@ class PyServer(object): ...@@ -1035,10 +1044,6 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2 que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops): if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG") raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops # create channels and virtual ops
def name_generator(prefix): def name_generator(prefix):
...@@ -1121,7 +1126,7 @@ class PyServer(object): ...@@ -1121,7 +1126,7 @@ class PyServer(object):
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = Channel(self._manager, name=channel_name_gen.next()) output_channel = Channel(self._manager, name=channel_name_gen.next())
channels.append(output_channel) channels.append(output_channel)
last_op = dag_views[-1][0] # TODO: fix it last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
self._actual_ops = virtual_ops self._actual_ops = virtual_ops
......
...@@ -1004,6 +1004,7 @@ class PyServer(object): ...@@ -1004,6 +1004,7 @@ class PyServer(object):
op.name = "#G" # update read_op.name op.name = "#G" # update read_op.name
break break
outdegs = {op.name: [] for op in self._user_ops} outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops): for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique # check the name of op is globally unique
if op.name in indeg_num: if op.name in indeg_num:
...@@ -1011,8 +1012,16 @@ class PyServer(object): ...@@ -1011,8 +1012,16 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops()) indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0: if indeg_num[op.name] == 0:
ques[que_idx].put(op) ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops(): for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op) outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views # topo sort to get dag_views
dag_views = [] dag_views = []
...@@ -1035,10 +1044,6 @@ class PyServer(object): ...@@ -1035,10 +1044,6 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2 que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops): if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG") raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops # create channels and virtual ops
def name_generator(prefix): def name_generator(prefix):
...@@ -1121,7 +1126,7 @@ class PyServer(object): ...@@ -1121,7 +1126,7 @@ class PyServer(object):
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = Channel(self._manager, name=channel_name_gen.next()) output_channel = Channel(self._manager, name=channel_name_gen.next())
channels.append(output_channel) channels.append(output_channel)
last_op = dag_views[-1][0] # TODO: fix it last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
self._actual_ops = virtual_ops self._actual_ops = virtual_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册