提交 8cc4bdb4 编写于 作者: Q qiaolongfei

fix graph_test

上级 834aedc7
......@@ -399,6 +399,7 @@ def transform_for_echars(model_json):
option['title']['text'] = model_json['name']
rename_model(model_json)
node_links = get_node_links(model_json)
add_level_to_node_links(node_links)
level_to_all = get_level_to_all(node_links, model_json)
......@@ -446,7 +447,7 @@ def transform_for_echars(model_json):
return option
def load_model(model_pb_path):
def to_IR_json(model_pb_path):
model = onnx.load(model_pb_path)
graph = model.graph
del graph.initializer[:]
......@@ -456,10 +457,12 @@ def load_model(model_pb_path):
model_json = json.loads(json_str)
reorganize_inout(model_json, 'input')
reorganize_inout(model_json, 'output')
rename_model(model_json)
# debug_print(model_json)
return model_json
def load_model(model_pb_path):
model_json = to_IR_json(model_pb_path)
options = transform_for_echars(model_json)
# debug_print(options)
return options
......
......@@ -8,8 +8,7 @@ class GraphTest(unittest.TestCase):
self.mock_dir = "./mock"
def test_graph_edges_squeezenet(self):
json_str = graph.load_model(self.mock_dir + '/squeezenet_model.pb')
json_obj = json.loads(json_str)
json_obj = graph.to_IR_json(self.mock_dir + '/squeezenet_model.pb')
# 126 edges + 66 nodes (out-edge of each node is counted twice)
self.assertEqual(len(json_obj['edges']), 126 + 66)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册