bn_scale_fuser.py 6.2 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *


class Static_BNScaleFuser(FuseBase):
    def __init__(self):
S
SunAhong1993 已提交
23 24
        super(Static_BNScaleFuser, self).__init__(graph_type="static")
        patterns = list()
S
SunAhong1993 已提交
25 26 27 28

    def build_pattern(self):
        """ 描述需要替换的batchnorm2d图结构。
        batchnorm2d层模式python实现代码示例:
S
SunAhong1993 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
        模式一:
        conv1_bn = paddle.nn.functional.batch_norm(x=conv1, weight=conv1_bn_weight, bias=conv1_bn_bias, running_mean=conv1_bn_mean, running_var=conv1_bn_variance, epsilon=9.999999747378752e-06, momentum=0.9990000128746033)
        conv1_scale_cparam1 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam1')
        conv1_scale_mul = paddle.multiply(x=conv1_bn, y=conv1_scale_cparam1, axis=1)
        conv1_scale_cparam2 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam2')
        conv1_scale_cparam2 = paddle.reshape(x=conv1_scale_cparam2, shape=[32, 1, 1])
        conv1_scale = paddle.add(x=conv1_scale_mul, y=conv1_scale_cparam2)
        模式二:
        conv1_bn = paddle.nn.functional.batch_norm(x=conv1, weight=conv1_bn_weight, bias=conv1_bn_bias, running_mean=conv1_bn_mean, running_var=conv1_bn_variance, epsilon=9.999999747378752e-06, momentum=0.9990000128746033)
        conv1_scale_cparam1 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam1')
        conv1_scale_mul = paddle.multiply(x=conv1_bn, y=conv1_scale_cparam1, axis=1)
        conv1_scale_cparam2 = paddle.static.create_parameter(shape=(32,), dtype='float32', name='conv1_scale_cparam2')
        conv1_scale = paddle.add(x=conv1_scale_mul, y=conv1_scale_cparam2)
S
SunAhong1993 已提交
42 43 44 45 46
        """

        def gen_name(id):
            return "x" + str(id)
        
S
SunAhong1993 已提交
47 48 49 50 51 52 53 54
        pattern = PaddleGraph(graph_type="dygraph")
        pattern.add_layer(
            "paddle.nn.functional.batch_norm",
            inputs={"input": "bn-input-0",
                    "weight": "bn-input-1",
                    "bias": "bn-input-2",
                    "running_mean": "bn-input-3",
                    "running_var": "bn-input-4",},
S
SunAhong1993 已提交
55
            outputs=[gen_name(0)])
S
SunAhong1993 已提交
56 57
        pattern.add_layer(
            "paddle.static.create_parameter",
S
SunAhong1993 已提交
58 59 60 61
            inputs={},
            outputs=[gen_name(1)])
        inputs_dict = {}
        inputs_dict['x'] = gen_name(0)
S
SunAhong1993 已提交
62 63 64
        inputs_dict['y'] = gen_name(1)
        pattern.add_layer(
            "paddle.multiply",
S
SunAhong1993 已提交
65
            inputs=inputs_dict,
S
SunAhong1993 已提交
66 67 68
            outputs=[gen_name(2)])
        pattern.add_layer(
            "paddle.static.create_parameter",
S
SunAhong1993 已提交
69
            inputs={},
S
SunAhong1993 已提交
70 71 72 73
            outputs=[gen_name(3)])
        pattern.add_layer(
            "paddle.reshape",
            inputs={"x": gen_name(3)},
S
SunAhong1993 已提交
74
            outputs=[gen_name(4)])
S
SunAhong1993 已提交
75 76 77 78 79 80
        inputs_dict = {}
        inputs_dict['x'] = gen_name(2)
        inputs_dict['y'] = gen_name(4)
        pattern.add_layer(
            "paddle.add",
            inputs=inputs_dict,
S
SunAhong1993 已提交
81
            outputs=[gen_name(5)])
S
SunAhong1993 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        pattern.build(inputs={"input-0": "bn-input-0",
                              "input-1": "bn-input-1",
                              "input-2": "bn-input-2",
                              "input-3": "bn-input-3",
                              "input-4": "bn-input-4"})
        self.patterns.append(pattern)
        
        pattern = PaddleGraph(graph_type="dygraph")
        pattern.add_layer(
            "paddle.nn.functional.batch_norm",
            inputs={"input": "bn-input-0",
                    "weight": "bn-input-1",
                    "bias": "bn-input-2",
                    "running_mean": "bn-input-3",
                    "running_var": "bn-input-4",},
            outputs=[gen_name(0)])
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(1)])
S
SunAhong1993 已提交
102
        inputs_dict = {}
S
SunAhong1993 已提交
103 104 105 106
        inputs_dict['x'] = gen_name(0)
        inputs_dict['y'] = gen_name(1)
        pattern.add_layer(
            "paddle.multiply",
S
SunAhong1993 已提交
107
            inputs=inputs_dict,
S
SunAhong1993 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            outputs=[gen_name(2)])
        pattern.add_layer(
            "paddle.static.create_parameter",
            inputs={},
            outputs=[gen_name(3)])
        inputs_dict = {}
        inputs_dict['x'] = gen_name(2)
        inputs_dict['y'] = gen_name(3)
        pattern.add_layer(
            "paddle.add",
            inputs=inputs_dict,
            outputs=[gen_name(4)])
        pattern.build(inputs={"input-0": "bn-input-0",
                              "input-1": "bn-input-1",
                              "input-2": "bn-input-2",
                              "input-3": "bn-input-3",
                              "input-4": "bn-input-4"})
        self.patterns.append(pattern)
S
SunAhong1993 已提交
126 127 128

    def insert_new_layer(self, graph, parameters, matches):
        new_layer = self.gen_new_layer(parameters, matches)
S
SunAhong1993 已提交
129
        new_layer_id = list(matches.keys())[-1]
S
SunAhong1993 已提交
130
        graph.layers[new_layer_id] = new_layer
S
SunAhong1993 已提交
131 132
        matches.pop(list(matches.keys())[1])
        matches.pop(list(matches.keys())[2])
S
SunAhong1993 已提交
133 134 135 136
        matches.pop(new_layer_id)

    def gen_new_layer(self, parameters, matches):
        layers_id = list(matches.keys())
S
SunAhong1993 已提交
137
        bn_layer = matches[layers_id[0]]
S
SunAhong1993 已提交
138
        layer = matches[layers_id[1]]
S
SunAhong1993 已提交
139 140 141 142 143 144 145
        bn_layer.inputs["weight"] = layer.outputs[0]
        layer = matches[layers_id[3]]
        bn_layer.inputs["bias"] = layer.outputs[0]
        bn_layer.id = layers_id[-1]
        layer = matches[layers_id[-1]]
        bn_layer.outputs = layer.outputs
        return bn_layer