fuse_resnet_unit_pass.py 5.1 KB
Newer Older
W
wuhuanzhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.fluid.ir as ir


def set_resnet_unit_attrs(resnet_unit, has_shortcut):
    resnet_unit.SetAttr("fuse_add", False)
    resnet_unit.SetAttr("act_type", "relu")
    resnet_unit.SetAttr("has_shortcut", has_shortcut)
    resnet_unit.SetAttr("data_format", 'NHWC')
    resnet_unit.SetAttr("dilation", 1)
24 25 26 27 28 29
    resnet_unit.Attr("stride").MappedPattern(op="conv2d",
                                             name="strides",
                                             element_index=0)
    resnet_unit.Attr("padding").MappedPattern(op="conv2d",
                                              name="paddings",
                                              element_index=0)
W
wuhuanzhou 已提交
30 31
    resnet_unit.Attr("group").MappedPattern(op="conv2d", name="groups")
    resnet_unit.Attr("op_device").MappedPattern(op="conv2d", name="op_device")
32 33
    resnet_unit.Attr("op_namescope").MappedPattern(op="conv2d",
                                                   name="op_namescope")
W
wuhuanzhou 已提交
34 35
    resnet_unit.Attr("momentum").MappedPattern(op="batch_norm", name="momentum")
    resnet_unit.Attr("epsilon").MappedPattern(op="batch_norm", name="epsilon")
36 37
    resnet_unit.Attr("use_global_stats").MappedPattern(op="batch_norm",
                                                       name="use_global_stats")
W
wuhuanzhou 已提交
38 39 40


def set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ=None, varZ=None):
41 42 43 44
    resnet_unit.SetOutputs(RunningMeanX=meanX,
                           RunningVarX=varX,
                           RunningMeanZ=meanZ,
                           RunningVarZ=varZ)
W
wuhuanzhou 已提交
45 46 47 48


@ir.RegisterPass
def fuse_resnet_unit():
49

W
wuhuanzhou 已提交
50 51 52 53 54 55 56
    def pattern_conv_bn(x, filter, scale, bias, mean, var):
        filter.Attr("shape")[0].Mod(32).EQ(0)
        filter.Attr("shape")[1].Mod(8).EQ(0)
        filter.Attr("shape")[2].EQ(1)
        filter.Attr("shape")[3].EQ(1)
        conv2d = ir.PassDesc.OP.conv2d(Input=x, Filter=filter)
        conv2d.SetAttr("data_format", 'NHWC')
57 58 59 60 61
        bn = ir.PassDesc.OP.batch_norm(X=conv2d,
                                       Bias=bias,
                                       Mean=mean,
                                       Scale=scale,
                                       Variance=var)
W
wuhuanzhou 已提交
62 63 64 65 66 67 68 69
        return bn

    def pattern_one_input(x, filter, scale, bias, mean, var):
        bn = pattern_conv_bn(x, filter, scale, bias, mean, var)
        relu = ir.PassDesc.OP.relu(X=bn.Output("Y"))
        return relu

    def replace_one_input(x, filter, scale, bias, mean, var):
70 71 72 73 74 75
        resnet_unit = ir.PassDesc.OP.resnet_unit(X=x,
                                                 FilterX=filter,
                                                 ScaleX=scale,
                                                 BiasX=bias,
                                                 MeanX=mean,
                                                 VarX=var)
W
wuhuanzhou 已提交
76 77 78 79 80 81 82 83
        set_resnet_unit_attrs(resnet_unit, False)
        set_resnet_unit_outputs(resnet_unit, mean, var)
        return resnet_unit.Output("Y")

    def pattern_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ,
                          scaleZ, biasZ, meanZ, varZ):
        bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX)
        bnZ = pattern_conv_bn(x, filterZ, scaleZ, biasZ, meanZ, varZ)
84 85
        ewadd = ir.PassDesc.OP.elementwise_add(X=bnX.Output("Y"),
                                               Y=bnZ.Output("Y"))
W
wuhuanzhou 已提交
86 87 88 89 90
        relu = ir.PassDesc.OP.relu(X=ewadd)
        return relu

    def replace_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ,
                          scaleZ, biasZ, meanZ, varZ):
91 92 93 94 95 96 97 98 99 100 101 102
        resnet_unit = ir.PassDesc.OP.resnet_unit(X=x,
                                                 FilterX=filterX,
                                                 ScaleX=scaleX,
                                                 BiasX=biasX,
                                                 MeanX=meanX,
                                                 VarX=varX,
                                                 Z=z,
                                                 FilterZ=filterZ,
                                                 ScaleZ=scaleZ,
                                                 BiasZ=biasZ,
                                                 MeanZ=meanZ,
                                                 VarZ=varZ)
W
wuhuanzhou 已提交
103 104 105 106 107 108
        set_resnet_unit_attrs(resnet_unit, True)
        set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ, varZ)
        return resnet_unit.Output("Y")

    return (pattern_one_input, replace_one_input), (pattern_two_input,
                                                    replace_two_input)