fuse_resnet_unit_pass.py 4.0 KB
Newer Older
W
wuhuanzhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid.ir as ir


def set_resnet_unit_attrs(resnet_unit, has_shortcut):
    resnet_unit.SetAttr("fuse_add", False)
    resnet_unit.SetAttr("act_type", "relu")
    resnet_unit.SetAttr("has_shortcut", has_shortcut)
    resnet_unit.SetAttr("data_format", 'NHWC')
    resnet_unit.SetAttr("dilation", 1)
    resnet_unit.Attr("stride").MappedPattern(
        op="conv2d", name="strides", element_index=0)
    resnet_unit.Attr("padding").MappedPattern(
        op="conv2d", name="paddings", element_index=0)
    resnet_unit.Attr("group").MappedPattern(op="conv2d", name="groups")
    resnet_unit.Attr("op_device").MappedPattern(op="conv2d", name="op_device")
    resnet_unit.Attr("op_namescope").MappedPattern(
        op="conv2d", name="op_namescope")
    resnet_unit.Attr("momentum").MappedPattern(op="batch_norm", name="momentum")
    resnet_unit.Attr("epsilon").MappedPattern(op="batch_norm", name="epsilon")
    resnet_unit.Attr("use_global_stats").MappedPattern(
        op="batch_norm", name="use_global_stats")


def set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ=None, varZ=None):
    resnet_unit.SetOutputs(
        RunningMeanX=meanX,
        RunningVarX=varX,
        RunningMeanZ=meanZ,
        RunningVarZ=varZ)


@ir.RegisterPass
def fuse_resnet_unit():
    def pattern_conv_bn(x, filter, scale, bias, mean, var):
        filter.Attr("shape")[0].Mod(32).EQ(0)
        filter.Attr("shape")[1].Mod(8).EQ(0)
        filter.Attr("shape")[2].EQ(1)
        filter.Attr("shape")[3].EQ(1)
        conv2d = ir.PassDesc.OP.conv2d(Input=x, Filter=filter)
        conv2d.SetAttr("data_format", 'NHWC')
        bn = ir.PassDesc.OP.batch_norm(
            X=conv2d, Bias=bias, Mean=mean, Scale=scale, Variance=var)
        return bn

    def pattern_one_input(x, filter, scale, bias, mean, var):
        bn = pattern_conv_bn(x, filter, scale, bias, mean, var)
        relu = ir.PassDesc.OP.relu(X=bn.Output("Y"))
        return relu

    def replace_one_input(x, filter, scale, bias, mean, var):
        resnet_unit = ir.PassDesc.OP.resnet_unit(
            X=x, FilterX=filter, ScaleX=scale, BiasX=bias, MeanX=mean, VarX=var)
        set_resnet_unit_attrs(resnet_unit, False)
        set_resnet_unit_outputs(resnet_unit, mean, var)
        return resnet_unit.Output("Y")

    def pattern_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ,
                          scaleZ, biasZ, meanZ, varZ):
        bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX)
        bnZ = pattern_conv_bn(x, filterZ, scaleZ, biasZ, meanZ, varZ)
        ewadd = ir.PassDesc.OP.elementwise_add(
            X=bnX.Output("Y"), Y=bnZ.Output("Y"))
        relu = ir.PassDesc.OP.relu(X=ewadd)
        return relu

    def replace_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ,
                          scaleZ, biasZ, meanZ, varZ):
        resnet_unit = ir.PassDesc.OP.resnet_unit(
            X=x,
            FilterX=filterX,
            ScaleX=scaleX,
            BiasX=biasX,
            MeanX=meanX,
            VarX=varX,
            Z=z,
            FilterZ=filterZ,
            ScaleZ=scaleZ,
            BiasZ=biasZ,
            MeanZ=meanZ,
            VarZ=varZ)
        set_resnet_unit_attrs(resnet_unit, True)
        set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ, varZ)
        return resnet_unit.Output("Y")

    return (pattern_one_input, replace_one_input), (pattern_two_input,
                                                    replace_two_input)