From d1dc677a584393ca9494c2a73596268ed281fc24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Wed, 5 Jan 2022 09:19:48 +0800 Subject: [PATCH] [infrt] optimize the infrt rewriter pattern format. test=develop (#38694) --- paddle/infrt/dialect/infrt_base.td | 4 ---- paddle/infrt/dialect/rewrite.td | 9 ++++++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td index 61dcfe5bfb1..7d6fdbbbf2f 100644 --- a/paddle/infrt/dialect/infrt_base.td +++ b/paddle/infrt/dialect/infrt_base.td @@ -35,8 +35,4 @@ def INFRT_cvtValueToValueRange : NativeCodeCall< def INFRT_concatTwoValueRange : NativeCodeCall< "mlir::concatTwoValueRange($0, $1)">; - -class IsBoolAttrEq : Constraint< - CPred<"($0.getValue() ==" # value # ")">, - "Bool attrbute value constraint">; #endif // INFRT_BASE diff --git a/paddle/infrt/dialect/rewrite.td b/paddle/infrt/dialect/rewrite.td index aa81dd72d05..b5b7cf0667f 100644 --- a/paddle/infrt/dialect/rewrite.td +++ b/paddle/infrt/dialect/rewrite.td @@ -15,13 +15,16 @@ include "paddle/infrt/dialect/pd_ops.td" // which corresponds to the following computation: // (FusedFC) out = x * y + bias // +// while meeting the following attribute constrait: +// Matmul: transpose_x: false +// transpose_y: false +// // Todo: // 1. Make the constrait more completely. // 2. Consider the case of : out = bias + z //===----------------------------------------------------------------------===// -def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, $transpose_x, $transpose_y, $alpha), $bias, $axis), - (PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">)), - [(IsBoolAttrEq<"false"> $transpose_x),(IsBoolAttrEq<"false"> $transpose_y)]>; +def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, ConstBoolAttrFalse:$_, ConstBoolAttrFalse:$_, $alpha), $bias, $axis), + (PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">))>; //===----------------------------------------------------------------------===// -- GitLab