make_model_diagram.py 4.2 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate dot diagram file for the given paddle model config
# The generated file can be viewed using Graphviz (http://graphviz.org)

import sys
import traceback

from paddle.trainer.config_parser import parse_config


def make_layer_label(layer_config):
    label = '%s type=%s' % (layer_config.name, layer_config.type)
    if layer_config.reversed:
        label += ' <=='

    label2 = ''
    if layer_config.active_type:
        label2 += 'act=%s ' % layer_config.active_type
    if layer_config.bias_parameter_name:
        label2 += 'bias=%s ' % layer_config.bias_parameter_name

    if label2:
        label += '\l' + label2
    return label


def make_diagram(config_file, dot_file, config_arg_str):
    config = parse_config(config_file, config_arg_str)
42 43 44 45
    make_diagram_from_proto(config.model_config, dot_file)


def make_diagram_from_proto(model_config, dot_file):
Z
zhangjinchao01 已提交
46 47 48 49 50 51
    # print >> sys.stderr, config
    name2id = {}
    f = open(dot_file, 'w')
    submodel_layers = set()

    def make_link(link):
Q
qijun 已提交
52 53
        return 'l%s -> l%s;' % (name2id[link.layer_name],
                                name2id[link.link_name])
Z
zhangjinchao01 已提交
54 55 56 57

    def make_mem(mem):
        s = ''
        if mem.boot_layer_name:
Q
qijun 已提交
58 59 60 61
            s += 'l%s -> l%s;\n' % (name2id[mem.boot_layer_name],
                                    name2id[mem.layer_name])
        s += 'l%s -> l%s [style=dashed];' % (name2id[mem.layer_name],
                                             name2id[mem.link_name])
Z
zhangjinchao01 已提交
62 63 64 65
        return s

    print >> f, 'digraph graphname {'
    print >> f, 'node [width=0.375,height=0.25];'
66 67
    for i in xrange(len(model_config.layers)):
        l = model_config.layers[i]
Z
zhangjinchao01 已提交
68 69 70
        name2id[l.name] = i

    i = 0
71
    for sub_model in model_config.sub_models:
Z
zhangjinchao01 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84
        if sub_model.name == 'root':
            continue
        print >> f, 'subgraph cluster_%s {' % i
        print >> f, 'style=dashed;'
        label = '%s ' % sub_model.name
        if sub_model.reversed:
            label += '<=='
        print >> f, 'label = "%s";' % label
        i += 1
        submodel_layers.add(sub_model.name)
        for layer_name in sub_model.layer_names:
            submodel_layers.add(layer_name)
            lid = name2id[layer_name]
85
            layer_config = model_config.layers[lid]
Z
zhangjinchao01 已提交
86 87 88 89
            label = make_layer_label(layer_config)
            print >> f, 'l%s [label="%s", shape=box];' % (lid, label)
        print >> f, '}'

90 91
    for i in xrange(len(model_config.layers)):
        l = model_config.layers[i]
Z
zhangjinchao01 已提交
92 93 94 95
        if l.name not in submodel_layers:
            label = make_layer_label(l)
            print >> f, 'l%s [label="%s", shape=box];' % (i, label)

96
    for sub_model in model_config.sub_models:
Z
zhangjinchao01 已提交
97 98 99 100 101 102 103 104 105
        if sub_model.name == 'root':
            continue
        for link in sub_model.in_links:
            print >> f, make_link(link)
        for link in sub_model.out_links:
            print >> f, make_link(link)
        for mem in sub_model.memories:
            print >> f, make_mem(mem)

106 107
    for i in xrange(len(model_config.layers)):
        for l in model_config.layers[i].inputs:
Z
zhangjinchao01 已提交
108 109 110 111 112 113 114 115
            print >> f, 'l%s -> l%s [label="%s"];' % (
                name2id[l.input_layer_name], i, l.input_parameter_name)

    print >> f, '}'
    f.close()


def usage():
Q
qijun 已提交
116 117
    print >> sys.stderr, ("Usage: python show_model_diagram.py" +
                          " CONFIG_FILE DOT_FILE [config_str]")
Z
zhangjinchao01 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    exit(1)


if __name__ == '__main__':
    if len(sys.argv) < 3 or len(sys.argv) > 4:
        usage()

    config_file = sys.argv[1]
    dot_file = sys.argv[2]
    config_arg_str = sys.argv[3] if len(sys.argv) == 4 else ''

    try:
        make_diagram(config_file, dot_file, config_arg_str)
    except:
        traceback.print_exc()
        raise