提交 c3f348d8 编写于 作者: Y Yan Chunwei 提交者: GitHub

add graph backend (#87)

based on graphviz
上级 83dd19f7
import os
import json import json
from google.protobuf.json_format import MessageToJson from google.protobuf.json_format import MessageToJson
import onnx import onnx
import graphviz_graph as gg
from PIL import Image
def debug_print(json_obj): def debug_print(json_obj):
print(json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': '))) print(json.dumps(
json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
def reorganize_inout(json_obj, key): def reorganize_inout(json_obj, key):
...@@ -78,15 +82,16 @@ def get_links(model_json): ...@@ -78,15 +82,16 @@ def get_links(model_json):
name = input['name'] name = input['name']
for node in model_json['node']: for node in model_json['node']:
if name in node['input']: if name in node['input']:
links.append({'source': name, links.append({'source': name, "target": node['name']})
"target": node['name']})
for source_node in model_json['node']: for source_node in model_json['node']:
for output in source_node['output']: for output in source_node['output']:
for target_node in model_json['node']: for target_node in model_json['node']:
if output in target_node['input']: if output in target_node['input']:
links.append({'source': source_node['name'], links.append({
'target': target_node['name']}) 'source': source_node['name'],
'target': target_node['name']
})
return links return links
...@@ -189,8 +194,6 @@ def get_level_to_all(node_links, model_json): ...@@ -189,8 +194,6 @@ def get_level_to_all(node_links, model_json):
level_to_nodes[level] = list() level_to_nodes[level] = list()
level_to_nodes[level].append(idx) level_to_nodes[level].append(idx)
# debug_print(level_to_nodes) # debug_print(level_to_nodes)
""" """
input_to_level {idx -> level} input_to_level {idx -> level}
level_to_inputs {level -> [input1, input2]} level_to_inputs {level -> [input1, input2]}
...@@ -231,7 +234,8 @@ def get_level_to_all(node_links, model_json): ...@@ -231,7 +234,8 @@ def get_level_to_all(node_links, model_json):
if out_level not in output_to_level: if out_level not in output_to_level:
output_to_level[out_idx] = out_level output_to_level[out_idx] = out_level
else: else:
raise Exception("output " + out_name + "have multiple source") raise Exception(
"output " + out_name + "have multiple source")
level_to_outputs = dict() level_to_outputs = dict()
for out_idx in output_to_level: for out_idx in output_to_level:
level = output_to_level[out_idx] level = output_to_level[out_idx]
...@@ -243,7 +247,12 @@ def get_level_to_all(node_links, model_json): ...@@ -243,7 +247,12 @@ def get_level_to_all(node_links, model_json):
def init_level(level): def init_level(level):
if level not in level_to_all: if level not in level_to_all:
level_to_all[level] = {'nodes': list(), 'inputs': list(), 'outputs': list()} level_to_all[level] = {
'nodes': list(),
'inputs': list(),
'outputs': list()
}
# merge all levels # merge all levels
for level in level_to_nodes: for level in level_to_nodes:
init_level(level) init_level(level)
...@@ -321,116 +330,6 @@ def add_edges(json_obj): ...@@ -321,116 +330,6 @@ def add_edges(json_obj):
return json_obj return json_obj
def transform_for_echars(model_json):
opItemStyle = {
"normal": {
"color": '#d95f02'
}
}
paraterItemStyle = {
"normal": {
"color": '#1b9e77'
}
};
paraSymbolSize = [12, 6]
paraSymbol = 'rect'
opSymbolSize = [5, 5]
option = {
"title": {
"text": 'Default Graph Name'
},
"tooltip": {
"show": False
},
"animationDurationUpdate": 1500,
"animationEasingUpdate": 'quinticInOut',
"series": [
{
"type": "graph",
"layout": "none",
"symbolSize": 8,
"roam": True,
"label": {
"normal": {
"show": True,
"color": 'black'
}
},
"edgeSymbol": ['none', 'arrow'],
"edgeSymbolSize": [0, 10],
"edgeLabel": {
"normal": {
"textStyle": {
"fontSize": 20
}
}
},
"lineStyle": {
"normal": {
"opacity": 0.9,
"width": 2,
"curveness": 0
}
},
"data": [],
"links": []
}
]
}
option['title']['text'] = model_json['name']
rename_model(model_json)
node_links = get_node_links(model_json)
add_level_to_node_links(node_links)
level_to_all = get_level_to_all(node_links, model_json)
node_to_coordinate, input_to_coordinate, output_to_coordinate = level_to_coordinate(level_to_all)
inputs = model_json['input']
nodes = model_json['node']
outputs = model_json['output']
echars_data = list()
for in_idx in range(len(inputs)):
input = inputs[in_idx]
data = dict()
data['name'] = input['name']
data['x'] = input_to_coordinate[in_idx]['x']
data['y'] = input_to_coordinate[in_idx]['y']
data['symbol'] = paraSymbol
data['itemStyle'] = paraterItemStyle
data['symbolSize'] = paraSymbolSize
echars_data.append(data)
for node_idx in range(len(nodes)):
node = nodes[node_idx]
data = dict()
data['name'] = node['name']
data['x'] = node_to_coordinate[node_idx]['x']
data['y'] = node_to_coordinate[node_idx]['y']
data['itemStyle'] = opItemStyle
data['symbolSize'] = opSymbolSize
echars_data.append(data)
for out_idx in range(len(outputs)):
output = outputs[out_idx]
data = dict()
data['name'] = output['name']
data['x'] = output_to_coordinate[out_idx]['x']
data['y'] = output_to_coordinate[out_idx]['y']
data['symbol'] = paraSymbol
data['itemStyle'] = paraterItemStyle
data['symbolSize'] = paraSymbolSize
echars_data.append(data)
option['series'][0]['data'] = echars_data
option['series'][0]['links'] = get_links(model_json)
return option
def to_IR_json(model_pb_path): def to_IR_json(model_pb_path):
model = onnx.load(model_pb_path) model = onnx.load(model_pb_path)
graph = model.graph graph = model.graph
...@@ -446,14 +345,160 @@ def to_IR_json(model_pb_path): ...@@ -446,14 +345,160 @@ def to_IR_json(model_pb_path):
def load_model(model_pb_path): def load_model(model_pb_path):
model_json = to_IR_json(model_pb_path) model_json = to_IR_json(model_pb_path)
options = transform_for_echars(model_json) model_json = add_edges(model_json)
return options return model_json
class GraphPreviewGenerator(object):
def __init__(self, model_json):
#self.model = json.loads(model_json)
self.model = model_json
# init graphviz graph
self.graph = gg.Graph(
self.model['name'],
layout="dot",
#resolution=200,
concentrate="true",
# rankdir="LR"
rankdir="TB",
)
self.op_rank = self.graph.rank_group('same', 2)
self.param_rank = self.graph.rank_group('same', 1)
self.arg_rank = self.graph.rank_group('same', 0)
def __call__(self, path='temp.dot'):
self.nodes = {}
self.params = set()
self.ops = set()
self.args = set()
for item in self.model['input'] + self.model['output']:
node = self.add_param(**item)
print 'name', item['name']
self.nodes[item['name']] = node
self.params.add(item['name'])
for id, item in enumerate(self.model['node']):
node = self.add_op(**item)
name = "node_" + str(id)
print 'name', name
self.nodes[name] = node
self.ops.add(name)
for item in self.model['edges']:
source = item['source']
target = item['target']
if source not in self.nodes:
self.nodes[source] = self.add_arg(source)
self.args.add(source)
if target not in self.nodes:
self.nodes[target] = self.add_arg(target)
self.args.add(target)
if source in self.args or target in self.args:
edge = self.add_edge(
style="dashed,bold", color="#aaaaaa", **item)
else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item)
self.graph.display(path)
def add_param(self, name, data_type, shape):
label = '\n'.join([
'<<table cellpadding="5">',
' <tr>',
' <td bgcolor="#eeeeee">',
name,
' </td>'
' </tr>',
' <tr>',
' <td>',
data_type,
' </td>'
' </tr>',
' <tr>',
' <td>',
'[%s]' % 'x'.join(shape),
' </td>'
' </tr>',
'</table>>',
])
return self.graph.node(
label,
prefix="param",
shape="none",
# rank=self.param_rank,
style="rounded,filled,bold",
width="1.3",
#color="#ffa0a0",
color="#8cc7ff",
fontname="Arial")
def add_op(self, opType, **kwargs):
return self.graph.node(
gg.crepr(opType),
# rank=self.op_rank,
prefix="op",
shape="box",
style="rounded, filled, bold",
fillcolor="#8cc7cd",
#fillcolor="#8cc7ff",
fontname="Arial",
width="1.3",
height="0.84",
)
def add_arg(self, name):
return self.graph.node(
gg.crepr(name),
prefix="arg",
# rank=self.arg_rank,
shape="box",
style="rounded,filled,bold",
fontname="Arial",
color="grey")
def add_edge(self, source, target, label, **kwargs):
source = self.nodes[source]
target = self.nodes[target]
return self.graph.edge(source, target, **kwargs)
def draw_graph(model_pb_path, image_dir):
json_str = load_model(model_pb_path)
best_image = None
min_width = None
for i in range(10):
# randomly generate dot images and select the one with minimum width.
g = GraphPreviewGenerator(json_str)
dot_path = os.path.join(image_dir, "temp-%d.dot" % i)
image_path = os.path.join(image_dir, "temp-%d.jpg" % i)
g(dot_path)
try:
im = Image.open(image_path)
if min_width is None or im.size[0] < min_width:
min_width = im.size
best_image = image_path
except:
pass
return best_image
if __name__ == '__main__': if __name__ == '__main__':
import os import os
import sys import sys
current_path = os.path.abspath(os.path.dirname(sys.argv[0])) current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
# json_str = load_model(current_path + "/mock/inception_v1_model.pb") json_str = load_model(current_path + "/mock/inception_v1_model.pb")
json_str = load_model(current_path + "/mock/squeezenet_model.pb") #json_str = load_model(current_path + "/mock/squeezenet_model.pb")
print(json_str) # json_str = load_model('./mock/shufflenet/model.pb')
debug_print(json_str)
assert json_str
g = GraphPreviewGenerator(json_str)
g('./temp.dot')
# for i in range(10):
# g = GraphPreviewGenerator(json_str)
# g('./temp-%d.dot' % i)
import subprocess
import tempfile
import sys
import random
import subprocess
import os
def crepr(v):
if type(v) is str or type(v) is unicode:
return '"%s"' % v
return str(v)
class Rank(object):
def __init__(self, kind, name, priority):
'''
kind: str
name: str
priority: int
'''
self.kind = kind
self.name = name
self.priority = priority
self.nodes = []
def __str__(self):
if not self.nodes:
return ''
# repr = []
# for node in self.nodes:
# repr.append(str(node))
return '{' + 'rank={};'.format(self.kind) + \
','.join([node.name for node in self.nodes]) + '}'
# return '\n'.join(repr)
# the python package graphviz is too poor.
class Graph(object):
rank_counter = 0
def __init__(self, title, **attrs):
self.title = title
self.attrs = attrs
self.nodes = []
self.edges = []
self.rank_groups = {}
def code(self):
return self.__str__()
def rank_group(self, kind, priority):
name = "rankgroup-%d" % Graph.rank_counter
Graph.rank_counter += 1
rank = Rank(kind, name, priority)
self.rank_groups[name] = rank
return name
def node(self, label, prefix, **attrs):
node = Node(label, prefix, **attrs)
if 'rank' in attrs:
rank = self.rank_groups[attrs['rank']]
del attrs['rank']
rank.nodes.append(node)
self.nodes.append(node)
return node
def edge(self, source, target, **attrs):
edge = Edge(source, target, **attrs)
self.edges.append(edge)
return edge
def display(self, dot_path):
file = open(dot_path, 'w')
file.write(self.__str__())
image_path = dot_path[:-3] + "jpg"
cmd = ["/usr/bin/dot", "-Tjpg", dot_path, "-o", image_path]
# cmd = "./preview.sh \"%s\"" % cmd
print 'cmd', cmd
# subprocess.call(cmd, shell=True)
subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
# os.system(' '.join(cmd))
# assert os.path.isfile(image_path), "no image generated"
def _rank_repr(self):
ranks = sorted(
self.rank_groups.items(),
cmp=lambda a, b: a[1].priority > b[1].priority)
repr = []
for x in ranks:
repr.append(str(x[1]))
return '\n'.join(repr) + '\n'
def __str__(self):
reprs = [
'digraph G {',
'title = {}'.format(crepr(self.title)),
]
for attr in self.attrs:
reprs.append("{key}={value};".format(
key=attr, value=crepr(self.attrs[attr])))
reprs.append(self._rank_repr())
random.shuffle(self.nodes)
reprs += [str(node) for node in self.nodes]
for x in self.edges:
reprs.append(str(x))
reprs.append('}')
return '\n'.join(reprs)
class Node(object):
counter = 1
def __init__(self, label, prefix, **attrs):
self.label = label
self.name = "%s_%d" % (prefix, Node.counter)
self.attrs = attrs
Node.counter += 1
def __str__(self):
reprs = '{name} [label={label} {extra} ];'.format(
name=self.name,
label=self.label,
extra=',' + ','.join("%s=%s" % (key, crepr(value))
for key, value in self.attrs.items())
if self.attrs else "")
return reprs
class Edge(object):
def __init__(self, source, target, **attrs):
'''
Link source to target.
:param source: Node
:param target: Node
:param graph: Graph
:param attrs: dic
'''
self.source = source
self.target = target
self.attrs = attrs
def __str__(self):
repr = "{source} -> {target} {extra}".format(
source=self.source.name,
target=self.target.name,
extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in self.attrs.items()) + "]")
return repr
g_graph = Graph(title="some model")
def add_param(label, graph=None):
if not graph:
graph = g_graph
return graph.node(label=label, prefix='param', color='blue')
def add_op(label, graph=None):
if not graph:
graph = g_graph
label = '\n'.join([
'<table border="0">',
' <tr>',
' <td>',
label,
' </td>'
' </tr>',
'</table>',
])
return graph.node(label=label, prefix='op', shape="none")
def add_edge(source, target):
return g_graph.edge(source, target)
if __name__ == '__main__':
n0 = add_param(crepr("layer/W0.w"))
n1 = add_param(crepr("layer/W0.b"))
n2 = add_op("sum")
add_edge(n0, n2)
add_edge(n1, n2)
print g_graph.code()
g_graph.display('./1.dot')
...@@ -23,12 +23,12 @@ app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30 ...@@ -23,12 +23,12 @@ app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30
SERVER_DIR = os.path.join(visualdl.ROOT, 'server') SERVER_DIR = os.path.join(visualdl.ROOT, 'server')
def option_parser(): def option_parser():
""" """
:return: :return:
""" """
parser = OptionParser(usage="usage: visual_dl visual_dl.py "\ parser = OptionParser(usage="usage: visualDL -p port [options]")
"-p port [options]")
parser.add_option( parser.add_option(
"-p", "-p",
"--port", "--port",
...@@ -44,6 +44,12 @@ def option_parser(): ...@@ -44,6 +44,12 @@ def option_parser():
default="0.0.0.0", default="0.0.0.0",
action="store", action="store",
help="api service ip") help="api service ip")
parser.add_option(
"-m",
"--model_pb",
type=str,
action="store",
help="model proto in ONNX format")
parser.add_option( parser.add_option(
"--logdir", action="store", dest="logdir", help="log file directory") "--logdir", action="store", dest="logdir", help="log file directory")
return parser.parse_args() return parser.parse_args()
...@@ -56,6 +62,7 @@ mock_data_path = os.path.join(SERVER_DIR, "./mock_data/") ...@@ -56,6 +62,7 @@ mock_data_path = os.path.join(SERVER_DIR, "./mock_data/")
log_reader = LogReader(options.logdir) log_reader = LogReader(options.logdir)
graph_image_path = None
# return data # return data
# status, msg, data # status, msg, data
...@@ -82,6 +89,10 @@ def serve_static(filename): ...@@ -82,6 +89,10 @@ def serve_static(filename):
return send_from_directory( return send_from_directory(
os.path.join(server_path, static_file_path), filename) os.path.join(server_path, static_file_path), filename)
@app.route('/graphs/image')
def serve_graph():
print 'send file', graph_image_path
return send_file(graph_image_path)
@app.route('/data/logdir') @app.route('/data/logdir')
def logdir(): def logdir():
...@@ -170,13 +181,19 @@ def histogram(): ...@@ -170,13 +181,19 @@ def histogram():
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
@app.route('/data/plugin/graphs/graphs') @app.route('/data/plugin/graphs/graph')
def graph(): def graph():
# run = request.args.get('run') global graph_image_path
# model_json_str = mock_data.graph_data() # TODO(ChunweiYan) need to add a config for whether have graph.
# model_json = json.loads(model_json_str)
model_json = vdl_graph.load_model(options.logdir + "/model.pb") image_dir = os.path.join(options.logdir, "graphs")
result = gen_result(0, "", model_json) if not os.path.isdir(image_dir):
os.mkdir(image_dir)
image_path = vdl_graph.draw_graph(options.model_pb, image_dir)
graph_image_path = image_path
print 'image_path', image_path
data = {'url': '/graphs/image'}
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册