From 2c399c084c19754039c65fb2975e184df95a6463 Mon Sep 17 00:00:00 2001 From: Nicky Chan Date: Thu, 19 Jul 2018 16:03:31 -0700 Subject: [PATCH] Add download ONNX model in scratch_log, clean up graph image logic (#473) --- demo/vdl_create_scratch_log | 19 +++++++++---------- visualdl/server/graph.py | 2 +- visualdl/server/visualdl | 21 +++------------------ 3 files changed, 13 insertions(+), 29 deletions(-) diff --git a/demo/vdl_create_scratch_log b/demo/vdl_create_scratch_log index cb60e722..e860e133 100755 --- a/demo/vdl_create_scratch_log +++ b/demo/vdl_create_scratch_log @@ -197,10 +197,9 @@ with logw.mode("train") as logger: embedding.add_embeddings_with_word_dict(hot_vectors, word_dict) -def download_graph_image(): +def download_onnx(): ''' - This is a scratch demo, it do not generate a ONNX proto, but just download an image - that generated before to show how the graph frontend works. + This is a scratch demo, it do not generate a ONNX proto, but just download a prebuilt ONNX file. For real cases, just refer to README. ''' @@ -218,12 +217,12 @@ def download_graph_image(): myssl = ssl.create_default_context() myssl.check_hostname = False myssl.verify_mode = ssl.CERT_NONE - image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true" - log.warning('download graph demo from {}'.format(image_url)) - graph_image = ur.urlopen(image_url, context=myssl).read() - with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f: - f.write(graph_image) - log.warning('graph ready!') + onnx_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mnist_model.onnx?raw=true" + log.warning('download ONNX file from {}'.format(onnx_url)) + onnx_model = ur.urlopen(onnx_url, context=myssl).read() + with open(os.path.join(logdir, 'mnist_model.onnx'), 'wb') as f: + f.write(onnx_model) + log.warning('ONNX model ready! use visualdl --logdir=scratch_log -m scratch_log/mnist_model.onnx to launch') -download_graph_image() +download_onnx() diff --git a/visualdl/server/graph.py b/visualdl/server/graph.py index 0fc6154d..0e9cf2c8 100644 --- a/visualdl/server/graph.py +++ b/visualdl/server/graph.py @@ -486,7 +486,7 @@ class GraphPreviewGenerator(object): return self.graph.edge(source, target, **kwargs) -def draw_graph(model_pb_path, image_dir): +def draw_graph(model_pb_path): json_str = load_model(model_pb_path) return json_str diff --git a/visualdl/server/visualdl b/visualdl/server/visualdl index eceaddf3..9bb80588 100644 --- a/visualdl/server/visualdl +++ b/visualdl/server/visualdl @@ -299,29 +299,14 @@ def individual_audio(): @app.route('/data/plugin/graphs/graph') def graph(): - # TODO(daming-lu): rename variables because we switched from static graph generated by graphviz - # to d3/dagre drawn svg - if graph_image_path is not None: - data = { - 'data': graph_image_path, - } - else: - image_dir = os.path.join(args.logdir, "graphs") - if not os.path.isdir(image_dir): - os.mkdir(image_dir) - json_str = vdl_graph.draw_graph(args.model_pb, image_dir) - data = {'data': json_str} + json_str = vdl_graph.draw_graph(args.model_pb) + data = {'data': json_str} + result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') if __name__ == '__main__': logger.info(" port=" + str(args.port)) - if args.model_pb: - # draw graph - image_dir = os.path.join(args.logdir, "graphs") - if not os.path.isdir(image_dir): - os.mkdir(image_dir) - graph_image_path = vdl_graph.draw_graph(args.model_pb, image_dir) app.run(debug=False, host=args.host, port=args.port, threaded=True) -- GitLab