pattern_matcher.py 8.8 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.core.program import PaddleGraph


class PatternMatcher(object):
    def __init__(self, pattern):
        self.pattern = pattern
S
SunAhong1993 已提交
21 22
        # matches的每个match是按照拓扑排序组成layer的dict
        self.matches = list()
S
SunAhong1993 已提交
23

S
SunAhong1993 已提交
24 25 26 27 28
    def operate(self, graph, match_kind="topo"):
        if match_kind == "topo":
            self.detect_patterns_by_topo(graph)
        elif match_kind == "edge":
            self.detect_patterns_by_edge(graph)
S
SunAhong1993 已提交
29
        self.remove_overlapped_match()
S
SunAhong1993 已提交
30
        return self.matches
S
SunAhong1993 已提交
31

S
SunAhong1993 已提交
32
    def detect_patterns_by_topo(self, graph):
S
SunAhong1993 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        """ 找到与模式匹配的子图,
            并将子图的id以拓扑排序存放到subgraph_id2layers。
        """

        def get_subgraph(pattern, graph, start_index):
            pattern_index = 0
            pattern_id2layers = pattern.get_global_layers()
            pattern_ids = list(pattern_id2layers.keys())
            subgraph_id2layers = dict()
            graph_layers = dict(list(graph.layers.items())[start_index:])
            for layer_id, layer in graph_layers.items():
                pattern_layer = pattern.layers[list(pattern.layers.keys())[
                    pattern_index]]
                if layer.kernel == pattern_layer.kernel:
                    subgraph_id2layers[layer_id] = layer
                    pattern_layer_id = pattern_layer.id
                    # 判断输入连接是否一致
                    if layer_id in graph.edges_in:
                        if pattern_layer_id not in pattern.edges_in:
S
SunAhong1993 已提交
52
                            print("1--")
S
SunAhong1993 已提交
53 54 55 56
                            return False
                        else:
                            if len(graph.edges_in[layer_id]) != len(
                                    pattern.edges_in[pattern_layer_id]):
S
SunAhong1993 已提交
57
                                print("2--")
S
SunAhong1993 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70
                                return False
                        layer_in = graph.edges_in[layer_id]
                        pattern_layer_in = pattern.edges_in[pattern_layer_id]
                        for i in range(len(layer_in)):
                            layer_id_in = layer_in[i]
                            pattern_layer_id_in = pattern_layer_in[i]
                            if pattern_layer_id_in != -1:
                                subgraph_ids = list(subgraph_id2layers.keys())
                                if pattern_ids.index(pattern_layer_id_in) == \
                                subgraph_ids.index(layer_id_in):
                                    # 判断pattern输入在pattern_ids的索引
                                    # 和graph输入在subgraph_ids的索引一致
                                    continue
S
SunAhong1993 已提交
71
                                print("3--")
S
SunAhong1993 已提交
72 73 74 75 76 77 78
                                return False
                    # 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效)
                    if layer_id in graph.edges_out:
                        if pattern_layer_id not in pattern.edges_out:
                            if not set(pattern_layer.outputs).issubset(
                                    pattern.outputs):
                                # 若pattern当前layer的输出是pattern的输出,则是正确的
S
SunAhong1993 已提交
79

S
SunAhong1993 已提交
80 81 82 83 84 85 86 87 88 89
                                return False
                        else:
                            if len(graph.edges_out[layer_id]) != len(
                                    pattern.edges_out[pattern_layer_id]):
                                # 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到
                                if not set(pattern_layer.outputs).issubset(
                                        pattern.outputs):
                                    # 若pattern当前layer的输出是pattern的输出,则是正确的
                                    return False
                    # 当为控制流时的处理
S
SunAhong1993 已提交
90 91
                    if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
                        if len(pattern_layer.blocks) != len(layer.blocks):
S
SunAhong1993 已提交
92
                            return False
S
SunAhong1993 已提交
93 94 95 96 97 98 99
                        for i, b in enumerate(pattern_layer.blocks):
                            match_info = get_subgraph(pattern_layer.blocks[i],
                                                      layer.blocks[i], 0)
                            if match_info is not False:
                                subgraph_id2layers.update(match_info)
                            else:
                                return False
S
SunAhong1993 已提交
100 101 102 103 104
                    pattern_index += 1
                    if pattern_index == len(pattern.layers):
                        return subgraph_id2layers
                else:
                    return False
S
SunAhong1993 已提交
105
            return subgraph_id2layers
S
SunAhong1993 已提交
106 107 108 109

        for i, (layer_id, layer) in enumerate(graph.layers.items()):
            match_info = get_subgraph(self.pattern, graph, i)
            if match_info:
S
SunAhong1993 已提交
110
                self.matches.append(match_info)
S
SunAhong1993 已提交
111 112
            for j, block in enumerate(layer.blocks):
                if len(block.layers) > 0:
S
SunAhong1993 已提交
113 114 115 116 117 118
                    self.detect_patterns_by_topo(layer.blocks[j])

    def detect_patterns_by_edge(self, graph):
        """当遇见顺序没有强制规定的pattern时使用该方式
        """
        pass
S
SunAhong1993 已提交
119 120 121 122 123

    def remove_overlapped_match(self):
        """ 如果2个子图有重叠,只取前一个子图。
        """
        match_ids = []
S
SunAhong1993 已提交
124
        for i, match in enumerate(self.matches):
S
SunAhong1993 已提交
125
            is_overlapped = False
S
SunAhong1993 已提交
126
            for id in match.keys():
S
SunAhong1993 已提交
127
                if id in match_ids:
S
SunAhong1993 已提交
128
                    self.matches.pop(i)
S
SunAhong1993 已提交
129 130 131
                    is_overlapped = True
                    break
            if not is_overlapped:
S
SunAhong1993 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
                match_ids.extend(list(match.keys()))


def get_subgraph(prefix_layer_id, suffix_layer_id, graph):
    """ 根据prefix_layer_id和suffix_layer_id获取需要子图。
        Args:
            prefix_layer_id (str): 起初为一个空字符串,之后为suffix_layer_id分割出来的前缀。
            suffix_layer_id (str): 起初为以一个layer的id,之后将分割部分给prefix_layer_id;例如”57.0.1“;
            graph (x2paddle.core.program.PaddleGraph): 需要进行pass的子图。
    """
    id_part = suffix_layer_id.split(".")
    if len(id_part) == 1:
        return graph
    if prefix_layer_id == "":
        layer_id = id_part[0]
        prefix_layer_id += ".".join(id_part[:2])
    else:
        layer_id = prefix_layer_id + "." + id_part[0]
        prefix_layer_id += ("." + ".".join(id_part[:2]))
    subgraph = graph.layers[layer_id].blocks[int(id_part[1])]
    suffix_layer_id = ".".join(id_part[2:])
    return get_subgraph(prefix_layer_id, suffix_layer_id, subgraph)
S
SunAhong1993 已提交
154 155 156


class FuseBase(object):
S
SunAhong1993 已提交
157 158
    def __init__(self, graph_type):
        self.pattern = PaddleGraph(graph_type=graph_type)
S
SunAhong1993 已提交
159

S
SunAhong1993 已提交
160
    def operate(self, graph, match_kind="topo"):
S
SunAhong1993 已提交
161
        parameters = graph.parameters
S
SunAhong1993 已提交
162
        self.build_pattern()
S
SunAhong1993 已提交
163 164 165 166
        self.perform_pattern_matcher(graph, match_kind)
        for match in self.matches:
            first_layer_id = list(match.keys())[0]
            subgraph = get_subgraph("", first_layer_id, graph)
S
SunAhong1993 已提交
167
            self.insert_new_layer(subgraph, parameters, match)
S
SunAhong1993 已提交
168 169 170
        self.delete_inter_layer(graph)
        graph.build()

S
SunAhong1993 已提交
171
    def perform_pattern_matcher(self, graph, match_kind="topo"):
S
SunAhong1993 已提交
172 173 174
        """ 执行模式匹配,找到匹配的子图。
        """
        pattern_matcher = PatternMatcher(self.pattern)
S
SunAhong1993 已提交
175
        self.matches = pattern_matcher.operate(graph, match_kind)
S
SunAhong1993 已提交
176 177 178 179

    def delete_inter_layer(self, graph):
        """ 删除不需要的中间layer及其对应参数。
        """
S
SunAhong1993 已提交
180 181 182 183
        for match in self.matches:
            first_layer_id = list(match.keys())[0]
            subgraph = get_subgraph("", first_layer_id, graph)
            for layer_id, layer in match.items():
S
SunAhong1993 已提交
184 185 186 187 188
                if layer.kernel == "fluid.dygraph.base.to_variable" and \
                layer.attrs["value"].startswith("params["):
                    param_name = layer.attrs["value"][8:-2]
                    if param_name in graph.parameters:
                        graph.parameters.pop(param_name)
S
SunAhong1993 已提交
189
                if layer_id in subgraph.layers:
S
SunAhong1993 已提交
190
                    # layer_id可能是属于子图的,此时删除父layer,即删除整个子图
S
SunAhong1993 已提交
191
                    subgraph.layers.pop(layer_id)