interpolate_bilinear_fuser.py 10.5 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
import copy
S
SunAhong1993 已提交
16
import numpy as np
S
SunAhong1993 已提交
17
from x2paddle.optimizer.pattern_matcher import FuseBase
S
SunAhong1993 已提交
18 19 20 21
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *


S
SunAhong1993 已提交
22
class InterpolateBilinearFuser(FuseBase):
S
SunAhong1993 已提交
23
    def __init__(self):
S
SunAhong1993 已提交
24
        super(InterpolateBilinearFuser, self).__init__()
S
SunAhong1993 已提交
25
        self.pattenrs = list()
S
SunAhong1993 已提交
26 27 28 29

    def build_pattern(self):
        """ 描述需要替换的双线性插值图结构。
        interpolate_bilinear层模式python实现代码示例:
S
SunAhong1993 已提交
30 31 32 33 34 35 36 37 38 39
            x2195 = x2181.shape
            x2195 = len(x2195)
            x2196 = x2195 - 2
            x2197 = []
            for _x2199 in range(x2196):
                x2197.append(None)
            x2200 = (x2181, x8, None, None)
            ...
            x2267 = x2266 == 3
            if x2267 :
S
SunAhong1993 已提交
40
                raise RaiseException('Exception')
S
SunAhong1993 已提交
41
                x2268 = None
S
SunAhong1993 已提交
42
            else:
S
SunAhong1993 已提交
43 44 45 46 47 48
                x2270 = x2181.shape
                x2270 = len(x2270)
                x2271 = x2270 == 4
                if x2271 :
                    x2274 = x2197[0]
                    x2275 = x2197[1]
W
wjj19950828 已提交
49
                    x2233_isinstance = isinstance(x2233, paddle.static.Variable)
S
SunAhong1993 已提交
50 51 52 53
                    if x2233_isinstance :
                        x2233 = x2233.numpy().tolist()
                    x2276 = paddle.nn.functional.interpolate(x=x2181, size=x2233, scale_factor=x2274, align_corners=False, align_mode=0, mode='bilinear')
                    x2272 = x2276
S
SunAhong1993 已提交
54
                else:
S
SunAhong1993 已提交
55 56 57 58
                    x2277 = x2181.shape
                    x2277 = len(x2277)
                    x2278 = x2277 == 5
                    if x2278 :
S
SunAhong1993 已提交
59 60 61
                        raise RaiseException('Exception')
                    else:
                        raise RaiseException('Exception')
S
SunAhong1993 已提交
62 63
                    x2272 = None
                x2268 = x2272
S
SunAhong1993 已提交
64 65 66 67 68
        """

        def gen_name(id):
            return "x" + str(id)

S
SunAhong1993 已提交
69
        pattern = PaddleGraph()
S
SunAhong1993 已提交
70 71 72 73 74
        pattern.add_layer(
            "prim.shape",
            inputs={"input": "interpolate-input-0"},
            outputs=[gen_name(9)])
        pattern.add_layer(
S
SunAhong1993 已提交
75
            "prim.len", inputs={"input": gen_name(9)}, outputs=[gen_name(9)])
S
SunAhong1993 已提交
76
        pattern.add_layer(
S
SunAhong1993 已提交
77 78
            "prim.sub", inputs={"x": gen_name(9)}, outputs=[gen_name(10)], y=2)
        pattern.add_layer("prim.list", inputs={}, outputs=[gen_name(11)])
S
SunAhong1993 已提交
79 80 81 82
        pattern.add_layer(
            "prim.loop",
            inputs={"input": gen_name(10)},
            outputs=[gen_name(12.1), gen_name(12.2)])
S
SunAhong1993 已提交
83 84
        loop_layer = pattern.layers[list(pattern.layers.keys())[-1]]
        pattern_block = PaddleGraph(loop_layer)
S
SunAhong1993 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        pattern_block.add_layer(
            "prim.append",
            inputs={"list": gen_name(11)},
            outputs=[],
            element=None)
        loop_layer.inputs["input-0"] = gen_name(11)
        loop_layer.add_block(pattern_block)
        pattern.add_layer(
            "prim.tuple",
            inputs={
                "input0": "interpolate-input-0",
                "input1": "interpolate-input-4",
            },
            outputs=[gen_name(12)],
            input2=None,
            input3=None)

        pattern.add_layer(
            "prim.eq",
            inputs={"x": "interpolate-input-2"},
            outputs=[gen_name(10.1)],
            y=3)

        pattern.add_layer(
S
SunAhong1993 已提交
109
            "prim.if", inputs={"input": gen_name(10.1)},
S
SunAhong1993 已提交
110
            outputs=[gen_name(14)])
S
SunAhong1993 已提交
111 112
        if_layer1 = pattern.layers[list(pattern.layers.keys())[-1]]
        pattern_block = PaddleGraph(parent_layer=if_layer1)
S
SunAhong1993 已提交
113 114 115 116 117 118 119 120
        pattern_block.add_layer(
            "prim.exception",
            inputs={},
            outputs=[gen_name(15)],
            input="Exception")
        pattern_block.add_layer(
            "prim.equal", inputs={}, outputs=[gen_name(14)], input=None)
        if_layer1.add_block(pattern_block)
S
SunAhong1993 已提交
121
        pattern_block = PaddleGraph(parent_layer=if_layer1)
S
SunAhong1993 已提交
122 123 124 125 126
        pattern_block.add_layer(
            "prim.shape",
            inputs={"input": "interpolate-input-0"},
            outputs=[gen_name(18)])
        pattern_block.add_layer(
S
SunAhong1993 已提交
127
            "prim.len", inputs={"input": gen_name(18)}, outputs=[gen_name(18)])
S
SunAhong1993 已提交
128
        pattern_block.add_layer(
S
SunAhong1993 已提交
129
            "prim.eq", inputs={"x": gen_name(18)}, outputs=[gen_name(19)], y=4)
S
SunAhong1993 已提交
130 131

        pattern_block.add_layer(
S
SunAhong1993 已提交
132 133 134
            "prim.if", inputs={"input": gen_name(19)}, outputs=[gen_name(20)])
        if_layer2 = pattern_block.layers[list(pattern_block.layers.keys())[-1]]
        pattern_block_block = PaddleGraph(parent_layer=if_layer2)
S
SunAhong1993 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148
        pattern_block_block.add_layer(
            "prim.getitem",
            inputs={"list": gen_name(11)},
            outputs=[gen_name(21)],
            element=0)
        pattern_block_block.add_layer(
            "prim.getitem",
            inputs={"list": gen_name(11)},
            outputs=[gen_name(22)],
            element=1)
        pattern_block_block.add_layer(
            "prim.isinstance",
            inputs={"input": "interpolate-input-3"},
            outputs=["interpolate-input-0_isinstance"],
W
wjj19950828 已提交
149
            cls="paddle.static.Variable")
S
SunAhong1993 已提交
150 151 152 153 154
        pattern_block_block.add_layer(
            "prim.if", {"input": "interpolate-input-0_isinstance"},
            outputs=["interpolate-input-0_if1"])
        if_layer_isinstance = pattern_block_block.layers[list(
            pattern_block_block.layers.keys())[-1]]
S
SunAhong1993 已提交
155
        pattern_block_block_block = PaddleGraph(if_layer_isinstance)
S
SunAhong1993 已提交
156 157 158 159 160
        pattern_block_block_block.add_layer(
            "prim.var2list",
            inputs={"input": "interpolate-input-3"},
            outputs=["interpolate-input-3"])
        if_layer_isinstance.add_block(pattern_block_block_block)
S
SunAhong1993 已提交
161
        pattern_block_block_block = PaddleGraph(if_layer_isinstance)
S
SunAhong1993 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175
        if_layer_isinstance.add_block(pattern_block_block_block)
        if_layer_isinstance.inputs["input-0"] = "interpolate-input-3"
        pattern_block_block.add_layer(
            "paddle.nn.functional.interpolate",
            inputs={
                "input": "interpolate-input-0",
                "size": "interpolate-input-3",
            },
            outputs=[gen_name(23)])
        pattern_block_block.add_layer(
            "prim.equal",
            inputs={"input": gen_name(23)},
            outputs=[gen_name(20)])
        if_layer2.add_block(pattern_block_block)
S
SunAhong1993 已提交
176
        pattern_block_block = PaddleGraph(if_layer2)
S
SunAhong1993 已提交
177 178 179 180 181
        pattern_block_block.add_layer(
            "prim.shape",
            inputs={"input": "interpolate-input-0"},
            outputs=[gen_name(24)])
        pattern_block_block.add_layer(
S
SunAhong1993 已提交
182
            "prim.len", inputs={"input": gen_name(24)}, outputs=[gen_name(24)])
S
SunAhong1993 已提交
183
        pattern_block_block.add_layer(
S
SunAhong1993 已提交
184
            "prim.eq", inputs={"x": gen_name(24)}, outputs=[gen_name(25)], y=5)
S
SunAhong1993 已提交
185
        pattern_block_block.add_layer(
S
SunAhong1993 已提交
186
            "prim.if", inputs={"input": gen_name(25)}, outputs=[gen_name(26)])
S
SunAhong1993 已提交
187 188
        if_layer3 = pattern_block_block.layers[list(
            pattern_block_block.layers.keys())[-1]]
S
SunAhong1993 已提交
189
        pattern_block_block_block = PaddleGraph(parent_layer=if_layer3)
S
SunAhong1993 已提交
190 191 192 193 194 195
        pattern_block_block_block.add_layer(
            "prim.exception",
            inputs={},
            outputs=[gen_name(27)],
            input="Exception")
        if_layer3.add_block(pattern_block_block_block)
S
SunAhong1993 已提交
196
        pattern_block_block_block = PaddleGraph(parent_layer=if_layer3)
S
SunAhong1993 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
        pattern_block_block_block.add_layer(
            "prim.exception",
            inputs={},
            outputs=[gen_name(28)],
            input="Exception")
        if_layer3.add_block(pattern_block_block_block)
        pattern_block_block.add_layer(
            "prim.equal", inputs={}, outputs=[gen_name(20)], input=None)
        if_layer2.add_block(pattern_block_block)
        if_layer2.inputs.update({
            "input-0": "interpolate-input-0",
            "input-1": "interpolate-input-3",
            "input-2": "interpolate-input-3",
            "input-3": gen_name(11),
            "input-5": gen_name(11),
        })
        pattern_block.add_layer(
            "prim.equal",
            inputs={"input": gen_name(20)},
            outputs=[gen_name(14)])
        if_layer1.add_block(pattern_block)
        if_layer1.inputs.update({
S
SunAhong1993 已提交
219 220 221 222 223
            'input-2': 'interpolate-input-0',
            'input-4': gen_name(11),
            'input-6': gen_name(11),
            'input-8': 'interpolate-input-0',
            'input-9': 'interpolate-input-3',
S
SunAhong1993 已提交
224 225 226 227 228 229 230 231 232 233
            'input-10': 'interpolate-input-0'
        })
        pattern.build(inputs={
            "input-0": "interpolate-input-0",
            "input-1": "interpolate-input-1",
            "input-2": "interpolate-input-2",
            "input-3": "interpolate-input-3",
            "input-4": "interpolate-input-4"
        })
        self.patterns.append(pattern)
S
SunAhong1993 已提交
234 235

    def insert_new_layer(self, graph, parameters, matches):
S
SunAhong1993 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        new_layer = self.gen_new_layer(parameters, matches)
        global_layers = graph.get_global_layers()
        new_matches = dict()
        is_match = False
        for layer_id, layer in global_layers.items():
            if layer_id == list(matches.keys())[0] and not is_match:
                new_matches[layer_id] = layer
                is_match = True
            if is_match:
                new_matches[layer_id] = layer
                if layer_id == list(matches.keys())[-1]:
                    break
        new_layer_id = new_layer.layer_id
        graph.layers[new_layer_id] = new_layer
        new_matches.pop(new_layer_id)
        matches.clear()
        for layer_id, layer in new_matches.items():
            matches[layer_id] = layer
S
SunAhong1993 已提交
254 255 256 257 258 259

    def gen_new_layer(self, parameters, matches):
        layers = list()
        layers_id = list(matches.keys())
        layer = matches[layers_id[6]]
        size = layer.inputs["input1"]
S
SunAhong1993 已提交
260 261 262 263 264 265 266
        layer = matches[layers_id[19]]
        new_layer = copy.deepcopy(layer)
        layer = matches[layers_id[9]]
        new_layer.outputs[0] = layer.outputs[0]
        new_layer.layer_id = layers_id[7]
        new_layer.inputs["size"] = size
        return new_layer