From 9c17e45e88538a5237617a7f2a94cf8a7a9b3036 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Mon, 3 Jul 2023 12:30:03 +0800 Subject: [PATCH] fix conv+bn pattern (#55073) --- .../pattern_rewrite/pattern_rewrite_test.cc | 61 ++++++++----------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 821420d5de8..e399fc254cc 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) { 2U); } -class TransposePatternRewrite +class RedundantTransposeFusePattern : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; @@ -265,37 +265,26 @@ class Conv2dBnFusePattern ir::Value bn_scale = op.scale(); ir::Value bn_bias = op.bias(); - ir::OpResult bn_mean_result = bn_mean.dyn_cast(); - IR_ENFORCE(bn_mean_result); - ir::OpResult bn_variance_result = bn_variance.dyn_cast(); - IR_ENFORCE(bn_variance_result); - ir::OpResult bn_scale_result = bn_scale.dyn_cast(); - IR_ENFORCE(bn_scale_result); - ir::OpResult bn_bias_result = bn_bias.dyn_cast(); - IR_ENFORCE(bn_bias_result); - // --- deal with filter --- - rewriter.SetInsertionPoint(conv2d_op); + rewriter.SetInsertionPoint(op); phi::DDim bn_variance_shape = bn_variance.type().dyn_cast().dims(); float epsilon = op.attribute("epsilon").data(); paddle::dialect::FullOp full_op = rewriter.Build( phi::vectorize(bn_variance_shape), epsilon); paddle::dialect::AddOp add_op = rewriter.Build( - bn_variance_result, full_op.out()); + bn_variance.dyn_cast(), full_op.out()); paddle::dialect::SqrtOp sqrt_op = rewriter.Build(add_op.out()); paddle::dialect::DivideOp div_op = - rewriter.Build(bn_scale_result, - sqrt_op.out()); - + rewriter.Build( + bn_scale.dyn_cast(), sqrt_op.out()); // reshape scale phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter); phi::DDim bn_scale_shape = bn_scale.type().dyn_cast().dims(); std::vector bn_scale_new_shape(conv2d_filter_shape.size(), 1); bn_scale_new_shape[0] = bn_scale_shape[0]; - paddle::dialect::ReshapeOp reshape_scale_op = rewriter.Build(div_op.out(), bn_scale_new_shape); @@ -303,39 +292,39 @@ class Conv2dBnFusePattern paddle::dialect::MultiplyOp mul_op = rewriter.Build(conv2d_filter_result, reshape_scale_op.out()); - // TODO(liuyuanle): Use rewriter. - conv2d_op->op_operand(1).set_source(mul_op.out()); + + auto conv2d_attributes = conv2d_op->attributes(); + auto new_conv2d_op = rewriter.Build( + conv2d_op.input().dyn_cast(), + mul_op.out(), + conv2d_attributes); // --- deal with bias --- - rewriter.SetInsertionPoint(op); paddle::dialect::MultiplyOp mul_bias_op = - rewriter.Build(bn_mean_result, - div_op.out()); + rewriter.Build( + bn_mean.dyn_cast(), div_op.out()); // new bias --> sub_op.out() paddle::dialect::SubtractOp sub_op = - rewriter.Build(bn_bias_result, - mul_bias_op.out()); - + rewriter.Build( + bn_bias.dyn_cast(), mul_bias_op.out()); // reshape new bias - phi::DDim conv2d_out_shape = ir::GetShapeFromValue(conv2d_out); - std::vector new_bias_new_shape(conv2d_out_shape.size(), 1); + phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out()); + std::vector new_bias_new_shape(new_conv2d_out_shape.size(), 1); std::string data_format = - conv2d_op.attribute("data_format").data(); - + new_conv2d_op.attribute("data_format").data(); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); - new_bias_new_shape[0] = conv2d_out_shape[0]; - new_bias_new_shape[1] = conv2d_out_shape[1]; - + new_bias_new_shape[0] = new_conv2d_out_shape[0]; + new_bias_new_shape[1] = new_conv2d_out_shape[1]; paddle::dialect::ReshapeOp reshape_bias_op = rewriter.Build(sub_op.out(), new_bias_new_shape); - paddle::dialect::AddOp add_bias_op = rewriter.Build( - conv2d_out, reshape_bias_op.out()); - auto next_op = ir::GetFirstUseOperationForOutput<0>(op); - rewriter.ReplaceAllUsesWith(next_op->operand(0), add_bias_op.out()); + new_conv2d_op.out(), reshape_bias_op.out()); + + rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out()); rewriter.EraseOp(op); + rewriter.EraseOp(conv2d_op); return true; } }; @@ -345,7 +334,7 @@ class TestPass : public ir::Pass { TestPass() : ir::Pass("TestPass", 1) {} void Run(ir::Operation *op) override { ir::RewritePatternSet ps(op->ir_context()); - ps.Add(op->ir_context()); + ps.Add(op->ir_context()); ps.Add(op->ir_context()); ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); -- GitLab