bn_scale_fuser.py 7.0 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *


class Static_BNScaleFuser(FuseBase):
    def __init__(self):
S
SunAhong1993 已提交
23
        super(Static_BNScaleFuser, self).__init__(graph_type="static")
S
SunAhong1993 已提交
24
        self.patterns = list()
S
SunAhong1993 已提交
25 26 27 28

    def build_pattern(self):
        """ 描述需要替换的batchnorm2d图结构。
        batchnorm2d层模式python实现代码示例:
S
SunAhong1993 已提交
29
        模式一:
S
SunAhong1993 已提交
30 31
        conv1_bn_mean = paddle.static.create_parameter(shape=(128,), dtype='float32', name='conv1_bn_mean')
        conv1_bn_variance = paddle.static.create_parameter(shape=(128,), dtype='float32', name='conv1_bn_variance')
S
SunAhong1993 已提交
32 33 34 35 36 37 38
        conv1_bn = paddle.nn.functional.batch_norm(x=conv1, weight=conv1_bn_weight, bias=conv1_bn_bias, running_mean=conv1_bn_mean, running_var=conv1_bn_variance, epsilon=9.999999747378752e-06, momentum=0.9990000128746033)
        conv1_scale_cparam1 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam1')
        conv1_scale_mul = paddle.multiply(x=conv1_bn, y=conv1_scale_cparam1, axis=1)
        conv1_scale_cparam2 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam2')
        conv1_scale_cparam2 = paddle.reshape(x=conv1_scale_cparam2, shape=[32, 1, 1])
        conv1_scale = paddle.add(x=conv1_scale_mul, y=conv1_scale_cparam2)
        模式二:
S
SunAhong1993 已提交
39 40
        conv1_bn_mean = paddle.static.create_parameter(shape=(128,), dtype='float32', name='conv1_bn_mean')
        conv1_bn_variance = paddle.static.create_parameter(shape=(128,), dtype='float32', name='conv1_bn_variance')
S
SunAhong1993 已提交
41 42 43 44 45
        conv1_bn = paddle.nn.functional.batch_norm(x=conv1, weight=conv1_bn_weight, bias=conv1_bn_bias, running_mean=conv1_bn_mean, running_var=conv1_bn_variance, epsilon=9.999999747378752e-06, momentum=0.9990000128746033)
        conv1_scale_cparam1 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam1')
        conv1_scale_mul = paddle.multiply(x=conv1_bn, y=conv1_scale_cparam1, axis=1)
        conv1_scale_cparam2 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam2')
        conv1_scale = paddle.add(x=conv1_scale_mul, y=conv1_scale_cparam2)
S
SunAhong1993 已提交
46 47 48 49 50
        """

        def gen_name(id):
            return "x" + str(id)
        
S
SunAhong1993 已提交
51
        pattern = PaddleGraph(graph_type="dygraph")
S
SunAhong1993 已提交
52 53 54 55 56 57 58 59
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(10)])
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(11)])
S
SunAhong1993 已提交
60 61 62 63 64
        pattern.add_layer(
            "paddle.nn.functional.batch_norm",
            inputs={"input": "bn-input-0",
                    "weight": "bn-input-1",
                    "bias": "bn-input-2",
S
SunAhong1993 已提交
65 66
                    "running_mean": gen_name(10),
                    "running_var": gen_name(11)},
S
SunAhong1993 已提交
67
            outputs=[gen_name(0)])
S
SunAhong1993 已提交
68 69
        pattern.add_layer(
            "paddle.static.create_parameter",
S
SunAhong1993 已提交
70 71 72 73
            inputs={},
            outputs=[gen_name(1)])
        inputs_dict = {}
        inputs_dict['x'] = gen_name(0)
S
SunAhong1993 已提交
74 75 76
        inputs_dict['y'] = gen_name(1)
        pattern.add_layer(
            "paddle.multiply",
S
SunAhong1993 已提交
77
            inputs=inputs_dict,
S
SunAhong1993 已提交
78 79 80
            outputs=[gen_name(2)])
        pattern.add_layer(
            "paddle.static.create_parameter",
S
SunAhong1993 已提交
81
            inputs={},
S
SunAhong1993 已提交
82 83 84 85
            outputs=[gen_name(3)])
        pattern.add_layer(
            "paddle.reshape",
            inputs={"x": gen_name(3)},
S
SunAhong1993 已提交
86
            outputs=[gen_name(4)])
S
SunAhong1993 已提交
87 88 89 90 91 92
        inputs_dict = {}
        inputs_dict['x'] = gen_name(2)
        inputs_dict['y'] = gen_name(4)
        pattern.add_layer(
            "paddle.add",
            inputs=inputs_dict,
S
SunAhong1993 已提交
93
            outputs=[gen_name(5)])
S
SunAhong1993 已提交
94 95
        pattern.build(inputs={"input-0": "bn-input-0",
                              "input-1": "bn-input-1",
S
SunAhong1993 已提交
96
                              "input-2": "bn-input-2"})
S
SunAhong1993 已提交
97 98 99
        self.patterns.append(pattern)
        
        pattern = PaddleGraph(graph_type="dygraph")
S
SunAhong1993 已提交
100 101 102 103 104 105 106 107
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(10)])
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(11)])
S
SunAhong1993 已提交
108 109 110 111 112
        pattern.add_layer(
            "paddle.nn.functional.batch_norm",
            inputs={"input": "bn-input-0",
                    "weight": "bn-input-1",
                    "bias": "bn-input-2",
S
SunAhong1993 已提交
113 114
                    "running_mean": gen_name(10),
                    "running_var": gen_name(11),},
S
SunAhong1993 已提交
115 116 117 118 119
            outputs=[gen_name(0)])
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(1)])
S
SunAhong1993 已提交
120
        inputs_dict = {}
S
SunAhong1993 已提交
121 122 123 124
        inputs_dict['x'] = gen_name(0)
        inputs_dict['y'] = gen_name(1)
        pattern.add_layer(
            "paddle.multiply",
S
SunAhong1993 已提交
125
            inputs=inputs_dict,
S
SunAhong1993 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139
            outputs=[gen_name(2)])
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(3)])
        inputs_dict = {}
        inputs_dict['x'] = gen_name(2)
        inputs_dict['y'] = gen_name(3)
        pattern.add_layer(
            "paddle.add",
            inputs=inputs_dict,
            outputs=[gen_name(4)])
        pattern.build(inputs={"input-0": "bn-input-0",
                              "input-1": "bn-input-1",
S
SunAhong1993 已提交
140
                              "input-2": "bn-input-2"})
S
SunAhong1993 已提交
141
        self.patterns.append(pattern)
S
SunAhong1993 已提交
142 143 144

    def insert_new_layer(self, graph, parameters, matches):
        new_layer = self.gen_new_layer(parameters, matches)
S
SunAhong1993 已提交
145
        new_layer_id = list(matches.keys())[-1]
S
SunAhong1993 已提交
146
        graph.layers[new_layer_id] = new_layer
S
SunAhong1993 已提交
147 148
        matches.pop(list(matches.keys())[0])
        matches.pop(list(matches.keys())[0])
S
SunAhong1993 已提交
149 150
        matches.pop(list(matches.keys())[1])
        matches.pop(list(matches.keys())[2])
S
SunAhong1993 已提交
151 152 153 154
        matches.pop(new_layer_id)

    def gen_new_layer(self, parameters, matches):
        layers_id = list(matches.keys())
S
SunAhong1993 已提交
155
        bn_layer = matches[layers_id[2]]
S
SunAhong1993 已提交
156
        layer = matches[layers_id[3]]
S
SunAhong1993 已提交
157 158
        bn_layer.inputs["weight"] = layer.outputs[0]
        layer = matches[layers_id[5]]
S
SunAhong1993 已提交
159 160 161 162 163
        bn_layer.inputs["bias"] = layer.outputs[0]
        bn_layer.id = layers_id[-1]
        layer = matches[layers_id[-1]]
        bn_layer.outputs = layer.outputs
        return bn_layer