提交 8cc4bdb4 编写于 作者: Q qiaolongfei

fix graph_test

上级 834aedc7
...@@ -399,6 +399,7 @@ def transform_for_echars(model_json): ...@@ -399,6 +399,7 @@ def transform_for_echars(model_json):
option['title']['text'] = model_json['name'] option['title']['text'] = model_json['name']
rename_model(model_json)
node_links = get_node_links(model_json) node_links = get_node_links(model_json)
add_level_to_node_links(node_links) add_level_to_node_links(node_links)
level_to_all = get_level_to_all(node_links, model_json) level_to_all = get_level_to_all(node_links, model_json)
...@@ -446,7 +447,7 @@ def transform_for_echars(model_json): ...@@ -446,7 +447,7 @@ def transform_for_echars(model_json):
return option return option
def load_model(model_pb_path): def to_IR_json(model_pb_path):
model = onnx.load(model_pb_path) model = onnx.load(model_pb_path)
graph = model.graph graph = model.graph
del graph.initializer[:] del graph.initializer[:]
...@@ -456,10 +457,12 @@ def load_model(model_pb_path): ...@@ -456,10 +457,12 @@ def load_model(model_pb_path):
model_json = json.loads(json_str) model_json = json.loads(json_str)
reorganize_inout(model_json, 'input') reorganize_inout(model_json, 'input')
reorganize_inout(model_json, 'output') reorganize_inout(model_json, 'output')
rename_model(model_json) return model_json
# debug_print(model_json)
def load_model(model_pb_path):
model_json = to_IR_json(model_pb_path)
options = transform_for_echars(model_json) options = transform_for_echars(model_json)
# debug_print(options)
return options return options
......
...@@ -8,8 +8,7 @@ class GraphTest(unittest.TestCase): ...@@ -8,8 +8,7 @@ class GraphTest(unittest.TestCase):
self.mock_dir = "./mock" self.mock_dir = "./mock"
def test_graph_edges_squeezenet(self): def test_graph_edges_squeezenet(self):
json_str = graph.load_model(self.mock_dir + '/squeezenet_model.pb') json_obj = graph.to_IR_json(self.mock_dir + '/squeezenet_model.pb')
json_obj = json.loads(json_str)
# 126 edges + 66 nodes (out-edge of each node is counted twice) # 126 edges + 66 nodes (out-edge of each node is counted twice)
self.assertEqual(len(json_obj['edges']), 126 + 66) self.assertEqual(len(json_obj['edges']), 126 + 66)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册