未验证 提交 06583b18 编写于 作者: Y Yan Chunwei 提交者: GitHub

graph apply new color theme (#135)

上级 42c327c4
...@@ -350,16 +350,16 @@ def load_model(model_pb_path): ...@@ -350,16 +350,16 @@ def load_model(model_pb_path):
class GraphPreviewGenerator(object): class GraphPreviewGenerator(object):
'''
Generate a graph image for ONNX proto.
'''
def __init__(self, model_json): def __init__(self, model_json):
#self.model = json.loads(model_json)
self.model = model_json self.model = model_json
# init graphviz graph # init graphviz graph
self.graph = gg.Graph( self.graph = gg.Graph(
self.model['name'], self.model['name'],
layout="dot", layout="dot",
#resolution=200,
concentrate="true", concentrate="true",
# rankdir="LR"
rankdir="TB", rankdir="TB",
) )
...@@ -367,7 +367,7 @@ class GraphPreviewGenerator(object): ...@@ -367,7 +367,7 @@ class GraphPreviewGenerator(object):
self.param_rank = self.graph.rank_group('same', 1) self.param_rank = self.graph.rank_group('same', 1)
self.arg_rank = self.graph.rank_group('same', 0) self.arg_rank = self.graph.rank_group('same', 0)
def __call__(self, path='temp.dot'): def __call__(self, path='temp.dot', show=False):
self.nodes = {} self.nodes = {}
self.params = set() self.params = set()
self.ops = set() self.ops = set()
...@@ -375,14 +375,12 @@ class GraphPreviewGenerator(object): ...@@ -375,14 +375,12 @@ class GraphPreviewGenerator(object):
for item in self.model['input'] + self.model['output']: for item in self.model['input'] + self.model['output']:
node = self.add_param(**item) node = self.add_param(**item)
print 'name', item['name']
self.nodes[item['name']] = node self.nodes[item['name']] = node
self.params.add(item['name']) self.params.add(item['name'])
for id, item in enumerate(self.model['node']): for id, item in enumerate(self.model['node']):
node = self.add_op(**item) node = self.add_op(**item)
name = "node_" + str(id) name = "node_" + str(id)
print 'name', name
self.nodes[name] = node self.nodes[name] = node
self.ops.add(name) self.ops.add(name)
...@@ -403,15 +401,20 @@ class GraphPreviewGenerator(object): ...@@ -403,15 +401,20 @@ class GraphPreviewGenerator(object):
else: else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item) edge = self.add_edge(style="bold", color="#aaaaaa", **item)
if not show:
self.graph.display(path) self.graph.display(path)
else:
self.graph.show(path)
def add_param(self, name, data_type, shape): def add_param(self, name, data_type, shape):
label = '\n'.join([ label = '\n'.join([
'<<table cellpadding="5">', '<<table cellpadding="5">',
' <tr>', ' <tr>',
' <td bgcolor="#eeeeee">', ' <td bgcolor="#2b787e">',
' <b>',
name, name,
' </td>' ' </b>',
' </td>',
' </tr>', ' </tr>',
' <tr>', ' <tr>',
' <td>', ' <td>',
...@@ -429,23 +432,21 @@ class GraphPreviewGenerator(object): ...@@ -429,23 +432,21 @@ class GraphPreviewGenerator(object):
label, label,
prefix="param", prefix="param",
shape="none", shape="none",
# rank=self.param_rank,
style="rounded,filled,bold", style="rounded,filled,bold",
width="1.3", width="1.3",
#color="#ffa0a0", color="#148b97",
color="#8cc7ff", fontcolor="#ffffff",
fontname="Arial") fontname="Arial")
def add_op(self, opType, **kwargs): def add_op(self, opType, **kwargs):
return self.graph.node( return self.graph.node(
gg.crepr(opType), "<<B>%s</B>>" % opType,
# rank=self.op_rank,
prefix="op", prefix="op",
shape="box", shape="box",
style="rounded, filled, bold", style="rounded, filled, bold",
fillcolor="#8cc7cd", color="#303A3A",
#fillcolor="#8cc7ff",
fontname="Arial", fontname="Arial",
fontcolor="#ffffff",
width="1.3", width="1.3",
height="0.84", height="0.84",
) )
...@@ -454,11 +455,11 @@ class GraphPreviewGenerator(object): ...@@ -454,11 +455,11 @@ class GraphPreviewGenerator(object):
return self.graph.node( return self.graph.node(
gg.crepr(name), gg.crepr(name),
prefix="arg", prefix="arg",
# rank=self.arg_rank,
shape="box", shape="box",
style="rounded,filled,bold", style="rounded,filled,bold",
fontname="Arial", fontname="Arial",
color="grey") fontcolor="#999999",
color="#dddddd")
def add_edge(self, source, target, label, **kwargs): def add_edge(self, source, target, label, **kwargs):
source = self.nodes[source] source = self.nodes[source]
...@@ -498,7 +499,4 @@ if __name__ == '__main__': ...@@ -498,7 +499,4 @@ if __name__ == '__main__':
assert json_str assert json_str
g = GraphPreviewGenerator(json_str) g = GraphPreviewGenerator(json_str)
g('./temp.dot') g('./temp.dot', show=False)
# for i in range(10):
# g = GraphPreviewGenerator(json_str)
# g('./temp-%d.dot' % i)
...@@ -28,14 +28,9 @@ class Rank(object): ...@@ -28,14 +28,9 @@ class Rank(object):
if not self.nodes: if not self.nodes:
return '' return ''
# repr = []
# for node in self.nodes:
# repr.append(str(node))
return '{' + 'rank={};'.format(self.kind) + \ return '{' + 'rank={};'.format(self.kind) + \
','.join([node.name for node in self.nodes]) + '}' ','.join([node.name for node in self.nodes]) + '}'
# return '\n'.join(repr)
# the python package graphviz is too poor. # the python package graphviz is too poor.
class Graph(object): class Graph(object):
...@@ -78,14 +73,21 @@ class Graph(object): ...@@ -78,14 +73,21 @@ class Graph(object):
file.write(self.__str__()) file.write(self.__str__())
image_path = dot_path[:-3] + "jpg" image_path = dot_path[:-3] + "jpg"
cmd = ["dot", "-Tjpg", dot_path, "-o", image_path] cmd = ["dot", "-Tjpg", dot_path, "-o", image_path]
# cmd = "./preview.sh \"%s\"" % cmd subprocess.Popen(
print 'cmd', cmd cmd,
# subprocess.call(cmd, shell=True) stdin=subprocess.PIPE,
subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return image_path
def show(self, dot_path):
image = self.display(dot_path)
cmd = ["feh", image]
subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE)
# os.system(' '.join(cmd))
# assert os.path.isfile(image_path), "no image generated"
def _rank_repr(self): def _rank_repr(self):
ranks = sorted( ranks = sorted(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册