transpose_elimination.py 14.1 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
S
SunAhong1993 已提交
16
import sys
S
SunAhong1993 已提交
17 18 19 20 21 22
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *


S
SunAhong1993 已提交
23
class DygraphTransposeElimination(FuseBase):
S
SunAhong1993 已提交
24
    def __init__(self):
S
SunAhong1993 已提交
25
        super(DygraphTransposeElimination, self).__init__(graph_type="dygraph")
S
SunAhong1993 已提交
26 27 28 29 30
        self.direct_layers = [
            'paddle.nn.ReLU', 'paddle.nn.ReLU6', 'paddle.abs',
            'paddle.nn.Sigmoid', 'paddle.exp', 'paddle.rsqrt',
            'paddle.nn.Swish', 'paddle.nn.Tanh',
            'paddle.nn.Softplus', 'paddle.nn.LeakyReLU',
S
SunAhong1993 已提交
31
            'paddle.floor', 'paddle.erf', 'paddle.square'
S
SunAhong1993 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        ]
        self.elementwise_layers = [
            'paddle.add', 'fluid.layers.elementwise_sub',
            'paddle.multiply', 'paddle.divide'
        ]
        self.reduce_layers = [
            'paddle.mean', 'paddle.all',
            'paddle.max', 'paddle.any',
            'paddle.sum', 'paddle.prod'
        ]

    def get_transpose_num(self, graph):
        count = 0
        for layer_id, layer in graph.layers.items():
            if layer.kernel == "paddle.transpose":
                count += 1
        return count
    
    def operate(self, graph):
        total_layer_num = len(graph.layers)
        scanned_layers = set()
        optimized_transpose_layers = list()
        optimized_reduce_layers = list()
        optimized_concat_layers = list()
        optimized_elementwise_layers = list()
S
SunAhong1993 已提交
57 58 59 60 61 62
        
        def get_index(layer):
            if layer.kernel.startswith("paddle.nn") and "functional" not in layer.kernel:
                return 1
            else:
                return 0 
S
SunAhong1993 已提交
63 64 65 66 67 68 69 70

        def strip_transpose(_graph):
            layers = copy.deepcopy(_graph.layers)
            for layer_id, layer in layers.items():
                if layer_id in scanned_layers:
                    continue
                scanned_layers.add(layer_id)
                percent = round(len(scanned_layers) / total_layer_num * 100, 2)
S
SunAhong1993 已提交
71
                sys.stderr.write("\rOptimize Transpose Layers...{}%".format(
S
SunAhong1993 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
                    percent))

                if layer.kernel != "paddle.transpose":
                    continue
                if layer.attrs["perm"] != [0, 2, 3, 1]:
                    continue
                transpose_layers = list()
                propagate_layers = list()
                reduce_layers = list()
                concat_layers = list()
                # 此elementwise_layers专用于存储shape(4) + shape(1)的形式layer
                elementwise_layers = list()
                can_be_optimized = True
                for out in _graph.edges_out.get(layer_id, []):
                    if _graph.layers[out].kernel == "paddle.transpose":
                        if _graph.layers[out].attrs["perm"] != [0, 3, 1, 2]:
                            can_be_optimized = False
                            break
                        transpose_layers.append(out)
                    elif _graph.layers[out].kernel in self.elementwise_layers:
                        propagate_layers.append(out)
                    elif _graph.layers[out].kernel in self.direct_layers:
S
SunAhong1993 已提交
94
                        ouput_index = get_index(_graph.layers[out])
S
SunAhong1993 已提交
95 96 97 98 99
                        if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
                            can_be_optimized = False
                            break
                        propagate_layers.append(out)
                    elif _graph.layers[out].kernel in self.reduce_layers:
S
SunAhong1993 已提交
100 101
                        ouput_index = get_index(_graph.layers[out])
                        if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
S
SunAhong1993 已提交
102 103
                            can_be_optimized = False
                            break
S
SunAhong1993 已提交
104
                        if not _graph.layers[out].attrs.get('keepdim', False):
S
SunAhong1993 已提交
105 106 107 108 109
                            can_be_optimized = False
                            break
                        propagate_layers.append(out)
                        reduce_layers.append(out)
                    elif _graph.layers[out].kernel == "paddle.concat":
S
SunAhong1993 已提交
110 111
                        ouput_index = get_index(_graph.layers[out])
                        if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
S
SunAhong1993 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
                            can_be_optimized = False
                            break
                        propagate_layers.append(out)
                        concat_layers.append(out)
                    else:
                        can_be_optimized = False
                        break

                visited_layers = set()
                while len(propagate_layers) > 0 and can_be_optimized:
                    current_id = propagate_layers.pop(0)
                    visited_layers.add(current_id)
                    for out in _graph.edges_out.get(current_id, []):
                        if _graph.layers[
                                out].kernel == "paddle.transpose":
                            if _graph.layers[out].attrs["perm"] != [0, 3, 1, 2]:
                                can_be_optimized = False
                                break
                            transpose_layers.append(out)
                        elif _graph.layers[
                                out].kernel in self.elementwise_layers:
S
SunAhong1993 已提交
133 134
                            output_index = get_index(_graph.layers[out])
                            if _graph.layers[out].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
135 136 137 138 139
                                can_be_optimized = False
                                break
                            if out not in visited_layers:
                                propagate_layers.append(out)
                        elif _graph.layers[out].kernel in self.direct_layers:
S
SunAhong1993 已提交
140 141
                            output_index = get_index(_graph.layers[out])
                            if _graph.layers[out].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
142 143 144 145 146
                                can_be_optimized = False
                                break
                            if out not in visited_layers:
                                propagate_layers.append(out)
                        elif _graph.layers[out].kernel in self.reduce_layers:
S
SunAhong1993 已提交
147 148
                            output_index = get_index(_graph.layers[out])
                            if _graph.layers[out].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
149 150
                                can_be_optimized = False
                                break
S
SunAhong1993 已提交
151
                            if not _graph.layers[out].attrs.get('keepdim',
S
SunAhong1993 已提交
152 153 154 155 156 157 158
                                                                False):
                                can_be_optimized = False
                                break
                            if out not in visited_layers:
                                propagate_layers.append(out)
                                reduce_layers.append(out)
                        elif _graph.layers[out].kernel == "paddle.concat":
S
SunAhong1993 已提交
159 160
                            output_index = get_index(_graph.layers[out])
                            if _graph.layers[out].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
                                can_be_optimized = False
                                break
                            if out not in visited_layers:
                                propagate_layers.append(out)
                                concat_layers.append(out)
                        else:
                            can_be_optimized = False
                            break
                    for ipt in _graph.edges_in.get(current_id, []):
                        if _graph.layers[
                                current_id].kernel in self.elementwise_layers:
                            try:
                                x_shape = _graph.layers[
                                    current_id].input_shapes['x']
                                y_shape = _graph.layers[
                                    current_id].input_shapes['y']
S
SunAhong1993 已提交
177
                                output_index = get_index(_graph.layers[ipt])
S
SunAhong1993 已提交
178
                                if _graph.layers[ipt].outputs[
S
SunAhong1993 已提交
179
                                        output_index] == _graph.layers[current_id].inputs[
S
SunAhong1993 已提交
180
                                            'x']:
S
SunAhong1993 已提交
181
                                    if list(x_shape)==[1] or len(x_shape) < 1:
S
SunAhong1993 已提交
182 183 184
                                        elementwise_layers.append(current_id)
                                        continue
                                elif _graph.layers[ipt].outputs[
S
SunAhong1993 已提交
185
                                        output_index] == _graph.layers[current_id].inputs[
S
SunAhong1993 已提交
186
                                            'y']:
S
SunAhong1993 已提交
187
                                    if list(y_shape)==[1] or len(y_shape) < 1:
S
SunAhong1993 已提交
188 189 190 191 192 193 194 195 196
                                        elementwise_layers.append(current_id)
                                        continue
                                else:
                                    raise Exception(
                                        "Unexcepted situation happend while optimizing transpose"
                                    )
                            except Exception as e:
                                can_be_optimized = False
                                break
S
SunAhong1993 已提交
197
                        output_index = get_index(_graph.layers[ipt])
S
SunAhong1993 已提交
198 199 200 201 202 203 204 205 206
                        if _graph.layers[
                                ipt].kernel == "paddle.transpose":
                            if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]:
                                can_be_optimized = False
                                break
                            if ipt not in visited_layers:
                                transpose_layers.append(ipt)
                        elif _graph.layers[
                                ipt].kernel in self.elementwise_layers:
S
SunAhong1993 已提交
207
                            if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
208 209 210 211 212
                                can_be_optimized = False
                                break
                            if ipt not in visited_layers:
                                propagate_layers.append(ipt)
                        elif _graph.layers[ipt].kernel in self.direct_layers:
S
SunAhong1993 已提交
213
                            if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
214 215 216 217 218
                                can_be_optimized = False
                                break
                            if ipt not in visited_layers:
                                propagate_layers.append(ipt)
                        elif _graph.layers[ipt].kernel in self.reduce_layers:
S
SunAhong1993 已提交
219
                            if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
220 221
                                can_be_optimized = False
                                break
S
SunAhong1993 已提交
222
                            if not _graph.layers[ipt].attrs.get('keepdim',
S
SunAhong1993 已提交
223 224 225 226 227 228 229
                                                                False):
                                can_be_optimized = False
                                break
                            if ipt not in visited_layers:
                                propagate_layers.append(ipt)
                                reduce_layers.append(ipt)
                        elif _graph.layers[ipt].kernel == "paddle.concat":
S
SunAhong1993 已提交
230
                            if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
S
SunAhong1993 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
                                can_be_optimized = False
                                break
                            if ipt not in visited_layers:
                                propagate_layers.append(ipt)
                                concat_layers.append(ipt)
                        else:
                            can_be_optimized = False
                            break
                    if not can_be_optimized:
                        break
                if not can_be_optimized:
                    continue

                transpose_layers.append(layer_id)
                transpose_layers = list(set(transpose_layers))
                for l in transpose_layers:
S
SunAhong1993 已提交
247 248
                    output_index = get_index(graph.layers[l])
                    if graph.layers[l].outputs[output_index] in graph.outputs:
S
SunAhong1993 已提交
249 250 251 252 253 254
                        can_be_optimized = False
                        break
                if not can_be_optimized:
                    continue

                for l in transpose_layers:
S
SunAhong1993 已提交
255
                    _graph.del_layer(l)
S
SunAhong1993 已提交
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271

                optimized_transpose_layers.extend(transpose_layers)
                optimized_reduce_layers.extend(reduce_layers)
                optimized_concat_layers.extend(concat_layers)
                optimized_elementwise_layers.extend(elementwise_layers)
                return True
            return False

        before_transpose_num = self.get_transpose_num(graph)
        opt_graph = copy.deepcopy(graph)
        total_layer_num = len(opt_graph.layers)

        while strip_transpose(opt_graph):
            pass

        for layer_id in list(set(optimized_transpose_layers)):
S
SunAhong1993 已提交
272
            graph.del_layer(layer_id)
S
SunAhong1993 已提交
273
        for layer_id in list(set(optimized_reduce_layers)):
S
SunAhong1993 已提交
274
            dim = graph.layers[layer_id].attrs.get('axis', None)
S
SunAhong1993 已提交
275 276 277
            if dim is not None:
                for i in range(len(dim)):
                    dim[i] = [0, 2, 3, 1][dim[i]]
S
SunAhong1993 已提交
278
                graph.layers[layer_id].attrs['axis'] = dim
S
SunAhong1993 已提交
279 280 281 282 283 284 285 286
        for layer_id in list(set(optimized_concat_layers)):
            axis = graph.layers[layer_id].attrs.get('axis', 0)
            graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]

        current_transpose_num = self.get_transpose_num(graph)
        print(
            "\nTranspose layers optimized, before: transpose_num={}, after: transpose_num={}".
            format(before_transpose_num, current_transpose_num))