未验证 提交 9c17e45e 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix conv+bn pattern (#55073)

上级 3f5c2b5f
...@@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) { ...@@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) {
2U); 2U);
} }
class TransposePatternRewrite class RedundantTransposeFusePattern
: public ir::OpRewritePattern<paddle::dialect::TransposeOp> { : public ir::OpRewritePattern<paddle::dialect::TransposeOp> {
public: public:
using ir::OpRewritePattern<paddle::dialect::TransposeOp>::OpRewritePattern; using ir::OpRewritePattern<paddle::dialect::TransposeOp>::OpRewritePattern;
...@@ -265,37 +265,26 @@ class Conv2dBnFusePattern ...@@ -265,37 +265,26 @@ class Conv2dBnFusePattern
ir::Value bn_scale = op.scale(); ir::Value bn_scale = op.scale();
ir::Value bn_bias = op.bias(); ir::Value bn_bias = op.bias();
ir::OpResult bn_mean_result = bn_mean.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_mean_result);
ir::OpResult bn_variance_result = bn_variance.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_variance_result);
ir::OpResult bn_scale_result = bn_scale.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_scale_result);
ir::OpResult bn_bias_result = bn_bias.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_bias_result);
// --- deal with filter --- // --- deal with filter ---
rewriter.SetInsertionPoint(conv2d_op); rewriter.SetInsertionPoint(op);
phi::DDim bn_variance_shape = phi::DDim bn_variance_shape =
bn_variance.type().dyn_cast<paddle::dialect::DenseTensorType>().dims(); bn_variance.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
float epsilon = op.attribute<ir::FloatAttribute>("epsilon").data(); float epsilon = op.attribute<ir::FloatAttribute>("epsilon").data();
paddle::dialect::FullOp full_op = rewriter.Build<paddle::dialect::FullOp>( paddle::dialect::FullOp full_op = rewriter.Build<paddle::dialect::FullOp>(
phi::vectorize(bn_variance_shape), epsilon); phi::vectorize(bn_variance_shape), epsilon);
paddle::dialect::AddOp add_op = rewriter.Build<paddle::dialect::AddOp>( paddle::dialect::AddOp add_op = rewriter.Build<paddle::dialect::AddOp>(
bn_variance_result, full_op.out()); bn_variance.dyn_cast<ir::OpResult>(), full_op.out());
paddle::dialect::SqrtOp sqrt_op = paddle::dialect::SqrtOp sqrt_op =
rewriter.Build<paddle::dialect::SqrtOp>(add_op.out()); rewriter.Build<paddle::dialect::SqrtOp>(add_op.out());
paddle::dialect::DivideOp div_op = paddle::dialect::DivideOp div_op =
rewriter.Build<paddle::dialect::DivideOp>(bn_scale_result, rewriter.Build<paddle::dialect::DivideOp>(
sqrt_op.out()); bn_scale.dyn_cast<ir::OpResult>(), sqrt_op.out());
// reshape scale // reshape scale
phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter); phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter);
phi::DDim bn_scale_shape = phi::DDim bn_scale_shape =
bn_scale.type().dyn_cast<paddle::dialect::DenseTensorType>().dims(); bn_scale.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
std::vector<int64_t> bn_scale_new_shape(conv2d_filter_shape.size(), 1); std::vector<int64_t> bn_scale_new_shape(conv2d_filter_shape.size(), 1);
bn_scale_new_shape[0] = bn_scale_shape[0]; bn_scale_new_shape[0] = bn_scale_shape[0];
paddle::dialect::ReshapeOp reshape_scale_op = paddle::dialect::ReshapeOp reshape_scale_op =
rewriter.Build<paddle::dialect::ReshapeOp>(div_op.out(), rewriter.Build<paddle::dialect::ReshapeOp>(div_op.out(),
bn_scale_new_shape); bn_scale_new_shape);
...@@ -303,39 +292,39 @@ class Conv2dBnFusePattern ...@@ -303,39 +292,39 @@ class Conv2dBnFusePattern
paddle::dialect::MultiplyOp mul_op = paddle::dialect::MultiplyOp mul_op =
rewriter.Build<paddle::dialect::MultiplyOp>(conv2d_filter_result, rewriter.Build<paddle::dialect::MultiplyOp>(conv2d_filter_result,
reshape_scale_op.out()); reshape_scale_op.out());
// TODO(liuyuanle): Use rewriter.
conv2d_op->op_operand(1).set_source(mul_op.out()); auto conv2d_attributes = conv2d_op->attributes();
auto new_conv2d_op = rewriter.Build<paddle::dialect::Conv2dOp>(
conv2d_op.input().dyn_cast<ir::OpResult>(),
mul_op.out(),
conv2d_attributes);
// --- deal with bias --- // --- deal with bias ---
rewriter.SetInsertionPoint(op);
paddle::dialect::MultiplyOp mul_bias_op = paddle::dialect::MultiplyOp mul_bias_op =
rewriter.Build<paddle::dialect::MultiplyOp>(bn_mean_result, rewriter.Build<paddle::dialect::MultiplyOp>(
div_op.out()); bn_mean.dyn_cast<ir::OpResult>(), div_op.out());
// new bias --> sub_op.out() // new bias --> sub_op.out()
paddle::dialect::SubtractOp sub_op = paddle::dialect::SubtractOp sub_op =
rewriter.Build<paddle::dialect::SubtractOp>(bn_bias_result, rewriter.Build<paddle::dialect::SubtractOp>(
mul_bias_op.out()); bn_bias.dyn_cast<ir::OpResult>(), mul_bias_op.out());
// reshape new bias // reshape new bias
phi::DDim conv2d_out_shape = ir::GetShapeFromValue(conv2d_out); phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out());
std::vector<int64_t> new_bias_new_shape(conv2d_out_shape.size(), 1); std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1);
std::string data_format = std::string data_format =
conv2d_op.attribute<ir::StrAttribute>("data_format").data(); new_conv2d_op.attribute<ir::StrAttribute>("data_format").data();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[0] = conv2d_out_shape[0]; new_bias_new_shape[0] = new_conv2d_out_shape[0];
new_bias_new_shape[1] = conv2d_out_shape[1]; new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op = paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(), rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
new_bias_new_shape); new_bias_new_shape);
paddle::dialect::AddOp add_bias_op = rewriter.Build<paddle::dialect::AddOp>( paddle::dialect::AddOp add_bias_op = rewriter.Build<paddle::dialect::AddOp>(
conv2d_out, reshape_bias_op.out()); new_conv2d_op.out(), reshape_bias_op.out());
auto next_op = ir::GetFirstUseOperationForOutput<0>(op);
rewriter.ReplaceAllUsesWith(next_op->operand(0), add_bias_op.out()); rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out());
rewriter.EraseOp(op); rewriter.EraseOp(op);
rewriter.EraseOp(conv2d_op);
return true; return true;
} }
}; };
...@@ -345,7 +334,7 @@ class TestPass : public ir::Pass { ...@@ -345,7 +334,7 @@ class TestPass : public ir::Pass {
TestPass() : ir::Pass("TestPass", 1) {} TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override { void Run(ir::Operation *op) override {
ir::RewritePatternSet ps(op->ir_context()); ir::RewritePatternSet ps(op->ir_context());
ps.Add<TransposePatternRewrite>(op->ir_context()); ps.Add<RedundantTransposeFusePattern>(op->ir_context());
ps.Add<Conv2dBnFusePattern>(op->ir_context()); ps.Add<Conv2dBnFusePattern>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册