未验证 提交 6ff9620c 编写于 作者: D daminglu 提交者: GitHub

Fix graph bugs (#391)

上级 f3c0a861
...@@ -19,6 +19,7 @@ import * as svgToPngDownloadHelper from './svgToPngDownloadHelper.js'; ...@@ -19,6 +19,7 @@ import * as svgToPngDownloadHelper from './svgToPngDownloadHelper.js';
import * as d3 from 'd3'; import * as d3 from 'd3';
import has from 'lodash/has'; import has from 'lodash/has';
import isArrayLike from 'lodash/isArrayLike';
export default { export default {
props: { props: {
...@@ -74,9 +75,11 @@ export default { ...@@ -74,9 +75,11 @@ export default {
if (has(graphData, 'input') === false) { if (has(graphData, 'input') === false) {
return; return;
} }
let inputIdToIndex = {};
for (let i=0; i<graphData['input'].length; ++i) { for (let i=0; i<graphData['input'].length; ++i) {
let curInputNode = graphData['input'][i]; let curInputNode = graphData['input'][i];
let nodeKey = curInputNode['name']; let nodeKey = curInputNode['name'];
inputIdToIndex[nodeKey] = i;
g.setNode( g.setNode(
nodeKey, nodeKey,
{ {
...@@ -144,15 +147,17 @@ export default { ...@@ -144,15 +147,17 @@ export default {
nodeKeys.push(outputNodeKey); nodeKeys.push(outputNodeKey);
// add edges from inputs to node and from node to output // add edges from inputs to node and from node to output
for (let e=0; e<curOperatorNode['input'].length; ++e) { if (has(curOperatorNode, 'input') && isArrayLike(curOperatorNode['input'])) {
// TODO(daming-lu): hard-coding style here to this polyline shows shadows. for (let e = 0; e < curOperatorNode['input'].length; ++e) {
g.setEdge(curOperatorNode['input'][e], nodeKey); g.setEdge(curOperatorNode['input'][e], nodeKey);
} }
}
if (has(curOperatorNode, 'output') && isArrayLike(curOperatorNode['output'])) {
g.setEdge(nodeKey, curOperatorNode['output'][0], { g.setEdge(nodeKey, curOperatorNode['output'][0], {
style: 'stroke: #333;stroke-width: 1.5px', style: 'stroke: #333;stroke-width: 1.5px',
}); });
} }
}
// TODO(daming-lu): add prettier styles to diff nodes // TODO(daming-lu): add prettier styles to diff nodes
let svg = d3.select('svg') let svg = d3.select('svg')
...@@ -172,7 +177,7 @@ export default { ...@@ -172,7 +177,7 @@ export default {
let opIndex = d.slice(7); // remove prefix "opNode_" let opIndex = d.slice(7); // remove prefix "opNode_"
nodeInfo = graphData.node[opIndex]; nodeInfo = graphData.node[opIndex];
} else if (nodeType === 'input') { } else if (nodeType === 'input') {
nodeInfo = graphData.input[d-1]; nodeInfo = graphData.input[inputIdToIndex[d]];
} else { } else {
nodeInfo = 'output'; nodeInfo = 'output';
} }
......
...@@ -18,7 +18,6 @@ import json ...@@ -18,7 +18,6 @@ import json
import os import os
from google.protobuf.json_format import MessageToJson from google.protobuf.json_format import MessageToJson
from PIL import Image
from . import graphviz_graph as gg from . import graphviz_graph as gg
from . import onnx from . import onnx
...@@ -328,21 +327,27 @@ def add_edges(json_obj): ...@@ -328,21 +327,27 @@ def add_edges(json_obj):
cur_node = json_obj['node'][node_index] cur_node = json_obj['node'][node_index]
# input edges # input edges
if 'input' in cur_node and len(cur_node['input']) > 0:
for source in cur_node['input']: for source in cur_node['input']:
json_obj['edges'].append({ json_obj['edges'].append({
'source': source, 'source':
'target': 'node_' + str(node_index), source,
'label': 'label_' + str(label_incrementer) 'target':
'node_' + str(node_index),
'label':
'label_' + str(label_incrementer)
}) })
label_incrementer += 1 label_incrementer += 1
# output edge # output edge
if 'output' in cur_node and len(cur_node['output']) > 0:
json_obj['edges'].append({ json_obj['edges'].append({
'source': 'node_' + str(node_index), 'source': 'node_' + str(node_index),
'target': cur_node['output'][0], 'target': cur_node['output'][0],
'label': 'label_' + str(label_incrementer) 'label': 'label_' + str(label_incrementer)
}) })
label_incrementer += 1 label_incrementer += 1
return json_obj return json_obj
...@@ -483,23 +488,7 @@ class GraphPreviewGenerator(object): ...@@ -483,23 +488,7 @@ class GraphPreviewGenerator(object):
def draw_graph(model_pb_path, image_dir): def draw_graph(model_pb_path, image_dir):
json_str = load_model(model_pb_path) json_str = load_model(model_pb_path)
best_image = None return json_str
min_width = None
for i in range(10):
# randomly generate dot images and select the one with minimum width.
g = GraphPreviewGenerator(json_str)
dot_path = os.path.join(image_dir, "temp-%d.dot" % i)
image_path = os.path.join(image_dir, "temp-%d.jpg" % i)
g(dot_path)
try:
im = Image.open(image_path)
if min_width is None or im.size[0] < min_width:
min_width = im.size
best_image = image_path
except Exception:
pass
return best_image
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -298,11 +298,18 @@ def individual_audio(): ...@@ -298,11 +298,18 @@ def individual_audio():
@app.route('/data/plugin/graphs/graph') @app.route('/data/plugin/graphs/graph')
def graph(): def graph():
# TODO(ChunweiYan) need to add a config for whether have graph. # TODO(daming-lu): rename variables because we switched from static graph generated by graphviz
if graph_image_path is None or not os.path.isfile(graph_image_path): # to d3/dagre drawn svg
data = {'url': ''} if graph_image_path is not None:
data = {
'data': graph_image_path,
}
else: else:
data = {'url': '/graphs/image'} image_dir = os.path.join(args.logdir, "graphs")
if not os.path.isdir(image_dir):
os.mkdir(image_dir)
json_str = vdl_graph.draw_graph(args.model_pb, image_dir)
data = {'data': json_str}
result = gen_result(0, "", data) result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册