make_model_diagram.py 4.4 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate dot diagram file for the given paddle model config
# The generated file can be viewed using Graphviz (http://graphviz.org)

M
minqiyang 已提交
18 19 20
from __future__ import print_function

import six
Z
zhangjinchao01 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
import sys
import traceback

from paddle.trainer.config_parser import parse_config


def make_layer_label(layer_config):
    label = '%s type=%s' % (layer_config.name, layer_config.type)
    if layer_config.reversed:
        label += ' <=='

    label2 = ''
    if layer_config.active_type:
        label2 += 'act=%s ' % layer_config.active_type
    if layer_config.bias_parameter_name:
        label2 += 'bias=%s ' % layer_config.bias_parameter_name

    if label2:
        label += '\l' + label2
    return label


def make_diagram(config_file, dot_file, config_arg_str):
    config = parse_config(config_file, config_arg_str)
45 46 47 48
    make_diagram_from_proto(config.model_config, dot_file)


def make_diagram_from_proto(model_config, dot_file):
Z
zhangjinchao01 已提交
49 50 51 52 53 54
    # print >> sys.stderr, config
    name2id = {}
    f = open(dot_file, 'w')
    submodel_layers = set()

    def make_link(link):
Q
qijun 已提交
55 56
        return 'l%s -> l%s;' % (name2id[link.layer_name],
                                name2id[link.link_name])
Z
zhangjinchao01 已提交
57 58 59 60

    def make_mem(mem):
        s = ''
        if mem.boot_layer_name:
Q
qijun 已提交
61 62 63 64
            s += 'l%s -> l%s;\n' % (name2id[mem.boot_layer_name],
                                    name2id[mem.layer_name])
        s += 'l%s -> l%s [style=dashed];' % (name2id[mem.layer_name],
                                             name2id[mem.link_name])
Z
zhangjinchao01 已提交
65 66
        return s

M
minqiyang 已提交
67 68 69
    print('digraph graphname {', file=f)
    print('node [width=0.375,height=0.25];', file=f)
    for i in six.moves.xrange(len(model_config.layers)):
70
        l = model_config.layers[i]
Z
zhangjinchao01 已提交
71 72 73
        name2id[l.name] = i

    i = 0
74
    for sub_model in model_config.sub_models:
Z
zhangjinchao01 已提交
75 76
        if sub_model.name == 'root':
            continue
M
minqiyang 已提交
77 78
        print('subgraph cluster_%s {' % i, file=f)
        print('style=dashed;', file=f)
Z
zhangjinchao01 已提交
79 80 81
        label = '%s ' % sub_model.name
        if sub_model.reversed:
            label += '<=='
M
minqiyang 已提交
82
        print('label = "%s";' % label, file=f)
Z
zhangjinchao01 已提交
83 84 85 86 87
        i += 1
        submodel_layers.add(sub_model.name)
        for layer_name in sub_model.layer_names:
            submodel_layers.add(layer_name)
            lid = name2id[layer_name]
88
            layer_config = model_config.layers[lid]
Z
zhangjinchao01 已提交
89
            label = make_layer_label(layer_config)
M
minqiyang 已提交
90 91
            print('l%s [label="%s", shape=box];' % (lid, label), file=f)
        print('}', file=f)
Z
zhangjinchao01 已提交
92

M
minqiyang 已提交
93
    for i in six.moves.xrange(len(model_config.layers)):
94
        l = model_config.layers[i]
Z
zhangjinchao01 已提交
95 96
        if l.name not in submodel_layers:
            label = make_layer_label(l)
M
minqiyang 已提交
97
            print('l%s [label="%s", shape=box];' % (i, label), file=f)
Z
zhangjinchao01 已提交
98

99
    for sub_model in model_config.sub_models:
Z
zhangjinchao01 已提交
100 101 102
        if sub_model.name == 'root':
            continue
        for link in sub_model.in_links:
M
minqiyang 已提交
103
            print(make_link(link), file=f)
Z
zhangjinchao01 已提交
104
        for link in sub_model.out_links:
M
minqiyang 已提交
105
            print(make_link(link), file=f)
Z
zhangjinchao01 已提交
106
        for mem in sub_model.memories:
M
minqiyang 已提交
107
            print(make_mem(mem), file=f)
Z
zhangjinchao01 已提交
108

M
minqiyang 已提交
109
    for i in six.moves.xrange(len(model_config.layers)):
110
        for l in model_config.layers[i].inputs:
M
minqiyang 已提交
111 112 113 114
            print(
                'l%s -> l%s [label="%s"];' % (name2id[l.input_layer_name], i,
                                              l.input_parameter_name),
                file=f)
Z
zhangjinchao01 已提交
115

M
minqiyang 已提交
116
    print('}', file=f)
Z
zhangjinchao01 已提交
117 118 119 120
    f.close()


def usage():
M
minqiyang 已提交
121 122 123 124
    print(
        ("Usage: python show_model_diagram.py" +
         " CONFIG_FILE DOT_FILE [config_str]"),
        file=sys.stderr)
Z
zhangjinchao01 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    exit(1)


if __name__ == '__main__':
    if len(sys.argv) < 3 or len(sys.argv) > 4:
        usage()

    config_file = sys.argv[1]
    dot_file = sys.argv[2]
    config_arg_str = sys.argv[3] if len(sys.argv) == 4 else ''

    try:
        make_diagram(config_file, dot_file, config_arg_str)
    except:
        traceback.print_exc()
        raise