未验证 提交 9c17e45e 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix conv+bn pattern (#55073)

上级 3f5c2b5f
......@@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) {
2U);
}
class TransposePatternRewrite
class RedundantTransposeFusePattern
: public ir::OpRewritePattern<paddle::dialect::TransposeOp> {
public:
using ir::OpRewritePattern<paddle::dialect::TransposeOp>::OpRewritePattern;
......@@ -265,37 +265,26 @@ class Conv2dBnFusePattern
ir::Value bn_scale = op.scale();
ir::Value bn_bias = op.bias();
ir::OpResult bn_mean_result = bn_mean.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_mean_result);
ir::OpResult bn_variance_result = bn_variance.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_variance_result);
ir::OpResult bn_scale_result = bn_scale.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_scale_result);
ir::OpResult bn_bias_result = bn_bias.dyn_cast<ir::OpResult>();
IR_ENFORCE(bn_bias_result);
// --- deal with filter ---
rewriter.SetInsertionPoint(conv2d_op);
rewriter.SetInsertionPoint(op);
phi::DDim bn_variance_shape =
bn_variance.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
float epsilon = op.attribute<ir::FloatAttribute>("epsilon").data();
paddle::dialect::FullOp full_op = rewriter.Build<paddle::dialect::FullOp>(
phi::vectorize(bn_variance_shape), epsilon);
paddle::dialect::AddOp add_op = rewriter.Build<paddle::dialect::AddOp>(
bn_variance_result, full_op.out());
bn_variance.dyn_cast<ir::OpResult>(), full_op.out());
paddle::dialect::SqrtOp sqrt_op =
rewriter.Build<paddle::dialect::SqrtOp>(add_op.out());
paddle::dialect::DivideOp div_op =
rewriter.Build<paddle::dialect::DivideOp>(bn_scale_result,
sqrt_op.out());
rewriter.Build<paddle::dialect::DivideOp>(
bn_scale.dyn_cast<ir::OpResult>(), sqrt_op.out());
// reshape scale
phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter);
phi::DDim bn_scale_shape =
bn_scale.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
std::vector<int64_t> bn_scale_new_shape(conv2d_filter_shape.size(), 1);
bn_scale_new_shape[0] = bn_scale_shape[0];
paddle::dialect::ReshapeOp reshape_scale_op =
rewriter.Build<paddle::dialect::ReshapeOp>(div_op.out(),
bn_scale_new_shape);
......@@ -303,39 +292,39 @@ class Conv2dBnFusePattern
paddle::dialect::MultiplyOp mul_op =
rewriter.Build<paddle::dialect::MultiplyOp>(conv2d_filter_result,
reshape_scale_op.out());
// TODO(liuyuanle): Use rewriter.
conv2d_op->op_operand(1).set_source(mul_op.out());
auto conv2d_attributes = conv2d_op->attributes();
auto new_conv2d_op = rewriter.Build<paddle::dialect::Conv2dOp>(
conv2d_op.input().dyn_cast<ir::OpResult>(),
mul_op.out(),
conv2d_attributes);
// --- deal with bias ---
rewriter.SetInsertionPoint(op);
paddle::dialect::MultiplyOp mul_bias_op =
rewriter.Build<paddle::dialect::MultiplyOp>(bn_mean_result,
div_op.out());
rewriter.Build<paddle::dialect::MultiplyOp>(
bn_mean.dyn_cast<ir::OpResult>(), div_op.out());
// new bias --> sub_op.out()
paddle::dialect::SubtractOp sub_op =
rewriter.Build<paddle::dialect::SubtractOp>(bn_bias_result,
mul_bias_op.out());
rewriter.Build<paddle::dialect::SubtractOp>(
bn_bias.dyn_cast<ir::OpResult>(), mul_bias_op.out());
// reshape new bias
phi::DDim conv2d_out_shape = ir::GetShapeFromValue(conv2d_out);
std::vector<int64_t> new_bias_new_shape(conv2d_out_shape.size(), 1);
phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out());
std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1);
std::string data_format =
conv2d_op.attribute<ir::StrAttribute>("data_format").data();
new_conv2d_op.attribute<ir::StrAttribute>("data_format").data();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[0] = conv2d_out_shape[0];
new_bias_new_shape[1] = conv2d_out_shape[1];
new_bias_new_shape[0] = new_conv2d_out_shape[0];
new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
new_bias_new_shape);
paddle::dialect::AddOp add_bias_op = rewriter.Build<paddle::dialect::AddOp>(
conv2d_out, reshape_bias_op.out());
auto next_op = ir::GetFirstUseOperationForOutput<0>(op);
rewriter.ReplaceAllUsesWith(next_op->operand(0), add_bias_op.out());
new_conv2d_op.out(), reshape_bias_op.out());
rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out());
rewriter.EraseOp(op);
rewriter.EraseOp(conv2d_op);
return true;
}
};
......@@ -345,7 +334,7 @@ class TestPass : public ir::Pass {
TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override {
ir::RewritePatternSet ps(op->ir_context());
ps.Add<TransposePatternRewrite>(op->ir_context());
ps.Add<RedundantTransposeFusePattern>(op->ir_context());
ps.Add<Conv2dBnFusePattern>(op->ir_context());
ir::FrozenRewritePatternSet frozen_ps(std::move(ps));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册