graph.py 3.5 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections


class GraphNode(object):
    def __init__(self, layer, layer_name=None):
        self.inputs = list()
        self.outputs = list()
        self.layer = layer

J
jiangjiajun 已提交
24 25
        assert layer_name is not None, "layer_name for GraphNode should not be None"
        self.layer_name = layer_name
J
jiangjiajun 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

    def __hash__(self):
        return hash(self.layer.name)

    def __eq__(self, other):
        if self.layer.name == other.layer.name:
            return True
        return False


class Graph(object):
    def __init__(self, model):
        self.node_map = collections.OrderedDict()
        self.input_nodes = list()
        self.output_nodes = list()
        self.topo_sort = list()
        self.model = model

J
jiangjiajun 已提交
44
    def build(self):
J
jiangjiajun 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        self._make_input_nodes()
        self._make_output_nodes()
        self._get_topo_sort()

    def _make_input_nodes(self):
        for name, node in self.node_map.items():
            if len(node.inputs) == 0:
                self.input_nodes.append(name)

    def _make_output_nodes(self):
        for name, node in self.node_map.items():
            if len(node.outputs) == 0:
                self.output_nodes.append(name)

    def _get_topo_sort(self):
        num_inputs = dict()
        for name, node in self.node_map.items():
            num_inputs[name] = len(node.inputs)
J
jiangjiajun 已提交
63
        print(len(self.node_map))
J
jiangjiajun 已提交
64 65

        self.topo_sort = self.input_nodes[:]
J
jiangjiajun 已提交
66
        idx = 0
J
jiangjiajun 已提交
67
        while idx < len(self.topo_sort):
J
jiangjiajun 已提交
68 69
            current_node = self.node_map[self.topo_sort[idx]]
            for node in current_node.outputs:
J
jiangjiajun 已提交
70 71 72
                num_inputs[node] -= 1
                if num_inputs[node] == 0:
                    self.topo_sort.append(node)
J
jiangjiajun 已提交
73 74
            idx += 1

J
jiangjiajun 已提交
75 76 77 78 79 80 81
    def get_node(self, name):
        if name not in self.node_map:
            raise Exception("Graph doesn't have node [%s]." % name)
        else:
            return self.node_map[name]

    def connect(self, src, dst):
J
jiangjiajun 已提交
82 83 84
        if dst not in self.node_map:
            raise Exception("node[{}] not in graph".format(dst))
        self.node_map[dst].inputs.append(src)
J
jiangjiajun 已提交
85
        self.node_map[src].outputs.append(dst)
J
jiangjiajun 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

    def remove_node(self, node_name):
        if node_name not in self.node_map:
            raise Exception("Node[{}] not in graph".format(node_name))
        inputs = self.node_map[node_name].inputs
        outputs = self.node_map[node_name].outputs
        for input in inputs:
            idx = self.node_map[input].outputs.index(node_name)
            del self.node_map[input].outputs[idx]
        for output in outputs:
            idx = self.node_map[input].inputs.index(node_name)
            del self.node_map[input].inputs[idx]
        del self.node_map[node_name]

        idx = self.topo_sort.index(node_name)
        del self.topo_sort[idx]
    
    def print(self):
        for i, tmp in enumerate(self.topo_sort):
            print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs)