// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "glog/logging.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/ir/dialect/CMakeLists.txt. #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" #ifndef _WIN32 class TestAnalysis1 {}; class TestAnalysis2 {}; IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis1) IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis1) IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis2) IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis2) TEST(pass_manager, PreservedAnalyses) { ir::detail::PreservedAnalyses pa; CHECK_EQ(pa.IsNone(), true); CHECK_EQ(pa.IsPreserved(), false); pa.Preserve(); CHECK_EQ(pa.IsPreserved(), true); pa.Unpreserve(); CHECK_EQ(pa.IsPreserved(), false); CHECK_EQ(pa.IsPreserved(), false); pa.Preserve(); CHECK_EQ(pa.IsPreserved(), true); CHECK_EQ(pa.IsPreserved(), true); CHECK_EQ(pa.IsAll(), false); pa.PreserveAll(); CHECK_EQ(pa.IsAll(), true); CHECK_EQ(pa.IsNone(), false); } #endif class AddOp : public ir::Op { public: using Op::Op; static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; void Verify(); static void Build(ir::Builder &builder, // NOLINT ir::OperationArgument &argument, // NOLINT ir::OpResult l_operand, ir::OpResult r_operand, ir::Type sum_type); }; void AddOp::Verify() { if (num_operands() != 2) { throw("The size of inputs must be equal to 2."); } if (num_results() != 1) { throw("The size of outputs must be equal to 1."); } } void AddOp::Build(ir::Builder &, ir::OperationArgument &argument, ir::OpResult l_operand, ir::OpResult r_operand, ir::Type sum_type) { argument.AddOperand(l_operand); argument.AddOperand(r_operand); argument.AddOutput(sum_type); } IR_DECLARE_EXPLICIT_TYPE_ID(AddOp) IR_DEFINE_EXPLICIT_TYPE_ID(AddOp) struct CountOpAnalysis { explicit CountOpAnalysis(ir::Operation *container_op) { IR_ENFORCE(container_op->num_regions() > 0, "op must be a container with zero or multiple regions."); LOG(INFO) << "In CountOpAnalysis, op is " << container_op->name() << "\n"; for (size_t i = 0; i < container_op->num_regions(); ++i) { auto ®ion = container_op->region(i); for (auto it = region.begin(); it != region.end(); ++it) { auto *block = *it; for (auto it = block->begin(); it != block->end(); ++it) { ++count; } } } LOG(INFO) << "-- count is " << count << "\n"; } int count = 0; }; IR_DECLARE_EXPLICIT_TYPE_ID(CountOpAnalysis) IR_DEFINE_EXPLICIT_TYPE_ID(CountOpAnalysis) class TestPass : public ir::Pass { public: TestPass() : ir::Pass("TestPass", 1) {} void Run(ir::Operation *op) override { auto count_op_analysis = analysis_manager().GetAnalysis(); pass_state().preserved_analyses.Preserve(); CHECK_EQ(pass_state().preserved_analyses.IsPreserved(), true); CHECK_EQ(count_op_analysis.count, 11); auto module_op = op->dyn_cast(); CHECK_EQ(module_op.operation(), op); CHECK_EQ(module_op.name(), module_op->name()); LOG(INFO) << "In " << pass_info().name << ": " << module_op->name() << std::endl; pass_state().preserved_analyses.Unpreserve(); CHECK_EQ(pass_state().preserved_analyses.IsPreserved(), false); } bool CanApplyOn(ir::Operation *op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } }; void BuildProgram(ir::Builder &builder) { // NOLINT paddle::dialect::FullOp full_input_op = builder.Build(std::vector{4, 3, 16, 16}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_filter_op = builder.Build(std::vector{64, 3, 3, 3}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_mean_op = builder.Build( std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_variance_op = builder.Build(std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_scale_op = builder.Build(std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp full_bias_op = builder.Build( std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::Conv2dOp conv2d_op = builder.Build(full_input_op.out(), full_filter_op.out()); paddle::dialect::BatchNormOp batch_norm_op = builder.Build(conv2d_op.out(), full_mean_op.out(), full_variance_op.out(), full_scale_op.out(), full_bias_op.out(), true, 0.9, 1e-6, "NCHW", false, false); auto transpose1_op = builder.Build( batch_norm_op.out(), std::vector{0, 2, 3, 1}); auto transpose2_op = builder.Build( transpose1_op.out(), std::vector{0, 3, 1, 2}); builder.Build(transpose2_op.out(), "out", 0); } TEST(pass_manager, PassManager) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program(ctx); ir::Builder builder = ir::Builder(ctx, program.block()); BuildProgram(builder); EXPECT_EQ(program.block()->size(), 11u); // (9) Test pass manager for program. ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); // pm.EnableIRPrinting(); pm.EnableIRPrinting(std::make_unique( [](ir::Pass *pass, ir::Operation *op) { return pass->name() == "TestPass"; }, [](ir::Pass *pass, ir::Operation *op) { return pass->name() == "TestPass"; }, true, true)); pm.EnablePassTiming(true); CHECK_EQ(pm.Run(&program), true); }