diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 9689fe9c65106b9b5257bbf208095704cbb1b512..18aa6831bf313ea19d629c2941357d607d468021 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -60,15 +60,16 @@ class Graph(object): num_inputs = dict() for name, node in self.node_map.items(): num_inputs[name] = len(node.inputs) + print(len(self.node_map)) self.topo_sort = self.input_nodes[:] idx = 0 while idx < len(self.topo_sort): current_node = self.node_map[self.topo_sort[idx]] for node in current_node.outputs: - num_inputs[node.layer_name] -= 1 - if num_inputs[node.layer_name] == 0: - self.topo_sort.append(node.layer_name) + num_inputs[node] -= 1 + if num_inputs[node] == 0: + self.topo_sort.append(node) idx += 1 for i, tmp in enumerate(self.topo_sort): @@ -84,3 +85,4 @@ class Graph(object): if dst not in self.node_map: raise Exception("node[{}] not in graph".format(dst)) self.node_map[dst].inputs.append(src) + self.node_map[src].outputs.append(dst) diff --git a/x2paddle/parser/tf_parser.py b/x2paddle/parser/tf_parser.py index 5db5055372f47356b23cac7bf3307a4848f95c3c..ef7b9fc201792c6623dacc3bc0936bbd333138ba 100644 --- a/x2paddle/parser/tf_parser.py +++ b/x2paddle/parser/tf_parser.py @@ -29,9 +29,6 @@ class TFGraphNode(GraphNode): class TFGraph(Graph): def __init__(self, model): super(TFGraph, self).__init__(model) - self.multi_output_ops = [ - 'Split', - 'Unpack'] def build(self): for layer in self.model.node: @@ -41,12 +38,10 @@ class TFGraph(Graph): for in_node in node.layer.input: if in_node not in self.node_map: if in_node.strip().split(':')[0] in self.node_map: - self.connect(in_node, layer_name) + self.connect(in_node.strip().split(':')[0], layer_name) else: raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name)) else: - if self.node_map[in_node].layer_type in self.multi_output_ops: - in_node += ":0" self.connect(in_node, layer_name) super(TFGraph, self).build()