net_drawer.py 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
import argparse
import json
import logging
from collections import defaultdict

import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

try:
    from graphviz import Digraph
except ImportError:
    logger.info(
        'Cannot import graphviz, which is required for drawing a network. This '
        'can usually be installed in python with "pip install graphviz". Also, '
        'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
        'can usually be installed with "sudo apt-get install graphviz".')
    print('net_drawer will not run correctly. Please install the correct '
          'dependencies.')
    exit(0)

OP_STYLE = {
    'shape': 'oval',
    'color': '#0F9D58',
    'style': 'filled',
    'fontcolor': '#FFFFFF'
}

VAR_STYLE = {}

GRAPH_STYLE = {"rankdir": "TB", }

GRAPH_ID = 0


def unique_id():
    def generator():
        GRAPH_ID += 1
        return GRAPH_ID

    return generator


def draw_node(op):
    node = OP_STYLE
    node["name"] = op.type
    node["label"] = op.type
    return node


def draw_edge(var_parent, op, var, arg):
    edge = VAR_STYLE
    edge["label"] = "%s(%s)" % (var.parameter, arg)
    edge["head_name"] = op.type
    edge["tail_name"] = var_parent[arg]
    return edge


def parse_graph(program, graph, var_dict, **kwargs):

    # fill the known variables
    for block in program.blocks:
        for var in block.vars:
            if not var_dict.has_key(var):
                var_dict[var] = "Feed"

    proto = framework_pb2.ProgramDesc.FromString(
        program.desc.serialize_to_string())
    for block in proto.blocks:
        for op in block.ops:
            graph.node(**draw_node(op))
            for o in op.outputs:
                for arg in o.arguments:
                    var_dict[arg] = op.type
            for e in op.inputs:
                for arg in e.arguments:
                    if var_dict.has_key(arg):
                        graph.edge(**draw_edge(var_dict, op, e, arg))


def draw_graph(init_program, program, **kwargs):
    if kwargs.has_key("graph_attr"):
        GRAPH_STYLE.update(kwargs[graph_attr])
    if kwargs.has_key("node_attr"):
        OP_STYLE.update(kwargs[node_attr])
    if kwargs.has_key("edge_attr"):
        VAR_STYLE.update(kwargs[edge_attr])

    graph_id = unique_id()
    filename = kwargs.get("filename")
    if filename == None:
        filename = str(graph_id) + ".gv"
    g = Digraph(
        name=str(graph_id),
        filename=filename,
        graph_attr=GRAPH_STYLE,
        node_attr=OP_STYLE,
        edge_attr=VAR_STYLE,
        **kwargs)

    var_dict = {}
    parse_graph(init_program, g, var_dict)
    parse_graph(program, g, var_dict)

    if filename != None:
        g.save()
    return g