graph.py 3.3 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
J
jiangjiajun 已提交
16
import copy as cp
J
jiangjiajun 已提交
17 18 19 20 21 22 23 24


class GraphNode(object):
    def __init__(self, layer, layer_name=None):
        self.inputs = list()
        self.outputs = list()
        self.layer = layer

J
jiangjiajun 已提交
25 26
        assert layer_name is not None, "layer_name for GraphNode should not be None"
        self.layer_name = layer_name
J
jiangjiajun 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

    def __hash__(self):
        return hash(self.layer.name)

    def __eq__(self, other):
        if self.layer.name == other.layer.name:
            return True
        return False


class Graph(object):
    def __init__(self, model):
        self.node_map = collections.OrderedDict()
        self.input_nodes = list()
        self.output_nodes = list()
        self.topo_sort = list()
        self.model = model

J
jiangjiajun 已提交
45
    def build(self):
J
jiangjiajun 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        self._make_input_nodes()
        self._make_output_nodes()
        self._get_topo_sort()

    def _make_input_nodes(self):
        for name, node in self.node_map.items():
            if len(node.inputs) == 0:
                self.input_nodes.append(name)

    def _make_output_nodes(self):
        for name, node in self.node_map.items():
            if len(node.outputs) == 0:
                self.output_nodes.append(name)

    def _get_topo_sort(self):
        num_inputs = dict()
        for name, node in self.node_map.items():
            num_inputs[name] = len(node.inputs)

        self.topo_sort = self.input_nodes[:]
J
jiangjiajun 已提交
66
        idx = 0
J
jiangjiajun 已提交
67
        while idx < len(self.topo_sort):
J
jiangjiajun 已提交
68 69
            current_node = self.node_map[self.topo_sort[idx]]
            for node in current_node.outputs:
J
jiangjiajun 已提交
70 71 72
                num_inputs[node] -= 1
                if num_inputs[node] == 0:
                    self.topo_sort.append(node)
J
jiangjiajun 已提交
73 74
            idx += 1

J
jiangjiajun 已提交
75
    def get_node(self, name, copy=False):
J
jiangjiajun 已提交
76
        if name not in self.node_map:
J
jiangjiajun 已提交
77 78
            if name.split(':')[0] in self.node_map:
                name_prefix, idx = name.split(':')
J
jiangjiajun 已提交
79
                if copy:
J
jiangjiajun 已提交
80
                    node = cp.copy(self.node_map[name_prefix])
J
jiangjiajun 已提交
81 82 83 84
                else:
                    node = self.node_map[name_prefix]
                node.index = int(idx)
                return node
J
jiangjiajun 已提交
85 86
            else:
                raise Exception("Graph doesn't have node [%s]." % name)
J
jiangjiajun 已提交
87
        else:
J
jiangjiajun 已提交
88
            if copy:
J
jiangjiajun 已提交
89
                node = cp.copy(self.node_map[name])
J
jiangjiajun 已提交
90 91 92
            else:
                node = self.node_map[name]
            return node
J
jiangjiajun 已提交
93 94

    def connect(self, src, dst):
J
jiangjiajun 已提交
95 96 97
        if dst not in self.node_map:
            raise Exception("node[{}] not in graph".format(dst))
        self.node_map[dst].inputs.append(src)
J
jiangjiajun 已提交
98
        self.node_map[src].outputs.append(dst)
J
jiangjiajun 已提交
99 100 101

    def print(self):
        for i, tmp in enumerate(self.topo_sort):
102 103
            print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs,
                  self.node_map[tmp].outputs)