graph.py 6.4 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WangZhen 已提交
14 15 16
from __future__ import print_function
import os
import subprocess
W
whs 已提交
17
from ....framework import Program
W
WangZhen 已提交
18 19
from ....framework import Block
from .... import core
W
whs 已提交
20

W
WangZhen 已提交
21
__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph']
W
WangZhen 已提交
22 23 24 25 26 27 28


class PyGraph(object):
    """
    PyGraph uses core.Graph as the delegation to accomplish the manipulation.
    """

29 30 31 32 33 34 35
    def __init__(self, graph, for_test=False):
        """
        Construct the PyGraph using core.Graph.
        Args:
            graph(core.Graph): C++ Graph.
            for_test(bool): True for the test graph and false for the train graph.
        """
W
WangZhen 已提交
36 37 38
        assert isinstance(
            graph, core.Graph), 'graph must be the instance of core.Graph.'
        self.graph = graph
39 40 41 42
        self.for_test = for_test

    def is_test(self):
        return self.for_test
W
WangZhen 已提交
43 44

    def all_parameters(self):
W
WangZhen 已提交
45
        param_nodes = set()
W
WangZhen 已提交
46
        for node in self.graph.nodes():
W
WangZhen 已提交
47 48 49 50
            if node.is_var() and node.var() is not None and node.var(
            ).persistable():
                param_nodes.add(node)
        return param_nodes
W
WangZhen 已提交
51 52

    def all_vars(self):
W
WangZhen 已提交
53
        return {node for node in self.graph.nodes() if node.is_var()}
W
WangZhen 已提交
54 55

    def all_ops(self):
W
WangZhen 已提交
56
        return {node for node in self.graph.nodes() if node.is_op()}
W
WangZhen 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

    def create_param_node(self, name, var_type, shape, var_dtype):
        var_desc = core.VarDesc(name)
        var_desc.set_type(var_type)
        var_desc.set_shape(shape)
        var_desc.set_dtype(var_dtype)
        var_desc.set_persistable(True)
        return self.graph.create_var_node(var_desc)

    def create_var_node(self, name, var_type, shape, var_dtype):
        var_desc = core.VarDesc(name)
        var_desc.set_type(var_type)
        var_desc.set_shape(shape)
        var_desc.set_dtype(var_dtype)
        return self.graph.create_var_node(var_desc)

    def create_var_node_from_desc(self, var_desc):
        return self.graph.create_var_node(var_desc)

    def create_op_node(self, op_type, attrs, inputs, outputs):
        op_desc = core.OpDesc()
        op_desc.set_type(op_type)
        for attr, value in attrs.iteritems():
            self._update_desc_attr(op_desc, attr, value)
W
WangZhen 已提交
81 82 83 84 85 86 87 88 89 90
        for input_name, var_nodes in inputs.iteritems():
            if not isinstance(var_nodes, list):
                var_nodes = [var_nodes]
            op_desc.set_input(input_name,
                              [var_node.name() for var_node in var_nodes])
        for output_name, var_nodes in outputs.iteritems():
            if not isinstance(var_nodes, list):
                var_nodes = [var_nodes]
            op_desc.set_output(output_name,
                               [var_node.name() for var_node in var_nodes])
W
WangZhen 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        return self.graph.create_op_node(op_desc)

    def create_op_node_from_desc(self, op_desc):
        return self.graph.create_op_node(op_desc)

    def _update_desc_attr(self, desc, name, val):
        """
        Update the value of desc's attribute by attribute's name.
        """
        if isinstance(val, Block):
            desc.set_block_attr(name, val.desc)
        elif isinstance(val, list) and val and all(
                isinstance(v, Block) for v in val):
            desc.set_blocks_attr(name, [v.desc for v in val])
        elif isinstance(val, core.BlockDesc) or \
                isinstance(val, core.ProgramDesc):
            desc.set_serialized_attr(name, val.serialize_to_string())
        else:
            desc._set_attr(name, val)
W
whs 已提交
110

W
WangZhen 已提交
111 112 113 114 115
    def safe_remove_nodes(self, remove_nodes):
        if not isinstance(remove_nodes, set):
            remove_nodes = set(remove_nodes)
        core.graph_safe_remove_nodes(self.graph, remove_nodes)

116
    def draw(self, save_path, name, marked_nodes=None):
W
WangZhen 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        def _convert_to_pdf(dot_file_path):
            pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
            exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
                            + ' -o ' + pdf_save_path, shell=True)
            if exited_code != 0:
                print('The dot command is needed for creating pdf files.')
                print('The {} is saved as the dot filetype.'.format(
                    dot_file_path))

        remove_ctr_vars = set()
        ops_num = 0
        for node in self.graph.nodes():
            if node.is_ctrl_var():
                remove_ctr_vars.add(node)
            elif node.is_op():
                ops_num += 1
        print('Total ops num = {}.'.format(ops_num))
        self.safe_remove_nodes(remove_ctr_vars)
        if marked_nodes is not None:
            if not isinstance(marked_nodes, set):
                marked_nodes = set(marked_nodes)
            marked_nodes = marked_nodes - remove_ctr_vars
139 140
            if self.graph.has('__graphviz__marked_node__'):
                self.graph.erase('__graphviz__marked_node__')
W
WangZhen 已提交
141 142 143 144 145 146 147 148 149 150 151
            self.graph.set('__graphviz__marked_node__', marked_nodes)
        viz_dot_path = os.path.join(save_path, name) + '.dot'
        viz_pass = core.get_pass('graph_viz_pass')
        viz_pass.set_str('graph_viz_path', viz_dot_path)
        viz_pass.apply(self.graph)
        _convert_to_pdf(viz_dot_path)

    def to_program(self):
        convert_pass = core.get_pass('graph_to_program_pass')
        convert_pass.set_program('program', Program().desc)
        convert_pass.apply(self.graph)
152 153
        desc = convert_pass.get_program('program')
        program = Program.construct_from_desc(desc)
W
WangZhen 已提交
154 155
        return program

W
whs 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178

class Graph(object):
    """
    Base class for all graph.
    """

    def __init__(self):
        pass

    def all_parameters(self):
        """
        Return all the parameters in current graph.
        """
        pass


class ImitationGraph(Graph):
    def __init__(self, program=None):
        super(ImitationGraph, self).__init__()
        self.program = Program() if program is None else program

    def all_parameters(self):
        return self.program.global_block().all_parameters()
W
WangZhen 已提交
179 180 181 182


class IRGraph(Graph):
    pass