提交 70945df5 编写于 作者: J jiangjiajun

test code

上级 9fd4bba6
...@@ -60,15 +60,16 @@ class Graph(object): ...@@ -60,15 +60,16 @@ class Graph(object):
num_inputs = dict() num_inputs = dict()
for name, node in self.node_map.items(): for name, node in self.node_map.items():
num_inputs[name] = len(node.inputs) num_inputs[name] = len(node.inputs)
print(len(self.node_map))
self.topo_sort = self.input_nodes[:] self.topo_sort = self.input_nodes[:]
idx = 0 idx = 0
while idx < len(self.topo_sort): while idx < len(self.topo_sort):
current_node = self.node_map[self.topo_sort[idx]] current_node = self.node_map[self.topo_sort[idx]]
for node in current_node.outputs: for node in current_node.outputs:
num_inputs[node.layer_name] -= 1 num_inputs[node] -= 1
if num_inputs[node.layer_name] == 0: if num_inputs[node] == 0:
self.topo_sort.append(node.layer_name) self.topo_sort.append(node)
idx += 1 idx += 1
for i, tmp in enumerate(self.topo_sort): for i, tmp in enumerate(self.topo_sort):
...@@ -84,3 +85,4 @@ class Graph(object): ...@@ -84,3 +85,4 @@ class Graph(object):
if dst not in self.node_map: if dst not in self.node_map:
raise Exception("node[{}] not in graph".format(dst)) raise Exception("node[{}] not in graph".format(dst))
self.node_map[dst].inputs.append(src) self.node_map[dst].inputs.append(src)
self.node_map[src].outputs.append(dst)
...@@ -29,9 +29,6 @@ class TFGraphNode(GraphNode): ...@@ -29,9 +29,6 @@ class TFGraphNode(GraphNode):
class TFGraph(Graph): class TFGraph(Graph):
def __init__(self, model): def __init__(self, model):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.multi_output_ops = [
'Split',
'Unpack']
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
...@@ -41,12 +38,10 @@ class TFGraph(Graph): ...@@ -41,12 +38,10 @@ class TFGraph(Graph):
for in_node in node.layer.input: for in_node in node.layer.input:
if in_node not in self.node_map: if in_node not in self.node_map:
if in_node.strip().split(':')[0] in self.node_map: if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node, layer_name) self.connect(in_node.strip().split(':')[0], layer_name)
else: else:
raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name)) raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name))
else: else:
if self.node_map[in_node].layer_type in self.multi_output_ops:
in_node += ":0"
self.connect(in_node, layer_name) self.connect(in_node, layer_name)
super(TFGraph, self).build() super(TFGraph, self).build()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册