diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index e9bc43be7d6f7131526e3f38ff9422f1c2c42143..1a776ecaefd528cb6f3161c215c5b5cae1e7cb07 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -1004,6 +1004,7 @@ class PyServer(object): op.name = "#G" # update read_op.name break outdegs = {op.name: [] for op in self._user_ops} + zero_indeg_num, zero_outdeg_num = 0, 0 for idx, op in enumerate(self._user_ops): # check the name of op is globally unique if op.name in indeg_num: @@ -1011,8 +1012,16 @@ class PyServer(object): indeg_num[op.name] = len(op.get_input_ops()) if indeg_num[op.name] == 0: ques[que_idx].put(op) + zero_indeg_num += 1 for pred_op in op.get_input_ops(): outdegs[pred_op.name].append(op) + if zero_indeg_num != 1: + raise Exception("DAG contains multiple input Ops") + for _, succ_list in outdegs.items(): + if len(succ_list) == 0: + zero_outdeg_num += 1 + if zero_outdeg_num != 1: + raise Exception("DAG contains multiple output Ops") # topo sort to get dag_views dag_views = [] @@ -1035,10 +1044,6 @@ class PyServer(object): que_idx = (que_idx + 1) % 2 if sorted_op_num < len(self._user_ops): raise Exception("not legal DAG") - if len(dag_views[0]) != 1: - raise Exception("DAG contains multiple input Ops") - if len(dag_views[-1]) != 1: - raise Exception("DAG contains multiple output Ops") # create channels and virtual ops def name_generator(prefix): @@ -1121,7 +1126,7 @@ class PyServer(object): processed_op.add(other_op.name) output_channel = Channel(self._manager, name=channel_name_gen.next()) channels.append(output_channel) - last_op = dag_views[-1][0] # TODO: fix it + last_op = dag_views[-1][0] last_op.add_output_channel(output_channel) self._actual_ops = virtual_ops diff --git a/python/paddle_serving_server_gpu/pyserver.py b/python/paddle_serving_server_gpu/pyserver.py index e9bc43be7d6f7131526e3f38ff9422f1c2c42143..1a776ecaefd528cb6f3161c215c5b5cae1e7cb07 100644 --- a/python/paddle_serving_server_gpu/pyserver.py +++ b/python/paddle_serving_server_gpu/pyserver.py @@ -1004,6 +1004,7 @@ class PyServer(object): op.name = "#G" # update read_op.name break outdegs = {op.name: [] for op in self._user_ops} + zero_indeg_num, zero_outdeg_num = 0, 0 for idx, op in enumerate(self._user_ops): # check the name of op is globally unique if op.name in indeg_num: @@ -1011,8 +1012,16 @@ class PyServer(object): indeg_num[op.name] = len(op.get_input_ops()) if indeg_num[op.name] == 0: ques[que_idx].put(op) + zero_indeg_num += 1 for pred_op in op.get_input_ops(): outdegs[pred_op.name].append(op) + if zero_indeg_num != 1: + raise Exception("DAG contains multiple input Ops") + for _, succ_list in outdegs.items(): + if len(succ_list) == 0: + zero_outdeg_num += 1 + if zero_outdeg_num != 1: + raise Exception("DAG contains multiple output Ops") # topo sort to get dag_views dag_views = [] @@ -1035,10 +1044,6 @@ class PyServer(object): que_idx = (que_idx + 1) % 2 if sorted_op_num < len(self._user_ops): raise Exception("not legal DAG") - if len(dag_views[0]) != 1: - raise Exception("DAG contains multiple input Ops") - if len(dag_views[-1]) != 1: - raise Exception("DAG contains multiple output Ops") # create channels and virtual ops def name_generator(prefix): @@ -1121,7 +1126,7 @@ class PyServer(object): processed_op.add(other_op.name) output_channel = Channel(self._manager, name=channel_name_gen.next()) channels.append(output_channel) - last_op = dag_views[-1][0] # TODO: fix it + last_op = dag_views[-1][0] last_op.add_output_channel(output_channel) self._actual_ops = virtual_ops