未验证 提交 06583b18 编写于 作者: Y Yan Chunwei 提交者: GitHub

graph apply new color theme (#135)

上级 42c327c4
......@@ -350,16 +350,16 @@ def load_model(model_pb_path):
class GraphPreviewGenerator(object):
'''
Generate a graph image for ONNX proto.
'''
def __init__(self, model_json):
#self.model = json.loads(model_json)
self.model = model_json
# init graphviz graph
self.graph = gg.Graph(
self.model['name'],
layout="dot",
#resolution=200,
concentrate="true",
# rankdir="LR"
rankdir="TB",
)
......@@ -367,7 +367,7 @@ class GraphPreviewGenerator(object):
self.param_rank = self.graph.rank_group('same', 1)
self.arg_rank = self.graph.rank_group('same', 0)
def __call__(self, path='temp.dot'):
def __call__(self, path='temp.dot', show=False):
self.nodes = {}
self.params = set()
self.ops = set()
......@@ -375,14 +375,12 @@ class GraphPreviewGenerator(object):
for item in self.model['input'] + self.model['output']:
node = self.add_param(**item)
print 'name', item['name']
self.nodes[item['name']] = node
self.params.add(item['name'])
for id, item in enumerate(self.model['node']):
node = self.add_op(**item)
name = "node_" + str(id)
print 'name', name
self.nodes[name] = node
self.ops.add(name)
......@@ -403,15 +401,20 @@ class GraphPreviewGenerator(object):
else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item)
self.graph.display(path)
if not show:
self.graph.display(path)
else:
self.graph.show(path)
def add_param(self, name, data_type, shape):
label = '\n'.join([
'<<table cellpadding="5">',
' <tr>',
' <td bgcolor="#eeeeee">',
' <td bgcolor="#2b787e">',
' <b>',
name,
' </td>'
' </b>',
' </td>',
' </tr>',
' <tr>',
' <td>',
......@@ -429,23 +432,21 @@ class GraphPreviewGenerator(object):
label,
prefix="param",
shape="none",
# rank=self.param_rank,
style="rounded,filled,bold",
width="1.3",
#color="#ffa0a0",
color="#8cc7ff",
color="#148b97",
fontcolor="#ffffff",
fontname="Arial")
def add_op(self, opType, **kwargs):
return self.graph.node(
gg.crepr(opType),
# rank=self.op_rank,
"<<B>%s</B>>" % opType,
prefix="op",
shape="box",
style="rounded, filled, bold",
fillcolor="#8cc7cd",
#fillcolor="#8cc7ff",
color="#303A3A",
fontname="Arial",
fontcolor="#ffffff",
width="1.3",
height="0.84",
)
......@@ -454,11 +455,11 @@ class GraphPreviewGenerator(object):
return self.graph.node(
gg.crepr(name),
prefix="arg",
# rank=self.arg_rank,
shape="box",
style="rounded,filled,bold",
fontname="Arial",
color="grey")
fontcolor="#999999",
color="#dddddd")
def add_edge(self, source, target, label, **kwargs):
source = self.nodes[source]
......@@ -498,7 +499,4 @@ if __name__ == '__main__':
assert json_str
g = GraphPreviewGenerator(json_str)
g('./temp.dot')
# for i in range(10):
# g = GraphPreviewGenerator(json_str)
# g('./temp-%d.dot' % i)
g('./temp.dot', show=False)
......@@ -28,14 +28,9 @@ class Rank(object):
if not self.nodes:
return ''
# repr = []
# for node in self.nodes:
# repr.append(str(node))
return '{' + 'rank={};'.format(self.kind) + \
','.join([node.name for node in self.nodes]) + '}'
# return '\n'.join(repr)
# the python package graphviz is too poor.
class Graph(object):
......@@ -78,14 +73,21 @@ class Graph(object):
file.write(self.__str__())
image_path = dot_path[:-3] + "jpg"
cmd = ["dot", "-Tjpg", dot_path, "-o", image_path]
# cmd = "./preview.sh \"%s\"" % cmd
print 'cmd', cmd
# subprocess.call(cmd, shell=True)
subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
# os.system(' '.join(cmd))
# assert os.path.isfile(image_path), "no image generated"
subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return image_path
def show(self, dot_path):
image = self.display(dot_path)
cmd = ["feh", image]
subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
def _rank_repr(self):
ranks = sorted(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册