提交 efc4e6c3 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #2328 from emailweixu/make_model_diagram

Modify make_model_diagram.py so that it can be used to draw from proto
...@@ -39,6 +39,10 @@ def make_layer_label(layer_config): ...@@ -39,6 +39,10 @@ def make_layer_label(layer_config):
def make_diagram(config_file, dot_file, config_arg_str): def make_diagram(config_file, dot_file, config_arg_str):
config = parse_config(config_file, config_arg_str) config = parse_config(config_file, config_arg_str)
make_diagram_from_proto(config.model_config, dot_file)
def make_diagram_from_proto(model_config, dot_file):
# print >> sys.stderr, config # print >> sys.stderr, config
name2id = {} name2id = {}
f = open(dot_file, 'w') f = open(dot_file, 'w')
...@@ -59,12 +63,12 @@ def make_diagram(config_file, dot_file, config_arg_str): ...@@ -59,12 +63,12 @@ def make_diagram(config_file, dot_file, config_arg_str):
print >> f, 'digraph graphname {' print >> f, 'digraph graphname {'
print >> f, 'node [width=0.375,height=0.25];' print >> f, 'node [width=0.375,height=0.25];'
for i in xrange(len(config.model_config.layers)): for i in xrange(len(model_config.layers)):
l = config.model_config.layers[i] l = model_config.layers[i]
name2id[l.name] = i name2id[l.name] = i
i = 0 i = 0
for sub_model in config.model_config.sub_models: for sub_model in model_config.sub_models:
if sub_model.name == 'root': if sub_model.name == 'root':
continue continue
print >> f, 'subgraph cluster_%s {' % i print >> f, 'subgraph cluster_%s {' % i
...@@ -78,18 +82,18 @@ def make_diagram(config_file, dot_file, config_arg_str): ...@@ -78,18 +82,18 @@ def make_diagram(config_file, dot_file, config_arg_str):
for layer_name in sub_model.layer_names: for layer_name in sub_model.layer_names:
submodel_layers.add(layer_name) submodel_layers.add(layer_name)
lid = name2id[layer_name] lid = name2id[layer_name]
layer_config = config.model_config.layers[lid] layer_config = model_config.layers[lid]
label = make_layer_label(layer_config) label = make_layer_label(layer_config)
print >> f, 'l%s [label="%s", shape=box];' % (lid, label) print >> f, 'l%s [label="%s", shape=box];' % (lid, label)
print >> f, '}' print >> f, '}'
for i in xrange(len(config.model_config.layers)): for i in xrange(len(model_config.layers)):
l = config.model_config.layers[i] l = model_config.layers[i]
if l.name not in submodel_layers: if l.name not in submodel_layers:
label = make_layer_label(l) label = make_layer_label(l)
print >> f, 'l%s [label="%s", shape=box];' % (i, label) print >> f, 'l%s [label="%s", shape=box];' % (i, label)
for sub_model in config.model_config.sub_models: for sub_model in model_config.sub_models:
if sub_model.name == 'root': if sub_model.name == 'root':
continue continue
for link in sub_model.in_links: for link in sub_model.in_links:
...@@ -99,8 +103,8 @@ def make_diagram(config_file, dot_file, config_arg_str): ...@@ -99,8 +103,8 @@ def make_diagram(config_file, dot_file, config_arg_str):
for mem in sub_model.memories: for mem in sub_model.memories:
print >> f, make_mem(mem) print >> f, make_mem(mem)
for i in xrange(len(config.model_config.layers)): for i in xrange(len(model_config.layers)):
for l in config.model_config.layers[i].inputs: for l in model_config.layers[i].inputs:
print >> f, 'l%s -> l%s [label="%s"];' % ( print >> f, 'l%s -> l%s [label="%s"];' % (
name2id[l.input_layer_name], i, l.input_parameter_name) name2id[l.input_layer_name], i, l.input_parameter_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册