From 70945df54de8c968505fbf5ba8701f3143ff5258 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 16 Jul 2019 15:00:34 +0800 Subject: [PATCH] test code --- x2paddle/core/graph.py | 8 +++++--- x2paddle/parser/tf_parser.py | 7 +------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 9689fe9..18aa683 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -60,15 +60,16 @@ class Graph(object): num_inputs = dict() for name, node in self.node_map.items(): num_inputs[name] = len(node.inputs) + print(len(self.node_map)) self.topo_sort = self.input_nodes[:] idx = 0 while idx < len(self.topo_sort): current_node = self.node_map[self.topo_sort[idx]] for node in current_node.outputs: - num_inputs[node.layer_name] -= 1 - if num_inputs[node.layer_name] == 0: - self.topo_sort.append(node.layer_name) + num_inputs[node] -= 1 + if num_inputs[node] == 0: + self.topo_sort.append(node) idx += 1 for i, tmp in enumerate(self.topo_sort): @@ -84,3 +85,4 @@ class Graph(object): if dst not in self.node_map: raise Exception("node[{}] not in graph".format(dst)) self.node_map[dst].inputs.append(src) + self.node_map[src].outputs.append(dst) diff --git a/x2paddle/parser/tf_parser.py b/x2paddle/parser/tf_parser.py index 5db5055..ef7b9fc 100644 --- a/x2paddle/parser/tf_parser.py +++ b/x2paddle/parser/tf_parser.py @@ -29,9 +29,6 @@ class TFGraphNode(GraphNode): class TFGraph(Graph): def __init__(self, model): super(TFGraph, self).__init__(model) - self.multi_output_ops = [ - 'Split', - 'Unpack'] def build(self): for layer in self.model.node: @@ -41,12 +38,10 @@ class TFGraph(Graph): for in_node in node.layer.input: if in_node not in self.node_map: if in_node.strip().split(':')[0] in self.node_map: - self.connect(in_node, layer_name) + self.connect(in_node.strip().split(':')[0], layer_name) else: raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name)) else: - if self.node_map[in_node].layer_type in self.multi_output_ops: - in_node += ":0" self.connect(in_node, layer_name) super(TFGraph, self).build() -- GitLab