interpolate_bilinear_fuser.py 12.2 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
import copy
S
SunAhong1993 已提交
16
import numpy as np
S
SunAhong1993 已提交
17
from x2paddle.optimizer.pattern_matcher import FuseBase
S
SunAhong1993 已提交
18 19 20 21
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *


S
renam  
SunAhong1993 已提交
22
class DygraphInterpolateBilinearFuser(FuseBase):
S
SunAhong1993 已提交
23
    def __init__(self):
S
renam  
SunAhong1993 已提交
24
        super(DygraphInterpolateBilinearFuser, self).__init__(graph_type="dygraph")
S
SunAhong1993 已提交
25 26 27 28 29 30 31 32 33 34 35
        import torch
        torch_version = torch.__version__
        torch_version_part = torch_version.split(".")
        if int(torch_version_part[0]) == 1 and int(torch_version_part[1]) > 5:
            self.version_gt_150 = True
        else:
            self.version_gt_150 = False

    def build_pattern(self):
        """ 描述需要替换的双线性插值图结构。
        interpolate_bilinear层模式python实现代码示例:
S
SunAhong1993 已提交
36 37 38 39 40 41 42 43 44 45
            x2195 = x2181.shape
            x2195 = len(x2195)
            x2196 = x2195 - 2
            x2197 = []
            for _x2199 in range(x2196):
                x2197.append(None)
            x2200 = (x2181, x8, None, None)
            ...
            x2267 = x2266 == 3
            if x2267 :
S
SunAhong1993 已提交
46
                raise RaiseException('Exception')
S
SunAhong1993 已提交
47
                x2268 = None
S
SunAhong1993 已提交
48
            else:
S
SunAhong1993 已提交
49 50 51 52 53 54 55 56 57 58 59
                x2270 = x2181.shape
                x2270 = len(x2270)
                x2271 = x2270 == 4
                if x2271 :
                    x2274 = x2197[0]
                    x2275 = x2197[1]
                    x2233_isinstance = isinstance(x2233, paddle.fluid.Variable)
                    if x2233_isinstance :
                        x2233 = x2233.numpy().tolist()
                    x2276 = paddle.nn.functional.interpolate(x=x2181, size=x2233, scale_factor=x2274, align_corners=False, align_mode=0, mode='bilinear')
                    x2272 = x2276
S
SunAhong1993 已提交
60
                else:
S
SunAhong1993 已提交
61 62 63 64
                    x2277 = x2181.shape
                    x2277 = len(x2277)
                    x2278 = x2277 == 5
                    if x2278 :
S
SunAhong1993 已提交
65 66 67
                        raise RaiseException('Exception')
                    else:
                        raise RaiseException('Exception')
S
SunAhong1993 已提交
68 69
                    x2272 = None
                x2268 = x2272
S
SunAhong1993 已提交
70 71 72 73 74 75 76
        """

        def gen_name(id):
            return "x" + str(id)

        if self.version_gt_150:
            self.pattern.add_layer(
S
SunAhong1993 已提交
77
                "prim.shape",
S
SunAhong1993 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
                inputs={"input": "interpolate-input-0"},
                outputs=[gen_name(9)])
            self.pattern.add_layer(
                "prim.len",
                inputs={"input": gen_name(9)},
                outputs=[gen_name(9)])
            self.pattern.add_layer(
                "prim.sub",
                inputs={"x": gen_name(9)},
                outputs=[gen_name(10)],
                y=2)
            self.pattern.add_layer(
                "prim.list", inputs={}, outputs=[gen_name(11)])
            self.pattern.add_layer(
                "prim.loop",
                inputs={"input": gen_name(10)},
                outputs=[gen_name(12.1), gen_name(12.2)])
            loop_layer = self.pattern.layers[list(self.pattern.layers.keys())[
                -1]]
            pattern_block = PaddleGraph(loop_layer, graph_type="dygraph")
            pattern_block.add_layer(
                "prim.append",
                inputs={"list": gen_name(11)},
                outputs=[],
                element=None)
            loop_layer.inputs["input-0"] = gen_name(11)
            loop_layer.add_block(pattern_block)
            self.pattern.add_layer(
                "prim.tuple",
                inputs={
                    "input0": "interpolate-input-0",
S
SunAhong1993 已提交
109
                    "input1": "interpolate-input-4",
S
SunAhong1993 已提交
110
                },
S
SunAhong1993 已提交
111
                outputs=[gen_name(12)],
S
SunAhong1993 已提交
112 113
                input2=None,
                input3=None)
S
SunAhong1993 已提交
114
            
S
SunAhong1993 已提交
115 116
            self.pattern.add_layer(
                "prim.eq",
S
SunAhong1993 已提交
117 118
                inputs={"x": "interpolate-input-2"},
                outputs=[gen_name(10.1)],
S
SunAhong1993 已提交
119
                y=3)
S
SunAhong1993 已提交
120
            
S
SunAhong1993 已提交
121 122
            self.pattern.add_layer(
                "prim.if",
S
SunAhong1993 已提交
123
                inputs={"input": gen_name(10.1)},
S
SunAhong1993 已提交
124 125 126
                outputs=[gen_name(14)])
            if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[
                -1]]
S
SunAhong1993 已提交
127
            pattern_block = PaddleGraph(parent_layer=if_layer1, graph_type="dygraph")
S
SunAhong1993 已提交
128 129 130
            pattern_block.add_layer(
                "prim.exception",
                inputs={},
S
SunAhong1993 已提交
131
                outputs=[gen_name(15)],
S
SunAhong1993 已提交
132 133
                input="Exception")
            pattern_block.add_layer(
S
SunAhong1993 已提交
134
                "prim.equal", inputs={}, outputs=[gen_name(14)], input=None)
S
SunAhong1993 已提交
135
            if_layer1.add_block(pattern_block)
S
SunAhong1993 已提交
136
            pattern_block = PaddleGraph(parent_layer=if_layer1, graph_type="dygraph")
S
SunAhong1993 已提交
137
            pattern_block.add_layer(
S
SunAhong1993 已提交
138
                "prim.shape",
S
SunAhong1993 已提交
139 140 141 142 143 144 145 146 147 148 149
                inputs={"input": "interpolate-input-0"},
                outputs=[gen_name(18)])
            pattern_block.add_layer(
                "prim.len",
                inputs={"input": gen_name(18)},
                outputs=[gen_name(18)])
            pattern_block.add_layer(
                "prim.eq",
                inputs={"x": gen_name(18)},
                outputs=[gen_name(19)],
                y=4)
S
SunAhong1993 已提交
150 151
            
            
S
SunAhong1993 已提交
152 153 154 155 156 157
            pattern_block.add_layer(
                "prim.if",
                inputs={"input": gen_name(19)},
                outputs=[gen_name(20)])
            if_layer2 = pattern_block.layers[list(pattern_block.layers.keys())[
                -1]]
S
SunAhong1993 已提交
158
            pattern_block_block = PaddleGraph(parent_layer=if_layer2, graph_type="dygraph")            
S
SunAhong1993 已提交
159 160 161
            pattern_block_block.add_layer(
                "prim.getitem",
                inputs={"list": gen_name(11)},
S
SunAhong1993 已提交
162
                outputs=[gen_name(21)],
S
SunAhong1993 已提交
163 164 165 166
                element=0)
            pattern_block_block.add_layer(
                "prim.getitem",
                inputs={"list": gen_name(11)},
S
SunAhong1993 已提交
167
                outputs=[gen_name(22)],
S
SunAhong1993 已提交
168 169 170
                element=1)
            pattern_block_block.add_layer(
                "prim.isinstance",
S
SunAhong1993 已提交
171
                inputs={"input": "interpolate-input-3"},
S
SunAhong1993 已提交
172 173 174 175 176 177 178 179 180 181 182
                outputs=["interpolate-input-0_isinstance"],
                cls="paddle.fluid.Variable")
            pattern_block_block.add_layer(
                "prim.if", {"input": "interpolate-input-0_isinstance"},
                outputs=["interpolate-input-0_if1"])
            if_layer_isinstance = pattern_block_block.layers[list(
                pattern_block_block.layers.keys())[-1]]
            pattern_block_block_block = PaddleGraph(
                if_layer_isinstance, graph_type="dygraph")
            pattern_block_block_block.add_layer(
                "prim.var2list",
S
SunAhong1993 已提交
183 184
                inputs={"input": "interpolate-input-3"},
                outputs=["interpolate-input-3"])
S
SunAhong1993 已提交
185 186 187 188
            if_layer_isinstance.add_block(pattern_block_block_block)
            pattern_block_block_block = PaddleGraph(
                if_layer_isinstance, graph_type="dygraph")
            if_layer_isinstance.add_block(pattern_block_block_block)
S
SunAhong1993 已提交
189
            if_layer_isinstance.inputs["input-0"] = "interpolate-input-3"
S
SunAhong1993 已提交
190 191 192 193
            pattern_block_block.add_layer(
                "paddle.nn.functional.interpolate",
                inputs={
                    "input": "interpolate-input-0",
S
SunAhong1993 已提交
194 195
                    "size": "interpolate-input-3",
                    "scale_factor": gen_name(21)
S
SunAhong1993 已提交
196
                },
S
SunAhong1993 已提交
197
                outputs=[gen_name(23)])
S
SunAhong1993 已提交
198 199
            pattern_block_block.add_layer(
                "prim.equal",
S
SunAhong1993 已提交
200
                inputs={"input": gen_name(23)},
S
SunAhong1993 已提交
201 202 203 204
                outputs=[gen_name(20)])
            if_layer2.add_block(pattern_block_block)
            pattern_block_block = PaddleGraph(if_layer2, graph_type="dygraph")
            pattern_block_block.add_layer(
S
SunAhong1993 已提交
205
                "prim.shape",
S
SunAhong1993 已提交
206
                inputs={"input": "interpolate-input-0"},
S
SunAhong1993 已提交
207
                outputs=[gen_name(24)])
S
SunAhong1993 已提交
208 209
            pattern_block_block.add_layer(
                "prim.len",
S
SunAhong1993 已提交
210 211
                inputs={"input": gen_name(24)},
                outputs=[gen_name(24)])
S
SunAhong1993 已提交
212 213
            pattern_block_block.add_layer(
                "prim.eq",
S
SunAhong1993 已提交
214 215
                inputs={"x": gen_name(24)},
                outputs=[gen_name(25)],
S
SunAhong1993 已提交
216 217 218
                y=5)
            pattern_block_block.add_layer(
                "prim.if",
S
SunAhong1993 已提交
219 220 221
                inputs={"input": gen_name(25)},
                outputs=[gen_name(26)])
            if_layer3 = pattern_block_block.layers[list(
S
SunAhong1993 已提交
222 223
                pattern_block_block.layers.keys())[-1]]
            pattern_block_block_block = PaddleGraph(
S
SunAhong1993 已提交
224
                parent_layer=if_layer3, graph_type="dygraph")
S
SunAhong1993 已提交
225 226 227
            pattern_block_block_block.add_layer(
                "prim.exception",
                inputs={},
S
SunAhong1993 已提交
228
                outputs=[gen_name(27)],
S
SunAhong1993 已提交
229
                input="Exception")
S
SunAhong1993 已提交
230
            if_layer3.add_block(pattern_block_block_block)
S
SunAhong1993 已提交
231
            pattern_block_block_block = PaddleGraph(
S
SunAhong1993 已提交
232
                parent_layer=if_layer3, graph_type="dygraph")
S
SunAhong1993 已提交
233 234 235
            pattern_block_block_block.add_layer(
                "prim.exception",
                inputs={},
S
SunAhong1993 已提交
236
                outputs=[gen_name(28)],
S
SunAhong1993 已提交
237
                input="Exception")
S
SunAhong1993 已提交
238
            if_layer3.add_block(pattern_block_block_block)
S
SunAhong1993 已提交
239 240 241 242
            pattern_block_block.add_layer(
                "prim.equal", inputs={}, outputs=[gen_name(20)], input=None)
            if_layer2.add_block(pattern_block_block)
            if_layer2.inputs.update({
S
SunAhong1993 已提交
243 244 245
                "input-0": "interpolate-input-0",
                "input-1": "interpolate-input-3",
                "input-2": "interpolate-input-3",
S
SunAhong1993 已提交
246 247 248 249 250 251
                "input-3": gen_name(11),
                "input-5": gen_name(11),
            })
            pattern_block.add_layer(
                "prim.equal",
                inputs={"input": gen_name(20)},
S
SunAhong1993 已提交
252
                outputs=[gen_name(14)])
S
SunAhong1993 已提交
253 254
            if_layer1.add_block(pattern_block)
            if_layer1.inputs.update({
S
SunAhong1993 已提交
255 256 257 258 259 260
                'input-2': 'interpolate-input-0', 
                'input-4': gen_name(11), 
                'input-6': gen_name(11), 
                'input-8': 'interpolate-input-0', 
                'input-9': 'interpolate-input-3', 
                'input-10': 'interpolate-input-0'
S
SunAhong1993 已提交
261 262 263
            })
            self.pattern.build(inputs={
                "input-0": "interpolate-input-0",
S
SunAhong1993 已提交
264 265 266 267
                "input-1": "interpolate-input-1",
                "input-2": "interpolate-input-2",
                "input-3": "interpolate-input-3",
                "input-4": "interpolate-input-4"
S
SunAhong1993 已提交
268
            })
S
SunAhong1993 已提交
269 270 271
            
            
            
S
SunAhong1993 已提交
272 273

    def insert_new_layer(self, graph, parameters, matches):
S
SunAhong1993 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
        new_layer = self.gen_new_layer(parameters, matches)
        global_layers = graph.get_global_layers()
        new_matches = dict()
        is_match = False
        for layer_id, layer in global_layers.items():
            if layer_id == list(matches.keys())[0] and not is_match:
                new_matches[layer_id] = layer
                is_match = True
            if is_match:
                new_matches[layer_id] = layer
                if layer_id == list(matches.keys())[-1]:
                    break
        new_layer_id = new_layer.layer_id
        graph.layers[new_layer_id] = new_layer
        new_matches.pop(new_layer_id)
        matches.clear()
        for layer_id, layer in new_matches.items():
            matches[layer_id] = layer
        
S
SunAhong1993 已提交
293 294 295 296 297 298

    def gen_new_layer(self, parameters, matches):
        layers = list()
        layers_id = list(matches.keys())
        layer = matches[layers_id[6]]
        size = layer.inputs["input1"]
S
SunAhong1993 已提交
299 300 301 302 303 304 305 306
        layer = matches[layers_id[19]]
        new_layer = copy.deepcopy(layer)
        layer = matches[layers_id[9]]
        new_layer.outputs[0] = layer.outputs[0]
        new_layer.layer_id = layers_id[7]
        new_layer.inputs.pop("scale_factor")
        new_layer.inputs["size"] = size
        return new_layer