未验证 提交 2c399c08 编写于 作者: N Nicky Chan 提交者: GitHub

Add download ONNX model in scratch_log, clean up graph image logic (#473)

上级 66c979ef
......@@ -197,10 +197,9 @@ with logw.mode("train") as logger:
embedding.add_embeddings_with_word_dict(hot_vectors, word_dict)
def download_graph_image():
def download_onnx():
'''
This is a scratch demo, it do not generate a ONNX proto, but just download an image
that generated before to show how the graph frontend works.
This is a scratch demo, it do not generate a ONNX proto, but just download a prebuilt ONNX file.
For real cases, just refer to README.
'''
......@@ -218,12 +217,12 @@ def download_graph_image():
myssl = ssl.create_default_context()
myssl.check_hostname = False
myssl.verify_mode = ssl.CERT_NONE
image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
log.warning('download graph demo from {}'.format(image_url))
graph_image = ur.urlopen(image_url, context=myssl).read()
with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
f.write(graph_image)
log.warning('graph ready!')
onnx_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mnist_model.onnx?raw=true"
log.warning('download ONNX file from {}'.format(onnx_url))
onnx_model = ur.urlopen(onnx_url, context=myssl).read()
with open(os.path.join(logdir, 'mnist_model.onnx'), 'wb') as f:
f.write(onnx_model)
log.warning('ONNX model ready! use visualdl --logdir=scratch_log -m scratch_log/mnist_model.onnx to launch')
download_graph_image()
download_onnx()
......@@ -486,7 +486,7 @@ class GraphPreviewGenerator(object):
return self.graph.edge(source, target, **kwargs)
def draw_graph(model_pb_path, image_dir):
def draw_graph(model_pb_path):
json_str = load_model(model_pb_path)
return json_str
......
......@@ -299,29 +299,14 @@ def individual_audio():
@app.route('/data/plugin/graphs/graph')
def graph():
# TODO(daming-lu): rename variables because we switched from static graph generated by graphviz
# to d3/dagre drawn svg
if graph_image_path is not None:
data = {
'data': graph_image_path,
}
else:
image_dir = os.path.join(args.logdir, "graphs")
if not os.path.isdir(image_dir):
os.mkdir(image_dir)
json_str = vdl_graph.draw_graph(args.model_pb, image_dir)
data = {'data': json_str}
json_str = vdl_graph.draw_graph(args.model_pb)
data = {'data': json_str}
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
if __name__ == '__main__':
logger.info(" port=" + str(args.port))
if args.model_pb:
# draw graph
image_dir = os.path.join(args.logdir, "graphs")
if not os.path.isdir(image_dir):
os.mkdir(image_dir)
graph_image_path = vdl_graph.draw_graph(args.model_pb, image_dir)
app.run(debug=False, host=args.host, port=args.port, threaded=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册