From 8cc4bdb49a21ab764d5ed01457eedb28cf8e0069 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 8 Jan 2018 14:33:33 +0800 Subject: [PATCH] fix graph_test --- server/visualdl/graph.py | 11 +++++++---- server/visualdl/graph_test.py | 3 +-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/server/visualdl/graph.py b/server/visualdl/graph.py index 523ebf10..c95de020 100644 --- a/server/visualdl/graph.py +++ b/server/visualdl/graph.py @@ -399,6 +399,7 @@ def transform_for_echars(model_json): option['title']['text'] = model_json['name'] + rename_model(model_json) node_links = get_node_links(model_json) add_level_to_node_links(node_links) level_to_all = get_level_to_all(node_links, model_json) @@ -446,7 +447,7 @@ def transform_for_echars(model_json): return option -def load_model(model_pb_path): +def to_IR_json(model_pb_path): model = onnx.load(model_pb_path) graph = model.graph del graph.initializer[:] @@ -456,10 +457,12 @@ def load_model(model_pb_path): model_json = json.loads(json_str) reorganize_inout(model_json, 'input') reorganize_inout(model_json, 'output') - rename_model(model_json) - # debug_print(model_json) + return model_json + + +def load_model(model_pb_path): + model_json = to_IR_json(model_pb_path) options = transform_for_echars(model_json) - # debug_print(options) return options diff --git a/server/visualdl/graph_test.py b/server/visualdl/graph_test.py index 6917afe9..11d30029 100644 --- a/server/visualdl/graph_test.py +++ b/server/visualdl/graph_test.py @@ -8,8 +8,7 @@ class GraphTest(unittest.TestCase): self.mock_dir = "./mock" def test_graph_edges_squeezenet(self): - json_str = graph.load_model(self.mock_dir + '/squeezenet_model.pb') - json_obj = json.loads(json_str) + json_obj = graph.to_IR_json(self.mock_dir + '/squeezenet_model.pb') # 126 edges + 66 nodes (out-edge of each node is counted twice) self.assertEqual(len(json_obj['edges']), 126 + 66) -- GitLab