未验证 提交 6ff9620c 编写于 作者: D daminglu 提交者: GitHub

Fix graph bugs (#391)

上级 f3c0a861
......@@ -19,6 +19,7 @@ import * as svgToPngDownloadHelper from './svgToPngDownloadHelper.js';
import * as d3 from 'd3';
import has from 'lodash/has';
import isArrayLike from 'lodash/isArrayLike';
export default {
props: {
......@@ -74,9 +75,11 @@ export default {
if (has(graphData, 'input') === false) {
return;
}
let inputIdToIndex = {};
for (let i=0; i<graphData['input'].length; ++i) {
let curInputNode = graphData['input'][i];
let nodeKey = curInputNode['name'];
inputIdToIndex[nodeKey] = i;
g.setNode(
nodeKey,
{
......@@ -144,15 +147,17 @@ export default {
nodeKeys.push(outputNodeKey);
// add edges from inputs to node and from node to output
for (let e=0; e<curOperatorNode['input'].length; ++e) {
// TODO(daming-lu): hard-coding style here to this polyline shows shadows.
if (has(curOperatorNode, 'input') && isArrayLike(curOperatorNode['input'])) {
for (let e = 0; e < curOperatorNode['input'].length; ++e) {
g.setEdge(curOperatorNode['input'][e], nodeKey);
}
}
if (has(curOperatorNode, 'output') && isArrayLike(curOperatorNode['output'])) {
g.setEdge(nodeKey, curOperatorNode['output'][0], {
style: 'stroke: #333;stroke-width: 1.5px',
});
}
}
// TODO(daming-lu): add prettier styles to diff nodes
let svg = d3.select('svg')
......@@ -172,7 +177,7 @@ export default {
let opIndex = d.slice(7); // remove prefix "opNode_"
nodeInfo = graphData.node[opIndex];
} else if (nodeType === 'input') {
nodeInfo = graphData.input[d-1];
nodeInfo = graphData.input[inputIdToIndex[d]];
} else {
nodeInfo = 'output';
}
......
......@@ -18,7 +18,6 @@ import json
import os
from google.protobuf.json_format import MessageToJson
from PIL import Image
from . import graphviz_graph as gg
from . import onnx
......@@ -328,21 +327,27 @@ def add_edges(json_obj):
cur_node = json_obj['node'][node_index]
# input edges
if 'input' in cur_node and len(cur_node['input']) > 0:
for source in cur_node['input']:
json_obj['edges'].append({
'source': source,
'target': 'node_' + str(node_index),
'label': 'label_' + str(label_incrementer)
'source':
source,
'target':
'node_' + str(node_index),
'label':
'label_' + str(label_incrementer)
})
label_incrementer += 1
# output edge
if 'output' in cur_node and len(cur_node['output']) > 0:
json_obj['edges'].append({
'source': 'node_' + str(node_index),
'target': cur_node['output'][0],
'label': 'label_' + str(label_incrementer)
})
label_incrementer += 1
return json_obj
......@@ -483,23 +488,7 @@ class GraphPreviewGenerator(object):
def draw_graph(model_pb_path, image_dir):
json_str = load_model(model_pb_path)
best_image = None
min_width = None
for i in range(10):
# randomly generate dot images and select the one with minimum width.
g = GraphPreviewGenerator(json_str)
dot_path = os.path.join(image_dir, "temp-%d.dot" % i)
image_path = os.path.join(image_dir, "temp-%d.jpg" % i)
g(dot_path)
try:
im = Image.open(image_path)
if min_width is None or im.size[0] < min_width:
min_width = im.size
best_image = image_path
except Exception:
pass
return best_image
return json_str
if __name__ == '__main__':
......
......@@ -298,11 +298,18 @@ def individual_audio():
@app.route('/data/plugin/graphs/graph')
def graph():
# TODO(ChunweiYan) need to add a config for whether have graph.
if graph_image_path is None or not os.path.isfile(graph_image_path):
data = {'url': ''}
# TODO(daming-lu): rename variables because we switched from static graph generated by graphviz
# to d3/dagre drawn svg
if graph_image_path is not None:
data = {
'data': graph_image_path,
}
else:
data = {'url': '/graphs/image'}
image_dir = os.path.join(args.logdir, "graphs")
if not os.path.isdir(image_dir):
os.mkdir(image_dir)
json_str = vdl_graph.draw_graph(args.model_pb, image_dir)
data = {'data': json_str}
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册