diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td index 61dcfe5bfb1c3723440c7cf760fa7a2f33f90a0a..7d6fdbbbf2f68f6629c2299f807cbb9fa7605f74 100644 --- a/paddle/infrt/dialect/infrt_base.td +++ b/paddle/infrt/dialect/infrt_base.td @@ -35,8 +35,4 @@ def INFRT_cvtValueToValueRange : NativeCodeCall< def INFRT_concatTwoValueRange : NativeCodeCall< "mlir::concatTwoValueRange($0, $1)">; - -class IsBoolAttrEq : Constraint< - CPred<"($0.getValue() ==" # value # ")">, - "Bool attrbute value constraint">; #endif // INFRT_BASE diff --git a/paddle/infrt/dialect/rewrite.td b/paddle/infrt/dialect/rewrite.td index aa81dd72d059b481b9509045e45ee1d0be3e09e9..b5b7cf0667f6823ac85ce13c7f8a6d818b30b7bb 100644 --- a/paddle/infrt/dialect/rewrite.td +++ b/paddle/infrt/dialect/rewrite.td @@ -15,13 +15,16 @@ include "paddle/infrt/dialect/pd_ops.td" // which corresponds to the following computation: // (FusedFC) out = x * y + bias // +// while meeting the following attribute constrait: +// Matmul: transpose_x: false +// transpose_y: false +// // Todo: // 1. Make the constrait more completely. // 2. Consider the case of : out = bias + z //===----------------------------------------------------------------------===// -def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, $transpose_x, $transpose_y, $alpha), $bias, $axis), - (PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">)), - [(IsBoolAttrEq<"false"> $transpose_x),(IsBoolAttrEq<"false"> $transpose_y)]>; +def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, ConstBoolAttrFalse:$_, ConstBoolAttrFalse:$_, $alpha), $bias, $axis), + (PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">))>; //===----------------------------------------------------------------------===//