// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/transforms/constant_folding_pass.h" #include "paddle/fluid/ir/transforms/transform_general_functions.h" #include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/cast_utils.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/parameter.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/value.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/ir/pattern_rewrite/pattern_applicator.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" #include "paddle/phi/core/kernel_registry.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/ir/dialect/CMakeLists.txt. #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/phi/core/ddim.h" // build Conv2dFusionOp #include "paddle/fluid/ir/interface/infermeta.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/ir/core/op_base.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/infermeta/multiary.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(divide, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(reshape, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(fetch, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(conv2d, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(transpose, CPU, ALL_LAYOUT); // Define op1. class Operation1 : public ir::Op { public: using Op::Op; static const char *name() { return "test.Operation1"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; void Verify(); static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; void Operation1::Verify() { auto &attributes = this->attributes(); if (attributes.count("op2_attr1") == 0 || (!attributes.at("op2_attr1").isa())) { throw("Type of attribute: parameter_name is not right."); } if (attributes.count("op2_attr2") == 0 || (!attributes.at("op2_attr2").isa())) { throw("Type of attribute: parameter_name is not right."); } } const char *Operation1::attributes_name[attributes_num] = {"op2_attr1", "op2_attr2"}; IR_DECLARE_EXPLICIT_TYPE_ID(Operation1) IR_DEFINE_EXPLICIT_TYPE_ID(Operation1) // Define a dialect, op1 and op2 will be registered by this dialect. class TestDialect : public ir::Dialect { public: explicit TestDialect(ir::IrContext *context) : ir::Dialect(name(), context, ir::TypeId::get()) { initialize(); } static const char *name() { return "test"; } private: void initialize() { RegisterOps(); } }; IR_DECLARE_EXPLICIT_TYPE_ID(TestDialect) IR_DEFINE_EXPLICIT_TYPE_ID(TestDialect) // TODO(wilber): Add logical when ir support erase, replace or update. class TestPatternRewrite : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {} bool Match(Operation1 op) const override { return false; } }; class TestPatternRewrite2 : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; bool MatchAndRewrite( Operation1 op, ir::PatternRewriter &rewriter) const override { // NOLINT return false; } }; TEST(PatternRewrite, PatternBenefit) { ir::PatternBenefit benefit1(1); EXPECT_EQ(benefit1.benefit(), 1U); ir::PatternBenefit benefit2(2); EXPECT_EQ(benefit2.benefit(), 2U); EXPECT_TRUE(benefit2 > benefit1); EXPECT_TRUE(benefit2 >= benefit1); EXPECT_TRUE(benefit1 < benefit2); EXPECT_TRUE(benefit1 <= benefit2); EXPECT_TRUE(benefit1 != benefit2); ir::PatternBenefit benefit3(2); EXPECT_TRUE(benefit2 == benefit3); } TEST(RewritePattern, RewritePatternSet) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); auto *test_dialect = ctx->GetOrRegisterDialect(); test_dialect->RegisterOp(); ir::RewritePatternSet ps(ctx); ps.Add(ctx, 1); EXPECT_EQ(ps.native_patterns().size(), 1U); EXPECT_TRUE(ps.native_patterns().back()->debug_labels().empty()); EXPECT_EQ(ps.native_patterns().back()->benefit(), 1U); ps.AddWithLabel({"TestPatternRewrite2"}, ctx, 2); EXPECT_EQ(ps.native_patterns().size(), 2U); EXPECT_EQ(ps.native_patterns().back()->debug_labels()[0], "TestPatternRewrite2"); EXPECT_EQ(ps.native_patterns().back()->benefit(), 2U); ps.Clear(); ps.Add(ctx, 2); EXPECT_EQ(ps.native_patterns().size(), 2U); EXPECT_EQ(ps.native_patterns()[0]->benefit(), 2U); EXPECT_EQ(ps.native_patterns()[1]->benefit(), 2U); } // TODO(wilber): Add actual case. // TEST(PatternRewrite, PatternApplicator) { // ir::IrContext *ctx = ir::IrContext::Instance(); // ctx->GetOrRegisterDialect(); // auto *test_dialect = ctx->GetOrRegisterDialect(); // test_dialect->RegisterOp(); // ir::RewritePatternSet ps(ctx); // ps.Add(ctx, 2); // ir::FrozenRewritePatternSet frozen_set(std::move(ps)); // ir::PatternApplicator applicator(frozen_set); // applicator.ApplyDefaultCostModel(); // } // // TODO(wilber): Add actual case. TEST(PatternRewrite, FrozenRewritePatternSet) { ir::FrozenRewritePatternSet frozen_set; EXPECT_TRUE(frozen_set.match_any_op_native_patterns().empty()); EXPECT_TRUE(frozen_set.op_specific_native_patterns().empty()); ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); auto *test_dialect = ctx->GetOrRegisterDialect(); test_dialect->RegisterOp(); ir::RewritePatternSet ps(ctx); ps.Add(ctx, 2); ir::FrozenRewritePatternSet frozen_set2(std::move(ps)); EXPECT_TRUE(frozen_set2.match_any_op_native_patterns().empty()); const auto &pattern_maps = frozen_set2.op_specific_native_patterns(); EXPECT_EQ(pattern_maps.size(), 1U); EXPECT_EQ(pattern_maps.at(ctx->GetRegisteredOpInfo("test.Operation1")).size(), 2U); } class RedundantTransposeFusePattern : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; bool MatchAndRewrite(paddle::dialect::TransposeOp op, ir::PatternRewriter &rewriter) const override { auto prev_op = ir::GetDefiningOpForInput(op, 0); std::vector axis_last = GetAxis(op); auto prev_trans_op = prev_op->dyn_cast(); if (prev_trans_op) { std::vector axis_first = GetAxis(prev_trans_op); IR_ENFORCE(axis_first.size() == axis_last.size(), "tranpose op's perm rank should be same."); auto new_perm = GetPerm(axis_first, axis_last); rewriter.SetInsertionPoint(op); auto new_transpose_op = rewriter.Build( ir::GetDefiningOpForInput(prev_trans_op, 0)->result(0), new_perm); rewriter.ReplaceOp(op, {new_transpose_op.out()}); return true; } return false; } private: std::vector GetAxis(paddle::dialect::TransposeOp op) const { auto array_attr = op.attribute("perm").AsVector(); std::vector axis(array_attr.size()); for (size_t i = 0; i < array_attr.size(); ++i) { axis[i] = array_attr[i].dyn_cast().data(); } return axis; } std::vector GetPerm(const std::vector &perm1, const std::vector &perm2) const { int n = perm1.size(); std::vector axis(n), axis1(n), axis2(n); std::iota(axis.begin(), axis.end(), 0); for (int i = 0; i < n; ++i) { axis1[i] = axis[perm1[i]]; } for (int i = 0; i < n; ++i) { axis2[i] = axis1[perm2[i]]; } return axis2; } }; class Conv2dBnFusePattern : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; bool MatchAndRewrite( paddle::dialect::BatchNormOp op, ir::PatternRewriter &rewriter) const override { // NOLINT // The next op should be batch_norm. paddle::dialect::Conv2dOp conv2d_op = ir::GetDefiningOpForInput(op, 0)->dyn_cast(); if (!conv2d_op) return false; ir::OpResult conv2d_out = conv2d_op.out(); if (!conv2d_out.HasOneUse()) return false; ir::Value conv2d_filter = conv2d_op.filter(); // ir::GetParameterOp filter_parameter_op = // conv2d_filter.GetDefiningOp()->dyn_cast(); // if (!filter_parameter_op) return false; ir::OpResult conv2d_filter_result = conv2d_filter.dyn_cast(); IR_ENFORCE(conv2d_filter_result); ir::Value bn_input = op.x(); IR_ENFORCE(bn_input == conv2d_out); ir::Value bn_mean = op.mean(); ir::Value bn_variance = op.variance(); ir::Value bn_scale = op.scale(); ir::Value bn_bias = op.bias(); // --- deal with filter --- rewriter.SetInsertionPoint(op); phi::DDim bn_variance_shape = bn_variance.type().dyn_cast().dims(); float epsilon = op.attribute("epsilon").data(); paddle::dialect::FullOp full_op = rewriter.Build( phi::vectorize(bn_variance_shape), epsilon); paddle::dialect::AddOp add_op = rewriter.Build( bn_variance.dyn_cast(), full_op.out()); paddle::dialect::SqrtOp sqrt_op = rewriter.Build(add_op.out()); paddle::dialect::DivideOp div_op = rewriter.Build( bn_scale.dyn_cast(), sqrt_op.out()); // reshape scale phi::DDim conv2d_filter_shape = ir::GetShapeFromValue(conv2d_filter); phi::DDim bn_scale_shape = bn_scale.type().dyn_cast().dims(); std::vector bn_scale_new_shape(conv2d_filter_shape.size(), 1); bn_scale_new_shape[0] = bn_scale_shape[0]; paddle::dialect::ReshapeOp reshape_scale_op = rewriter.Build(div_op.out(), bn_scale_new_shape); // new filter --> mul_op.out() paddle::dialect::MultiplyOp mul_op = rewriter.Build(conv2d_filter_result, reshape_scale_op.out()); auto conv2d_attributes = conv2d_op->attributes(); auto new_conv2d_op = rewriter.Build( conv2d_op.input().dyn_cast(), mul_op.out(), conv2d_attributes); // --- deal with bias --- paddle::dialect::MultiplyOp mul_bias_op = rewriter.Build( bn_mean.dyn_cast(), div_op.out()); // new bias --> sub_op.out() paddle::dialect::SubtractOp sub_op = rewriter.Build( bn_bias.dyn_cast(), mul_bias_op.out()); // reshape new bias phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out()); std::vector new_bias_new_shape(new_conv2d_out_shape.size(), 1); std::string data_format = new_conv2d_op.attribute("data_format").AsString(); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); new_bias_new_shape[1] = new_conv2d_out_shape[1]; paddle::dialect::ReshapeOp reshape_bias_op = rewriter.Build(sub_op.out(), new_bias_new_shape); paddle::dialect::AddOp add_bias_op = rewriter.Build( new_conv2d_op.out(), reshape_bias_op.out()); rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out()); rewriter.EraseOp(op); rewriter.EraseOp(conv2d_op); return true; } }; namespace paddle { namespace dialect { class Conv2dFusionOpTest : public ir::Op { public: using Op::Op; static const char *name() { return "pd.conv2d_fusion_test"; } static const char *attributes_name[10]; static constexpr uint32_t attributes_num = 10; static OpInfoTuple GetOpInfo(); static void Build(ir::Builder &builder, // NOLINT ir::OperationArgument &argument, // NOLINT ir::OpResult input_, ir::OpResult filter_, ir::OpResult bias_, ir::OpResult residual_, const std::vector &strides, const std::vector &paddings_t, std::string padding_algorithm, const std::vector &dilations_t, int groups, std::string data_format, std::string activation, bool exhaustive_search, const std::vector &channels, int user_workspace_size); static void Build(ir::Builder &builder, // NOLINT ir::OperationArgument &argument, // NOLINT ir::OpResult input_, ir::OpResult filter_, ir::OpResult bias_, ir::OpResult residual_, ir::AttributeMap attributes); void Verify(); ir::Value input() { return operand_source(0); } ir::Value filter() { return operand_source(1); } ir::Value bias() { return operand_source(2); } ir::Value residual() { return operand_source(3); } ir::OpResult output() { return result(0); } ir::OpResult outputs() { return result(1); } ir::Attribute attribute(const std::string &name) { { PADDLE_ENFORCE( attributes().count(name) > 0, phi::errors::PreconditionNotMet("Attribute is not exist.")); return attributes().at(name); } } template T attribute(const std::string &name) { { PADDLE_ENFORCE( attributes().count(name) > 0 && attributes().at(name).isa(), phi::errors::PreconditionNotMet("Attribute is not right.")); return attributes().at(name).dyn_cast(); } } static void InferMeta(phi::InferMetaContext *infer_meta); }; const char *Conv2dFusionOpTest::attributes_name[10] = {"strides", "paddings_t", "padding_algorithm", "dilations_t", "groups", "data_format", "activation", "exhaustive_search", "channels", "user_workspace_size"}; OpInfoTuple Conv2dFusionOpTest::GetOpInfo() { std::vector inputs = { OpInputInfo( "input", "paddle::dialect::DenseTensorType", false, false, false), OpInputInfo( "filter", "paddle::dialect::DenseTensorType", false, false, false), OpInputInfo( "bias", "paddle::dialect::DenseTensorType", false, false, false), OpInputInfo( "residual", "paddle::dialect::DenseTensorType", true, false, false)}; std::vector attributes = { OpAttributeInfo("strides", "ir::ArrayAttribute", ""), OpAttributeInfo( "paddings_t", "ir::ArrayAttribute", ""), OpAttributeInfo("padding_algorithm", "ir::StrAttribute", ""), OpAttributeInfo( "dilations_t", "ir::ArrayAttribute", ""), OpAttributeInfo("groups", "ir::Int32Attribute", ""), OpAttributeInfo("data_format", "ir::StrAttribute", ""), OpAttributeInfo("activation", "ir::StrAttribute", ""), OpAttributeInfo("exhaustive_search", "ir::BoolAttribute", ""), OpAttributeInfo("channels", "ir::ArrayAttribute", ""), OpAttributeInfo("user_workspace_size", "ir::Int32Attribute", "")}; std::vector outputs = { OpOutputInfo("output", "paddle::dialect::DenseTensorType", false, false), OpOutputInfo("outputs", "ir::VectorType", false, false)}; paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("Conv2dFusionInferMeta", {"input", "filter", "bias", "residual", "strides", "paddings_t", "padding_algorithm", "dilations_t", "groups", "data_format", "activation", "exhaustive_search", "channels", "user_workspace_size"}, {"ConvFusionKernel"}, {"input", "filter", "bias", "residual", "strides", "paddings_t", "padding_algorithm", "dilations_t", "groups", "data_format", "activation", "exhaustive_search", "channels", "user_workspace_size"}, {"input"}, {}, {}); return std::make_tuple(inputs, attributes, outputs, run_time_info); } void Conv2dFusionOpTest::Build(ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult input_, ir::OpResult filter_, ir::OpResult bias_, ir::OpResult residual_, ir::AttributeMap attributes) { std::vector strides; for (size_t i = 0; i < attributes.at("strides").dyn_cast().size(); i++) { strides.push_back(attributes.at("strides") .dyn_cast() .at(i) .dyn_cast() .data()); } std::vector paddings_t; for (size_t i = 0; i < attributes.at("paddings_t").dyn_cast().size(); i++) { paddings_t.push_back(attributes.at("paddings_t") .dyn_cast() .at(i) .dyn_cast() .data()); } std::string padding_algorithm = attributes.at("padding_algorithm") .dyn_cast() .AsString(); std::vector dilations_t; for (size_t i = 0; i < attributes.at("dilations_t").dyn_cast().size(); i++) { dilations_t.push_back(attributes.at("dilations_t") .dyn_cast() .at(i) .dyn_cast() .data()); } int groups = attributes.at("groups").dyn_cast().data(); std::string data_format = attributes.at("data_format").dyn_cast().AsString(); std::string activation = attributes.at("activation").dyn_cast().AsString(); bool exhaustive_search = attributes.at("exhaustive_search").dyn_cast().data(); std::vector channels; for (size_t i = 0; i < attributes.at("channels").dyn_cast().size(); i++) { channels.push_back(attributes.at("channels") .dyn_cast() .at(i) .dyn_cast() .data()); } int user_workspace_size = attributes.at("user_workspace_size") .dyn_cast() .data(); VLOG(4) << "Builder construction inputs"; std::vector argument_inputs = { input_, filter_, bias_, residual_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; std::vector vec_strides; for (size_t i = 0; i < static_cast(strides.size()); i++) { ir::Attribute attr_strides = ir::Int32Attribute::get(ir::IrContext::Instance(), strides[i]); vec_strides.push_back(attr_strides); } ir::Attribute attr_strides = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_strides); argument.AddAttribute("strides", attr_strides); std::vector vec_paddings_t; for (size_t i = 0; i < static_cast(paddings_t.size()); i++) { ir::Attribute attr_paddings_t = ir::Int32Attribute::get(ir::IrContext::Instance(), paddings_t[i]); vec_paddings_t.push_back(attr_paddings_t); } ir::Attribute attr_paddings_t = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_paddings_t); argument.AddAttribute("paddings_t", attr_paddings_t); ir::Attribute attr_padding_algorithm = ir::StrAttribute::get(ir::IrContext::Instance(), padding_algorithm); argument.AddAttribute("padding_algorithm", attr_padding_algorithm); std::vector vec_dilations_t; for (size_t i = 0; i < static_cast(dilations_t.size()); i++) { ir::Attribute attr_dilations_t = ir::Int32Attribute::get(ir::IrContext::Instance(), dilations_t[i]); vec_dilations_t.push_back(attr_dilations_t); } ir::Attribute attr_dilations_t = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_dilations_t); argument.AddAttribute("dilations_t", attr_dilations_t); ir::Attribute attr_groups = ir::Int32Attribute::get(ir::IrContext::Instance(), groups); argument.AddAttribute("groups", attr_groups); ir::Attribute attr_data_format = ir::StrAttribute::get(ir::IrContext::Instance(), data_format); argument.AddAttribute("data_format", attr_data_format); ir::Attribute attr_activation = ir::StrAttribute::get(ir::IrContext::Instance(), activation); argument.AddAttribute("activation", attr_activation); ir::Attribute attr_exhaustive_search = ir::BoolAttribute::get(ir::IrContext::Instance(), exhaustive_search); argument.AddAttribute("exhaustive_search", attr_exhaustive_search); std::vector vec_channels; for (size_t i = 0; i < static_cast(channels.size()); i++) { ir::Attribute attr_channels = ir::Int32Attribute::get(ir::IrContext::Instance(), channels[i]); vec_channels.push_back(attr_channels); } ir::Attribute attr_channels = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_channels); argument.AddAttribute("channels", attr_channels); ir::Attribute attr_user_workspace_size = ir::Int32Attribute::get(ir::IrContext::Instance(), user_workspace_size); argument.AddAttribute("user_workspace_size", attr_user_workspace_size); VLOG(4) << "Builder construction outputs"; paddle::dialect::DenseTensorType input = input_.type().dyn_cast(); (void)input; paddle::dialect::DenseTensorType filter = filter_.type().dyn_cast(); (void)filter; paddle::dialect::DenseTensorType bias = bias_.type().dyn_cast(); (void)bias; // paddle::dialect::DenseTensorType residual = // residual_.type().dyn_cast(); // (void)residual; VLOG(4) << "Builder construction dense_input"; phi::DenseTensor dense_input( std::make_unique( paddle::platform::CPUPlace()) .get(), phi::DenseTensorMeta(TransToPhiDataType(input.dtype()), input.dims(), input.data_layout(), input.lod(), input.offset())); VLOG(4) << "Builder construction meta_input"; phi::MetaTensor meta_input(&dense_input); VLOG(4) << "Builder construction dense_filter"; phi::DenseTensor dense_filter( std::make_unique( paddle::platform::CPUPlace()) .get(), phi::DenseTensorMeta(TransToPhiDataType(filter.dtype()), filter.dims(), filter.data_layout(), filter.lod(), filter.offset())); VLOG(4) << "Builder construction meta_filter"; phi::MetaTensor meta_filter(&dense_filter); VLOG(4) << "Builder construction dense_bias"; phi::DenseTensor dense_bias( std::make_unique( paddle::platform::CPUPlace()) .get(), phi::DenseTensorMeta(TransToPhiDataType(bias.dtype()), bias.dims(), bias.data_layout(), bias.lod(), bias.offset())); VLOG(4) << "Builder construction meta_bias"; phi::MetaTensor meta_bias(&dense_bias); // VLOG(4) << "Builder construction dense_residual"; // phi::DenseTensor // dense_residual(std::make_unique(paddle::platform::CPUPlace()).get(), // phi::DenseTensorMeta(TransToPhiDataType(residual.dtype()), // residual.dims(), // residual.data_layout(), // residual.lod(), // residual.offset())); VLOG(4) << "Builder construction meta_residual"; // phi::MetaTensor meta_residual(&dense_residual); phi::MetaTensor meta_residual; phi::DenseTensor dense_output; phi::MetaTensor meta_output(&dense_output); std::vector vec_dense_outputs((channels.size()), phi::DenseTensor()); std::vector vec_meta_outputs; for (size_t i = 0; i < static_cast(channels.size()); i++) { vec_meta_outputs.push_back(phi::MetaTensor(&vec_dense_outputs[i])); } std::vector meta_outputs; for (size_t i = 0; i < static_cast(vec_meta_outputs.size()); i++) { meta_outputs.push_back(&vec_meta_outputs[i]); } phi::FusedConvInferMeta(meta_input, meta_filter, meta_bias, meta_residual, strides, paddings_t, padding_algorithm, dilations_t, groups, data_format, "float32", "identity", false, false, &meta_output, phi::MetaConfig()); std::vector argument_outputs; auto output_dense_tensor_type = paddle::dialect::DenseTensorType::get( ir::IrContext::Instance(), TransToIrDataType(dense_output.dtype()), dense_output.dims(), dense_output.layout(), dense_output.lod(), dense_output.offset()); LOG(INFO) << output_dense_tensor_type; argument_outputs.push_back(output_dense_tensor_type); std::vector outputs_types; for (size_t i = 0; i < static_cast(channels.size()); i++) { outputs_types.push_back(paddle::dialect::DenseTensorType::get( ir::IrContext::Instance(), TransToIrDataType(vec_dense_outputs[i].dtype()), vec_dense_outputs[i].dims(), vec_dense_outputs[i].layout(), vec_dense_outputs[i].lod(), vec_dense_outputs[i].offset())); } ir::Type outputs_vector_type = ir::VectorType::get(ir::IrContext::Instance(), outputs_types); argument_outputs.push_back(outputs_vector_type); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } void Conv2dFusionOpTest::Verify() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: Conv2dFusionOp."; VLOG(4) << "Verifying inputs:"; { auto input_size = num_operands(); PADDLE_ENFORCE_EQ( input_size, 4u, phi::errors::PreconditionNotMet( "The size %d of inputs must be equal to 4.", input_size)); PADDLE_ENFORCE((*this) ->operand_source(0) .type() .isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 0th input.")); PADDLE_ENFORCE((*this) ->operand_source(1) .type() .isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 1th input.")); PADDLE_ENFORCE((*this) ->operand_source(2) .type() .isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 2th input.")); if (auto val = (*this)->operand(3)) { PADDLE_ENFORCE(val.type().isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 3th input.")); } } VLOG(4) << "Verifying attributes:"; { auto &attributes = this->attributes(); PADDLE_ENFORCE(attributes.count("strides") > 0 && attributes.at("strides").isa(), phi::errors::PreconditionNotMet( "Type of attribute: strides is not right.")); for (size_t i = 0; i < attributes.at("strides").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("strides") .dyn_cast() .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: strides is not right.")); } PADDLE_ENFORCE(attributes.count("paddings_t") > 0 && attributes.at("paddings_t").isa(), phi::errors::PreconditionNotMet( "Type of attribute: paddings_t is not right.")); for (size_t i = 0; i < attributes.at("paddings_t").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("paddings_t") .dyn_cast() .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: paddings_t is not right.")); } PADDLE_ENFORCE( attributes.count("padding_algorithm") > 0 && attributes.at("padding_algorithm").isa(), phi::errors::PreconditionNotMet( "Type of attribute: padding_algorithm is not right.")); PADDLE_ENFORCE(attributes.count("dilations_t") > 0 && attributes.at("dilations_t").isa(), phi::errors::PreconditionNotMet( "Type of attribute: dilations_t is not right.")); for (size_t i = 0; i < attributes.at("dilations_t").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("dilations_t") .dyn_cast() .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: dilations_t is not right.")); } PADDLE_ENFORCE(attributes.count("groups") > 0 && attributes.at("groups").isa(), phi::errors::PreconditionNotMet( "Type of attribute: groups is not right.")); PADDLE_ENFORCE(attributes.count("data_format") > 0 && attributes.at("data_format").isa(), phi::errors::PreconditionNotMet( "Type of attribute: data_format is not right.")); PADDLE_ENFORCE(attributes.count("activation") > 0 && attributes.at("activation").isa(), phi::errors::PreconditionNotMet( "Type of attribute: activation is not right.")); PADDLE_ENFORCE( attributes.count("exhaustive_search") > 0 && attributes.at("exhaustive_search").isa(), phi::errors::PreconditionNotMet( "Type of attribute: exhaustive_search is not right.")); PADDLE_ENFORCE(attributes.count("channels") > 0 && attributes.at("channels").isa(), phi::errors::PreconditionNotMet( "Type of attribute: channels is not right.")); for (size_t i = 0; i < attributes.at("channels").dyn_cast().size(); i++) { PADDLE_ENFORCE(attributes.at("channels") .dyn_cast() .at(i) .isa(), phi::errors::PreconditionNotMet( "Type of attribute: channels is not right.")); } PADDLE_ENFORCE( attributes.count("user_workspace_size") > 0 && attributes.at("user_workspace_size").isa(), phi::errors::PreconditionNotMet( "Type of attribute: user_workspace_size is not right.")); } VLOG(4) << "Verifying outputs:"; { auto output_size = num_results(); PADDLE_ENFORCE_EQ( output_size, 2u, phi::errors::PreconditionNotMet( "The size %d of outputs must be equal to 2.", output_size)); PADDLE_ENFORCE( (*this)->result(0).type().isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 0th output.")); auto output_1_type = (*this)->result(1).type(); if (auto vec_type = output_1_type.dyn_cast()) { for (size_t i = 0; i < vec_type.size(); i++) { PADDLE_ENFORCE(vec_type[i].isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 1th output.")); } } else { PADDLE_ENFORCE(output_1_type.isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 1th output.")); } } VLOG(4) << "End Verifying for: Conv2dFusionOp."; } void Conv2dFusionOpTest::InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::FusedConvInferMeta); fn(infer_meta); } } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Conv2dFusionOpTest) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Conv2dFusionOpTest) class Conv2dFusionTestDialect : public ir::Dialect { public: explicit Conv2dFusionTestDialect(ir::IrContext *context) : ir::Dialect(name(), context, ir::TypeId::get()) { initialize(); } static const char *name() { return "con2d fusion test"; } private: void initialize() { RegisterOps(); } }; IR_DECLARE_EXPLICIT_TYPE_ID(Conv2dFusionTestDialect) IR_DEFINE_EXPLICIT_TYPE_ID(Conv2dFusionTestDialect) class Conv2dAddFusePattern : public ir::OpRewritePattern { public: using ir::OpRewritePattern::OpRewritePattern; bool MatchAndRewrite( paddle::dialect::AddOp op, ir::PatternRewriter &rewriter) const override { // NOLINT // The next op should be add. paddle::dialect::Conv2dOp conv2d_op = ir::GetDefiningOpForInput(op, 0)->dyn_cast(); if (!conv2d_op) return false; ir::OpResult conv2d_out = conv2d_op.out(); if (!conv2d_out.HasOneUse()) return false; ir::Value conv2d_filter = conv2d_op.filter(); ir::OpResult conv2d_filter_result = conv2d_filter.dyn_cast(); IR_ENFORCE(conv2d_filter_result); ir::Value add_input = op.x(); IR_ENFORCE(add_input == conv2d_out); ir::Value y = op.y(); ir::OpResult bias = y.dyn_cast(); auto conv2d_attributes = conv2d_op.attributes(); std::vector conv2d_fusion_attrStr = {"strides", "paddings_t", "padding_algorithm", "dilations_t", "groups", "data_format", "activation", "exhaustive_search", "channels", "user_workspace_size"}; std::vector con2d_fusing_attr = { conv2d_attributes.at("strides"), conv2d_attributes.at("paddings"), conv2d_attributes.at("padding_algorithm"), conv2d_attributes.at("dilations"), conv2d_attributes.at("groups"), conv2d_attributes.at("data_format"), rewriter.str_attr("identity"), rewriter.bool_attr(true), rewriter.array_attr(std::vector{}), rewriter.int32_attr(0)}; ir::AttributeMap conv2d_fusion_attributes; for (size_t i = 0; i < conv2d_fusion_attrStr.size(); ++i) { conv2d_fusion_attributes[conv2d_fusion_attrStr[i]] = con2d_fusing_attr[i]; } ir::OpResult tmpResidual; auto conv2d_fuse_op = rewriter.Build( ir::GetDefiningOpForInput(conv2d_op, 0)->result(0), conv2d_filter_result, bias, tmpResidual, conv2d_fusion_attributes); rewriter.ReplaceOp(op, std::vector{conv2d_fuse_op.output()}); return true; } }; class TestPass : public ir::Pass { public: TestPass() : ir::Pass("TestPass", 1) {} bool Initialize(ir::IrContext *context) override { ir::RewritePatternSet ps(context); ps.Add(context); auto conv_bn_pattern = std::make_unique( context, 1, std::vector{paddle::dialect::FullOp::name(), paddle::dialect::AddOp::name(), paddle::dialect::SqrtOp::name(), paddle::dialect::DivideOp::name(), paddle::dialect::ReshapeOp::name(), paddle::dialect::MultiplyOp::name(), paddle::dialect::SubtractOp::name(), paddle::dialect::Conv2dOp::name()}); LOG(INFO) << "Conv2dBnFusePattern will generate the following operations: "; for (auto op_info : conv_bn_pattern->generated_ops()) { LOG(INFO) << "--- " << op_info.name(); } ps.Add(std::move(conv_bn_pattern)); patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); return true; } void Run(ir::Operation *op) override { ir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } bool CanApplyOn(ir::Operation *op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } private: ir::FrozenRewritePatternSet patterns_; }; void BuildProgram(ir::Builder &builder) { // NOLINT paddle::dialect::FullOp full_input_op = builder.Build(std::vector{4, 3, 16, 16}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_filter_op = builder.Build(std::vector{64, 3, 3, 3}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_mean_op = builder.Build( std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_variance_op = builder.Build(std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_scale_op = builder.Build(std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_bias_op = builder.Build( std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::Conv2dOp conv2d_op = builder.Build(full_input_op.out(), full_filter_op.out()); paddle::dialect::BatchNormOp batch_norm_op = builder.Build(conv2d_op.out(), full_mean_op.out(), full_variance_op.out(), full_scale_op.out(), full_bias_op.out(), true, 0.9, 1e-6, "NCHW", false, false); auto transpose1_op = builder.Build( batch_norm_op.out(), std::vector{0, 2, 3, 1}); auto transpose2_op = builder.Build( transpose1_op.out(), std::vector{0, 3, 1, 2}); builder.Build(transpose2_op.out(), "out", 0); } // TODO(wilber): Add a normal test. TEST(pattern_rewrite, Patterns) { ir::IrContext *ctx = ir::IrContext::Instance(); auto *test_dialect = ctx->GetOrRegisterDialect(); test_dialect->RegisterOp(); ctx->GetOrRegisterDialect(); ir::Program program(ctx); ir::Builder builder = ir::Builder(ctx, program.block()); BuildProgram(builder); EXPECT_EQ(program.block()->size(), 11u); ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); // pm.AddPass(ir::CreateConstantFoldingPass()); pm.AddPass(ir::CreateDeadCodeEliminationPass()); pm.EnablePassTiming(); pm.EnableIRPrinting(); // pm.EnableIRPrinting(std::make_unique( // [](ir::Pass *pass, ir::Operation *op) { // return pass->name() == "ConstantFoldingPass"; // }, // [](ir::Pass *pass, ir::Operation *op) { // return pass->name() == "ConstantFoldingPass"; // }, // true, // true)); CHECK_EQ(pm.Run(&program), true); }