提交 32ccedb4 编写于 作者: Q qiaolongfei

fix graph_test, clean code

上级 51940612
......@@ -54,6 +54,7 @@ def rename_model(model_json):
+ '\nshape=' + str(new_shape)
variable['name'] = new_name
rename_edge(model, old_name, new_name)
rename_variables(model_json, model_json['input'])
rename_variables(model_json, model_json['output'])
......@@ -79,9 +80,6 @@ def get_links(model_json):
if name in node['input']:
links.append({'source': name,
"target": node['name']})
# links.append({'source': name,
# "target": node['name'],
# "label": name})
for source_node in model_json['node']:
for output in source_node['output']:
......@@ -89,9 +87,6 @@ def get_links(model_json):
if output in target_node['input']:
links.append({'source': source_node['name'],
'target': target_node['name']})
# links.append({'source': source_node['name'],
# 'target': target_node['name'],
# 'label': output})
return links
......@@ -174,7 +169,6 @@ def add_level_to_node_links(node_links):
assert in_level is not None
if cur_level is None or in_level >= cur_level:
node_links[idx]['level'] = in_level + 1
# debug_print(node_links)
def get_level_to_all(node_links, model_json):
......@@ -249,7 +243,6 @@ def get_level_to_all(node_links, model_json):
if level not in level_to_outputs:
level_to_outputs[level] = list()
level_to_outputs[level].append(out_idx)
# debug_print(level_to_outputs)
level_to_all = dict()
......@@ -302,13 +295,9 @@ def level_to_coordinate(level_to_all):
output_to_coordinate[out_idx] = get_coordinate(x_idx, level)
x_idx += 1
# debug_print(node_to_coordinate)
# debug_print(input_to_coordinate)
# debug_print(output_to_coordinate)
return node_to_coordinate, input_to_coordinate, output_to_coordinate
def add_edges(json_obj):
# TODO(daming-lu): should try to de-duplicate node's out-edge
# Currently it is counted twice: 1 as out-edge, 1 as in-edge
......
......@@ -9,6 +9,7 @@ class GraphTest(unittest.TestCase):
def test_graph_edges_squeezenet(self):
json_obj = graph.to_IR_json(self.mock_dir + '/squeezenet_model.pb')
json_obj = graph.add_edges(json_obj)
# 126 edges + 66 nodes (out-edge of each node is counted twice)
self.assertEqual(len(json_obj['edges']), 126 + 66)
......@@ -39,6 +40,7 @@ class GraphTest(unittest.TestCase):
def test_graph_edges_inception_v1(self):
json_obj = graph.to_IR_json(self.mock_dir + '/inception_v1_model.pb')
json_obj = graph.add_edges(json_obj)
# 286 edges + 143 nodes (out-edge of each node is counted twice)
self.assertEqual(len(json_obj['edges']), 286 + 143)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册