diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 821420d5de8a1ad492affefddcd06fcba1e303c1..e399fc254ccc927c38e6713cc24def893b27e5c6 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) { 2U); } -class TransposePatternRewrite +class RedundantTransposeFusePattern : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; @@ -265,37 +265,26 @@ class Conv2dBnFusePattern ir::Value bn_scale = op.scale(); ir::Value bn_bias = op.bias(); - ir::OpResult bn_mean_result = bn_mean.dyn_cast(); - IR_ENFORCE(bn_mean_result); - ir::OpResult bn_variance_result = bn_variance.dyn_cast(); - IR_ENFORCE(bn_variance_result); - ir::OpResult bn_scale_result = bn_scale.dyn_cast(); - IR_ENFORCE(bn_scale_result); - ir::OpResult bn_bias_result = bn_bias.dyn_cast(); - IR_ENFORCE(bn_bias_result); - // --- deal with filter --- - rewriter.SetInsertionPoint(conv2d_op); + rewriter.SetInsertionPoint(op); phi::DDim bn_variance_shape = bn_variance.type().dyn_cast().dims(); float epsilon = op.attribute("epsilon").data(); paddle::dialect::FullOp full_op = rewriter.Build( phi::vectorize(bn_variance_shape), epsilon); paddle::dialect::AddOp add_op = rewriter.Build( - bn_variance_result, full_op.out()); + bn_variance.dyn_cast(), full_op.out()); paddle::dialect::SqrtOp sqrt_op = rewriter.Build(add_op.out()); paddle::dialect::DivideOp div_op = - rewriter.Build(bn_scale_result, - sqrt_op.out()); - + rewriter.Build( + bn_scale.dyn_cast(), sqrt_op.out()); // reshape scale phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter); phi::DDim bn_scale_shape = bn_scale.type().dyn_cast().dims(); std::vector bn_scale_new_shape(conv2d_filter_shape.size(), 1); bn_scale_new_shape[0] = bn_scale_shape[0]; - paddle::dialect::ReshapeOp reshape_scale_op = rewriter.Build(div_op.out(), bn_scale_new_shape); @@ -303,39 +292,39 @@ class Conv2dBnFusePattern paddle::dialect::MultiplyOp mul_op = rewriter.Build(conv2d_filter_result, reshape_scale_op.out()); - // TODO(liuyuanle): Use rewriter. - conv2d_op->op_operand(1).set_source(mul_op.out()); + + auto conv2d_attributes = conv2d_op->attributes(); + auto new_conv2d_op = rewriter.Build( + conv2d_op.input().dyn_cast(), + mul_op.out(), + conv2d_attributes); // --- deal with bias --- - rewriter.SetInsertionPoint(op); paddle::dialect::MultiplyOp mul_bias_op = - rewriter.Build(bn_mean_result, - div_op.out()); + rewriter.Build( + bn_mean.dyn_cast(), div_op.out()); // new bias --> sub_op.out() paddle::dialect::SubtractOp sub_op = - rewriter.Build(bn_bias_result, - mul_bias_op.out()); - + rewriter.Build( + bn_bias.dyn_cast(), mul_bias_op.out()); // reshape new bias - phi::DDim conv2d_out_shape = ir::GetShapeFromValue(conv2d_out); - std::vector new_bias_new_shape(conv2d_out_shape.size(), 1); + phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out()); + std::vector new_bias_new_shape(new_conv2d_out_shape.size(), 1); std::string data_format = - conv2d_op.attribute("data_format").data(); - + new_conv2d_op.attribute("data_format").data(); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); - new_bias_new_shape[0] = conv2d_out_shape[0]; - new_bias_new_shape[1] = conv2d_out_shape[1]; - + new_bias_new_shape[0] = new_conv2d_out_shape[0]; + new_bias_new_shape[1] = new_conv2d_out_shape[1]; paddle::dialect::ReshapeOp reshape_bias_op = rewriter.Build(sub_op.out(), new_bias_new_shape); - paddle::dialect::AddOp add_bias_op = rewriter.Build( - conv2d_out, reshape_bias_op.out()); - auto next_op = ir::GetFirstUseOperationForOutput<0>(op); - rewriter.ReplaceAllUsesWith(next_op->operand(0), add_bias_op.out()); + new_conv2d_op.out(), reshape_bias_op.out()); + + rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out()); rewriter.EraseOp(op); + rewriter.EraseOp(conv2d_op); return true; } }; @@ -345,7 +334,7 @@ class TestPass : public ir::Pass { TestPass() : ir::Pass("TestPass", 1) {} void Run(ir::Operation *op) override { ir::RewritePatternSet ps(op->ir_context()); - ps.Add(op->ir_context()); + ps.Add(op->ir_context()); ps.Add(op->ir_context()); ir::FrozenRewritePatternSet frozen_ps(std::move(ps));