graph.py 3.3 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

M
mamingjie-China 已提交
15 16
from __future__ import print_function
from __future__ import division
J
jiangjiajun 已提交
17
import collections
J
jiangjiajun 已提交
18
import copy as cp
J
jiangjiajun 已提交
19 20 21 22 23 24 25 26


class GraphNode(object):
    def __init__(self, layer, layer_name=None):
        self.inputs = list()
        self.outputs = list()
        self.layer = layer

J
jiangjiajun 已提交
27 28
        assert layer_name is not None, "layer_name for GraphNode should not be None"
        self.layer_name = layer_name
J
jiangjiajun 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

    def __hash__(self):
        return hash(self.layer.name)

    def __eq__(self, other):
        if self.layer.name == other.layer.name:
            return True
        return False


class Graph(object):
    def __init__(self, model):
        self.node_map = collections.OrderedDict()
        self.input_nodes = list()
        self.output_nodes = list()
        self.topo_sort = list()
        self.model = model

J
jiangjiajun 已提交
47
    def build(self):
J
modify  
jiangjiajun 已提交
48 49 50
        self.get_input_nodes()
        self.get_output_nodes()
        self.get_topo_sort()
J
jiangjiajun 已提交
51

J
modify  
jiangjiajun 已提交
52
    def get_input_nodes(self):
J
jiangjiajun 已提交
53
        for name, node in self.node_map.items():
S
SunAhong1993 已提交
54
            name = name.replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
55 56 57
            if len(node.inputs) == 0:
                self.input_nodes.append(name)

J
modify  
jiangjiajun 已提交
58
    def get_output_nodes(self):
J
jiangjiajun 已提交
59
        for name, node in self.node_map.items():
S
SunAhong1993 已提交
60
            name = name.replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
61 62 63
            if len(node.outputs) == 0:
                self.output_nodes.append(name)

J
modify  
jiangjiajun 已提交
64
    def get_topo_sort(self):
J
jiangjiajun 已提交
65 66 67 68 69
        num_inputs = dict()
        for name, node in self.node_map.items():
            num_inputs[name] = len(node.inputs)

        self.topo_sort = self.input_nodes[:]
J
jiangjiajun 已提交
70
        idx = 0
J
jiangjiajun 已提交
71
        while idx < len(self.topo_sort):
J
jiangjiajun 已提交
72 73
            current_node = self.node_map[self.topo_sort[idx]]
            for node in current_node.outputs:
J
jiangjiajun 已提交
74 75 76
                num_inputs[node] -= 1
                if num_inputs[node] == 0:
                    self.topo_sort.append(node)
J
jiangjiajun 已提交
77 78
            idx += 1

J
jiangjiajun 已提交
79
    def get_node(self, name, copy=False):
J
jiangjiajun 已提交
80
        if name not in self.node_map:
J
jiangjiajun 已提交
81 82
            if name.split(':')[0] in self.node_map:
                name_prefix, idx = name.split(':')
J
jiangjiajun 已提交
83
                if copy:
J
jiangjiajun 已提交
84
                    node = cp.copy(self.node_map[name_prefix])
J
jiangjiajun 已提交
85 86 87 88
                else:
                    node = self.node_map[name_prefix]
                node.index = int(idx)
                return node
J
jiangjiajun 已提交
89
            else:
J
jiangjiajun 已提交
90
                return None
J
jiangjiajun 已提交
91
        else:
J
jiangjiajun 已提交
92
            if copy:
J
jiangjiajun 已提交
93
                node = cp.copy(self.node_map[name])
J
jiangjiajun 已提交
94 95 96
            else:
                node = self.node_map[name]
            return node
J
jiangjiajun 已提交
97 98

    def connect(self, src, dst):
J
jiangjiajun 已提交
99 100 101
        if dst not in self.node_map:
            raise Exception("node[{}] not in graph".format(dst))
        self.node_map[dst].inputs.append(src)
J
jiangjiajun 已提交
102
        self.node_map[src].outputs.append(dst)