rewrite.td 3.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
#ifndef INFRT_REWRITE
#define INFRT_REWRITE

include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "paddle/infrt/dialect/pd_ops.td"

//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'Matmul o ElementwiseAdd' into 'PD_FusedFC'.
//
// We have:
//   (Matmul)      z = x * y
//   (Add)         out = z + bias 
//
// which corresponds to the following computation:
//   (FusedFC)  out = x * y + bias
// 
// Todo:
//  1. Make the constrait more completely.
//  2. Consider the case of : out = bias + z
//===----------------------------------------------------------------------===//
def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, $transpose_x, $transpose_y, $alpha), $bias, $axis),
                     (PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">)),
                     [(IsBoolAttrEq<"false"> $transpose_x),(IsBoolAttrEq<"false"> $transpose_y)]>;


//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'FusedFC o Relu' into 'FusedRepeatedFCRelu'.
//
// We have:
//   (FusedFC)      z = fc(x, y, bias)
//   (Relu)         out = relu(z)
//
// which corresponds to the following computation:
//   (FusedRepeatedFCRelu)  out = RepeatedFCRelu(x, [y], [bias])
// 
//===----------------------------------------------------------------------===//
def FuseFCRelu : Pat<(PD_ReluOp (PD_FusedFC $x, $y, $bias, $_)),
                     (PD_FusedRepeatedFCRelu $x, (INFRT_cvtValueToValueRange $y), (INFRT_cvtValueToValueRange $bias))>;

//===----------------------------------------------------------------------===//
// This is to fold 'FusedRepeatedFCRelu' op.
//
// We have:
//   (FusedRepeatedFCRelu)      z = RepeatedFCRelu(x, [y, ...], [bias, ...])
//   (FusedRepeatedFCRelu)      out = RepeatedFCRelu(z, [y1, ...], [bias1, ...])
//
// which corresponds to the following computation:
//   (FusedRepeatedFCRelu)  out = RepeatedFCRelu(x, [y, ..., y1, ...], [bias, ..., bias1, ....])
// 
//===----------------------------------------------------------------------===//
def FuseRepeatedFCRelu2 : Pat<(PD_FusedRepeatedFCRelu (PD_FusedRepeatedFCRelu $x, $y, $bias), $y_2, $bias_2),
                     (PD_FusedRepeatedFCRelu $x, (INFRT_concatTwoValueRange $y, $y_2), (INFRT_concatTwoValueRange $bias, $bias_2))>;


//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv'
// by deriving new 'w' and 'b' for 'Conv':
//
// We have:
//   (Conv)      z = w * x + b 
//   (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias
//
// which corresponds to the following computation:
//   y = w_ * x + b_
// where
//   w_ = scale * w / sqrt(var + eps)
//   b_ = B + scale * (b - mean) / sqrt(var + eps)
//
//===----------------------------------------------------------------------===//
def FuseBatchNormWithConvPattern: Pat<
    (PD_BatchNormOp
        (PD_Conv2dOp $input, $filter, $bias),
        $scale, $bias_2, $mean, $var, $epsilon),
    (PD_Conv2dOp
        $input,
        (PD_MulOp $filter,
            (PD_ElementwiseDiv:$coefficientW
                $scale,
                (PD_SqrtOp (PD_ElementwiseAdd $var, (PD_ConstantOp $epsilon), (INFRT_createI32Attr<"1">))),
                (INFRT_createI32Attr<"1">))),
        (PD_ElementwiseAdd
            $bias,
            (PD_MulOp 
                (PD_ElementwiseSub $bias, $mean, (INFRT_createI32Attr<"1">)),
                $coefficientW),
            (INFRT_createI32Attr<"1">)))
>;

#endif // INFRT_REWRITE