diff --git a/python/paddle/utils/make_model_diagram.py b/python/paddle/utils/make_model_diagram.py index 1370ea83a49955d3152a1147f8e8108371a8ae12..40f99075de7fb2401b3b704afe1eb44dbe6072dd 100644 --- a/python/paddle/utils/make_model_diagram.py +++ b/python/paddle/utils/make_model_diagram.py @@ -39,6 +39,10 @@ def make_layer_label(layer_config): def make_diagram(config_file, dot_file, config_arg_str): config = parse_config(config_file, config_arg_str) + make_diagram_from_proto(config.model_config, dot_file) + + +def make_diagram_from_proto(model_config, dot_file): # print >> sys.stderr, config name2id = {} f = open(dot_file, 'w') @@ -59,12 +63,12 @@ def make_diagram(config_file, dot_file, config_arg_str): print >> f, 'digraph graphname {' print >> f, 'node [width=0.375,height=0.25];' - for i in xrange(len(config.model_config.layers)): - l = config.model_config.layers[i] + for i in xrange(len(model_config.layers)): + l = model_config.layers[i] name2id[l.name] = i i = 0 - for sub_model in config.model_config.sub_models: + for sub_model in model_config.sub_models: if sub_model.name == 'root': continue print >> f, 'subgraph cluster_%s {' % i @@ -78,18 +82,18 @@ def make_diagram(config_file, dot_file, config_arg_str): for layer_name in sub_model.layer_names: submodel_layers.add(layer_name) lid = name2id[layer_name] - layer_config = config.model_config.layers[lid] + layer_config = model_config.layers[lid] label = make_layer_label(layer_config) print >> f, 'l%s [label="%s", shape=box];' % (lid, label) print >> f, '}' - for i in xrange(len(config.model_config.layers)): - l = config.model_config.layers[i] + for i in xrange(len(model_config.layers)): + l = model_config.layers[i] if l.name not in submodel_layers: label = make_layer_label(l) print >> f, 'l%s [label="%s", shape=box];' % (i, label) - for sub_model in config.model_config.sub_models: + for sub_model in model_config.sub_models: if sub_model.name == 'root': continue for link in sub_model.in_links: @@ -99,8 +103,8 @@ def make_diagram(config_file, dot_file, config_arg_str): for mem in sub_model.memories: print >> f, make_mem(mem) - for i in xrange(len(config.model_config.layers)): - for l in config.model_config.layers[i].inputs: + for i in xrange(len(model_config.layers)): + for l in model_config.layers[i].inputs: print >> f, 'l%s -> l%s [label="%s"];' % ( name2id[l.input_layer_name], i, l.input_parameter_name)