From 4905a247c7aac101bdaa9aa378aab85a04363be2 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Mon, 10 Jul 2023 20:00:01 +0800 Subject: [PATCH] [PASS] add constant folding pass (#55099) --- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 14 +- paddle/fluid/ir/transforms/CMakeLists.txt | 6 + .../ir/transforms/constant_folding_pass.cc | 203 ++++++++++++++++++ .../ir/transforms/constant_folding_pass.h} | 3 +- .../transforms/transform_general_functions.cc | 35 ++- .../transforms/transform_general_functions.h | 43 ++-- paddle/ir/CMakeLists.txt | 2 +- .../CMakeLists.txt | 0 .../dead_code_elimination_pass.cc} | 40 ++-- .../dead_code_elimination_pass.h | 26 +++ paddle/ir/pass/pass.h | 4 +- paddle/ir/pass/pass_manager.h | 3 +- paddle/ir/pattern_rewrite/pattern_match.h | 11 +- test/cpp/ir/pass/pass_manager_test.cc | 191 ++++++---------- test/cpp/ir/pattern_rewrite/CMakeLists.txt | 18 +- .../pattern_rewrite/pattern_rewrite_test.cc | 97 ++++++--- .../standalone_executor_new_ir_test.cc | 5 - 17 files changed, 482 insertions(+), 219 deletions(-) create mode 100644 paddle/fluid/ir/transforms/constant_folding_pass.cc rename paddle/{ir/transforms/dce.h => fluid/ir/transforms/constant_folding_pass.h} (92%) rename paddle/ir/{transforms => builtin_transforms}/CMakeLists.txt (100%) rename paddle/ir/{transforms/dce.cc => builtin_transforms/dead_code_elimination_pass.cc} (56%) create mode 100644 paddle/ir/builtin_transforms/dead_code_elimination_pass.h diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index bc454897e73..e3d9e222836 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -139,14 +139,12 @@ void HandleForSpecialOp(ir::Operation* op, if (op_name == "pd.fetch") { // fetch is a very special op, with no output VLOG(6) << "Handle for pd.fetch:"; - for (size_t i = 0; i < input_num; ++i) { - auto var = scope->Var("fetch"); - VLOG(6) << "Create var: fetch in scope " << scope; - auto fetch_list = var->GetMutable(); - int index = - op->attributes().at("col").dyn_cast().data(); - fetch_list->resize(index + 1); - } + auto var = scope->Var("fetch"); + VLOG(6) << "Create var: fetch in scope " << scope; + auto fetch_list = var->GetMutable(); + int index = + op->attributes().at("col").dyn_cast().data(); + fetch_list->resize(index + 1); } if (op_name == "pd.feed") { diff --git a/paddle/fluid/ir/transforms/CMakeLists.txt b/paddle/fluid/ir/transforms/CMakeLists.txt index 83e508dbd40..d23e253d83b 100644 --- a/paddle/fluid/ir/transforms/CMakeLists.txt +++ b/paddle/fluid/ir/transforms/CMakeLists.txt @@ -7,3 +7,9 @@ cc_library( pd_op_to_kernel_pass SRCS pd_op_to_kernel_pass.cc DEPS phi_utils pd_interface pd_trait ir) + +cc_library( + _constant_folding_pass + SRCS constant_folding_pass.cc + DEPS standalone_executor phi pd_op_to_kernel_pass transform_general_functions + ir) diff --git a/paddle/fluid/ir/transforms/constant_folding_pass.cc b/paddle/fluid/ir/transforms/constant_folding_pass.cc new file mode 100644 index 00000000000..3fcdee6748b --- /dev/null +++ b/paddle/fluid/ir/transforms/constant_folding_pass.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/ir/transforms/constant_folding_pass.h" + +#include +#include +#include + +// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in +// paddle/fluid/ir/dialect/CMakeLists.txt. +#include "paddle/fluid/ir/dialect/pd_op.h" + +#include "paddle/fluid/framework/new_executor/interpretercore.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/fluid/ir/dialect/pd_type.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/fluid/ir/transforms/transform_general_functions.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/parameter.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/ir/pattern_rewrite/pattern_match.h" +#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" + +namespace { + +class ConstantFoldingPattern : public ir::RewritePattern { + public: + ConstantFoldingPattern(ir::IrContext* context, + ir::PatternBenefit benefit = 1, + const std::vector& generated_names = {}) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { + } + + bool Match(ir::Operation* op) const override { + // TODO(liuyuanle): Use trait to improve robustness. + if (op->dyn_cast() || + op->dyn_cast() || + op->dyn_cast()) + return false; + + // Inputs must come from get parameter op. + for (uint32_t i = 0; i < op->num_operands(); ++i) + if (ir::GetDefiningOpForInput(op, i)->dyn_cast() == + nullptr) + return false; + return true; + } + + void Rewrite(ir::Operation* op, + ir::PatternRewriter& rewriter) const override { // NOLINT + ir::Program* program = op->GetParentProgram(); + auto temp_program = BuildProgramFromOperation(op); + + // Execute program + paddle::framework::interpreter::ExecutionConfig exe_config; + exe_config.create_local_scope = false; + paddle::framework::InterpreterCore core( + phi::CPUPlace{}, + paddle::dialect::PdOpLowerToKernelPass(temp_program.get()), + &scope_, + exe_config); + + paddle::framework::FetchList fetch_list = core.Run({}); + + // TODO(liuyuanle): Support multiple output. + auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]); + std::unique_ptr parameter = std::make_unique( + reinterpret_cast(out_tensor.data()), + out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), + op->result(0).type()); + + std::string param_name = + "@constant_folding_pass@_" + std::to_string(suffix_++); + + auto* param_var = scope_.Var(param_name); + auto* param_tensor = param_var->GetMutable(); + *param_tensor = out_tensor; + program->SetParameter(param_name, std::move(parameter)); + // rewriter.SetInsertionPoint(op); + auto get_parameter_op = + rewriter.Build(param_name, op->result(0).type()); + + rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0)); + rewriter.EraseOp(op); + } + + private: + std::unique_ptr BuildProgramFromOperation( + ir::Operation* op) const { + auto program = std::make_unique(ir_context()); + ir::Builder builder = ir::Builder(ir_context(), program->block()); + + // prepare op inputs + std::vector op_inputs; + for (uint32_t i = 0; i < op->num_operands(); i++) { + PADDLE_ENFORCE_EQ( + op->operand(i).type().isa(), + true, + phi::errors::InvalidArgument( + "Op's input must be a dense tensor type.")); + + auto [param_name, param] = ir::GetParameterFromValue(op->operand(i)); + program->SetParameter(param_name, + std::make_unique(*param)); + + auto* param_var = scope_.FindVar(param_name); + PADDLE_ENFORCE_NOT_NULL( + param_var, + phi::errors::InvalidArgument("Parameter var not in scope.")); + + auto get_parameter_op = + builder.Build(param_name, op->operand(i).type()); + op_inputs.push_back(get_parameter_op->result(0)); + } + + // prepare op outputs + std::vector output_types; + for (uint32_t i = 0; i < op->num_results(); i++) { + output_types.push_back(op->result(i).type()); + } + + auto* temp_op = + builder.Build(op_inputs, op->attributes(), output_types, op->info()); + + // TODO(liuyuanle): Support multiple output. + // for (uint32_t i = 0; i < op->num_results(); i++) { + PADDLE_ENFORCE_EQ( + temp_op->result(0).type().isa(), + true, + phi::errors::InvalidArgument( + "Op's output must be a dense tensor type.")); + + builder.Build( + temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0); + // } + + return program; + } + + private: + static size_t suffix_; + static paddle::framework::Scope scope_; +}; + +size_t ConstantFoldingPattern::suffix_ = 0; +paddle::framework::Scope ConstantFoldingPattern::scope_ = {}; + +class ConstantFoldingPass : public ir::Pass { + public: + // TODO(liuyuanle): Naming convention for pass. + ConstantFoldingPass() : ir::Pass("ConstantFoldingPass", 1) {} + + bool Initialize(ir::IrContext* context) override { + ir::RewritePatternSet ps(context); + ps.Add(context); + patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(ir::Operation* op) override { + ir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(ir::Operation* op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + ir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace ir { + +std::unique_ptr CreateConstantFoldingPass() { + return std::make_unique(); +} + +} // namespace ir diff --git a/paddle/ir/transforms/dce.h b/paddle/fluid/ir/transforms/constant_folding_pass.h similarity index 92% rename from paddle/ir/transforms/dce.h rename to paddle/fluid/ir/transforms/constant_folding_pass.h index 6e51b1b5b1d..0c5ca794ad5 100644 --- a/paddle/ir/transforms/dce.h +++ b/paddle/fluid/ir/transforms/constant_folding_pass.h @@ -18,8 +18,9 @@ #include "paddle/ir/core/dll_decl.h" namespace ir { + class Pass; -IR_API std::unique_ptr CreateDcePass(); +IR_API std::unique_ptr CreateConstantFoldingPass(); } // namespace ir diff --git a/paddle/fluid/ir/transforms/transform_general_functions.cc b/paddle/fluid/ir/transforms/transform_general_functions.cc index ca2a9ccc6d5..0de36ffd20b 100644 --- a/paddle/fluid/ir/transforms/transform_general_functions.cc +++ b/paddle/fluid/ir/transforms/transform_general_functions.cc @@ -17,22 +17,28 @@ #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/parameter.h" #include "paddle/ir/core/program.h" namespace ir { -ir::Parameter* GetParameterFromValue(ir::Value value) { +std::pair GetParameterFromValue(ir::Value value) { ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast(); PADDLE_ENFORCE_NOT_NULL( op, phi::errors::InvalidArgument( "Value must be a weight from a GetParameter op.")); ir::Program* program = op->GetParentProgram(); + PADDLE_ENFORCE_NOT_NULL( + program, phi::errors::InvalidArgument("Program should not be null.")); std::string name = op->attributes() .at(op.attributes_name[0]) .dyn_cast() .data(); - return program->GetParameter(name); + ir::Parameter* param = program->GetParameter(name); + PADDLE_ENFORCE_NOT_NULL( + param, phi::errors::InvalidArgument("Parameter should not be null.")); + return {name, param}; } const phi::DDim& GetShapeFromValue(ir::Value value) { @@ -44,4 +50,29 @@ const phi::DDim& GetShapeFromValue(ir::Value value) { return value.type().dyn_cast().dims(); } +ir::Type GetDataTypeFromValue(ir::Value value) { + // TODO(dev): Support other types like DenseTensor. + PADDLE_ENFORCE_EQ( + value.type().isa(), + true, + phi::errors::InvalidArgument("Value's type must be a DenseTensorType.")); + return value.type().dyn_cast().dtype(); +} + +Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { + PADDLE_ENFORCE_EQ( + index < op->num_operands(), + true, + phi::errors::InvalidArgument("Intput operand's index must be valid.")); + return op->operand(index).GetDefiningOp(); +} + +Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) { + PADDLE_ENFORCE_EQ( + index < op->num_results(), + true, + phi::errors::InvalidArgument("Output op result's index must be valid.")); + return op->result(index).first_use().owner(); +} + } // namespace ir diff --git a/paddle/fluid/ir/transforms/transform_general_functions.h b/paddle/fluid/ir/transforms/transform_general_functions.h index 69dafbe1517..b086af090f7 100644 --- a/paddle/fluid/ir/transforms/transform_general_functions.h +++ b/paddle/fluid/ir/transforms/transform_general_functions.h @@ -16,6 +16,7 @@ #include "paddle/ir/core/operation.h" #include "paddle/ir/core/parameter.h" +#include "paddle/ir/core/type.h" #include "paddle/ir/core/value.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" @@ -24,15 +25,16 @@ namespace ir { /** - * @brief Get the parameter from a value. + * @brief Get the [name, parameter] pair of pararmeter from a value. * * @note The value must be a output of a GetParameterOp. * * @param ir::Value * - * @return ir::Parameter* + * @return std::pair */ -ir::Parameter* GetParameterFromValue(ir::Value value); + +std::pair GetParameterFromValue(ir::Value value); /** * @brief Get tensor's shape from a value. @@ -43,37 +45,34 @@ ir::Parameter* GetParameterFromValue(ir::Value value); */ const phi::DDim& GetShapeFromValue(ir::Value value); +/** + * @brief Get tensor's data type from a value. + * + * @param ir::Value + * + * @return ir::Type + */ +ir::Type GetDataTypeFromValue(ir::Value value); + /** * @brief Get an operation that defines the specific input of the operation. * - * @param Operation* + * @param Operation* pointer to an operation + * @param uint32_t index of operand of the operation * * @return Operation* */ -template -Operation* GetDefiningOpForInput(Operation* op) { - PADDLE_ENFORCE_EQ( - Index < op->num_operands(), - true, - phi::errors::InvalidArgument("Intput operand's index must be valid.")); - return op->operand(Index).GetDefiningOp(); -} +Operation* GetDefiningOpForInput(Operation* op, uint32_t index); /** * @brief Get an operation that is the first to use the specific output of the * operation. * - * @param Operation* - * + * @param Operation* pointer to an operation + * @param uint32_t index of result of the operation + * @return Operation* */ -template -Operation* GetFirstUseOperationForOutput(Operation* op) { - PADDLE_ENFORCE_EQ( - Index < op->num_results(), - true, - phi::errors::InvalidArgument("Output op result's index must be valid.")); - return op->result(Index).first_use().owner(); -} +Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index); } // namespace ir diff --git a/paddle/ir/CMakeLists.txt b/paddle/ir/CMakeLists.txt index e82c944fb9f..39e5ff3fda6 100644 --- a/paddle/ir/CMakeLists.txt +++ b/paddle/ir/CMakeLists.txt @@ -37,7 +37,7 @@ endfunction() add_subdirectory(core) add_subdirectory(pass) add_subdirectory(pattern_rewrite) -add_subdirectory(transforms) +add_subdirectory(builtin_transforms) if(WIN32) if(WITH_SHARED_IR) diff --git a/paddle/ir/transforms/CMakeLists.txt b/paddle/ir/builtin_transforms/CMakeLists.txt similarity index 100% rename from paddle/ir/transforms/CMakeLists.txt rename to paddle/ir/builtin_transforms/CMakeLists.txt diff --git a/paddle/ir/transforms/dce.cc b/paddle/ir/builtin_transforms/dead_code_elimination_pass.cc similarity index 56% rename from paddle/ir/transforms/dce.cc rename to paddle/ir/builtin_transforms/dead_code_elimination_pass.cc index 94613fc017a..c595a7cae03 100644 --- a/paddle/ir/transforms/dce.cc +++ b/paddle/ir/builtin_transforms/dead_code_elimination_pass.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/transforms/dce.h" -#include +#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h" + #include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/program.h" #include "paddle/ir/pass/pass.h" namespace { @@ -22,30 +23,43 @@ namespace { // TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be // removed by dce pass. // Now just a naive implementation. -class DcePass : public ir::Pass { +class DeadCodeEliminationPass : public ir::Pass { public: - DcePass() : ir::Pass("DcePass", 0) {} + DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {} void Run(ir::Operation *op) override { auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "DcePass should run on module op."); auto *block = module_op.block(); - std::vector erased_op; + std::vector erased_op; for (auto it = block->begin(); it != block->end(); ++it) { + auto &op = *it; // TODO(wilber): Support NoSideEffect trait. - // if (!(*it)->HasTrait()) continue; + // if (!op->HasTrait()) continue; bool use_empty = true; - for (uint32_t i = 0; i < (*it)->num_results(); ++i) { - use_empty &= (*it)->result(i).use_empty(); + for (uint32_t i = 0; i < op->num_results(); ++i) { + use_empty &= op->result(i).use_empty(); } // TODO(wilber): Support Terminator trait. - if (use_empty && (*it)->name() != "pd.fetch") { - erased_op.push_back(**it); + if (use_empty && op->name() != "pd.fetch") { + erased_op.push_back(op); } } - for (auto ep : erased_op) block->erase(ep); + for (auto *op : erased_op) { + if (op->dyn_cast()) { + // Delete parameter from program. + ir::GetParameterOp get_parameter_op = + op->dyn_cast(); + get_parameter_op->GetParentProgram()->parameters().erase( + get_parameter_op->attributes() + .at(get_parameter_op.attributes_name[0]) + .dyn_cast() + .data()); + } + block->erase(*op); + } } bool CanApplyOn(ir::Operation *op) const override { @@ -57,6 +71,8 @@ class DcePass : public ir::Pass { namespace ir { -std::unique_ptr CreateDcePass() { return std::make_unique(); } +std::unique_ptr CreateDeadCodeEliminationPass() { + return std::make_unique(); +} } // namespace ir diff --git a/paddle/ir/builtin_transforms/dead_code_elimination_pass.h b/paddle/ir/builtin_transforms/dead_code_elimination_pass.h new file mode 100644 index 00000000000..f03c024ae1d --- /dev/null +++ b/paddle/ir/builtin_transforms/dead_code_elimination_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ir/core/dll_decl.h" + +namespace ir { + +class Pass; + +IR_API std::unique_ptr CreateDeadCodeEliminationPass(); + +} // namespace ir diff --git a/paddle/ir/pass/pass.h b/paddle/ir/pass/pass.h index d785f3a801f..484651e87e2 100644 --- a/paddle/ir/pass/pass.h +++ b/paddle/ir/pass/pass.h @@ -55,8 +55,8 @@ struct PassInfo { std::string name; // opt_level=0: the basic pass which framework need. - // opt_level=1: the fusion logical pass. - // opt_level=2: constant fold, cse, memory optimize, etc. + // opt_level=1: constant fold, cse, memory optimize, etc. + // opt_level=2: the fusion logical pass. // opt_level=3: layout, etc. uint8_t opt_level; diff --git a/paddle/ir/pass/pass_manager.h b/paddle/ir/pass/pass_manager.h index ec4191ebccb..67ac2d1ba34 100644 --- a/paddle/ir/pass/pass_manager.h +++ b/paddle/ir/pass/pass_manager.h @@ -110,7 +110,8 @@ class IR_API PassManager { // TODO(liuyuanle): Add flags to control printing behavior. }; - void EnableIRPrinting(std::unique_ptr config); + void EnableIRPrinting(std::unique_ptr option = + std::make_unique()); void EnablePassTiming(bool print_module = true); diff --git a/paddle/ir/pattern_rewrite/pattern_match.h b/paddle/ir/pattern_rewrite/pattern_match.h index 26a6a1842b9..8b3bbaa5b1c 100644 --- a/paddle/ir/pattern_rewrite/pattern_match.h +++ b/paddle/ir/pattern_rewrite/pattern_match.h @@ -26,6 +26,7 @@ #include "paddle/ir/core/builder.h" #include "paddle/ir/core/dll_decl.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation.h" @@ -148,6 +149,8 @@ class IR_API Pattern { const PatternBenefit benefit_; IrContext* context_; + // A list of the potential operations that may be generated when rewriting an + // op with this pattern. std::vector generated_ops_; std::string debug_name_; @@ -162,13 +165,13 @@ class IR_API RewritePattern : public Pattern { virtual void Rewrite(Operation* op, PatternRewriter& rewriter) const { // NOLINT - throw( + IR_THROW( "need to implement either MatchAndRewrite or one of the rewrite " "functions."); } virtual bool Match(Operation* op) const { - throw("need to implement either MatchAndRewrite or Match."); + IR_THROW("need to implement either MatchAndRewrite or Match."); return false; } @@ -220,10 +223,10 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern { virtual void Rewrite(SourceOp op, PatternRewriter& rewriter) const { // NOLINT - throw("must override Rewrite or MatchAndRewrite"); + IR_THROW("must override Rewrite or MatchAndRewrite"); } virtual bool Match(SourceOp op) const { - throw("must override Match or MatchAndRewrite"); + IR_THROW("must override Match or MatchAndRewrite"); } virtual bool MatchAndRewrite(SourceOp op, PatternRewriter& rewriter) const { // NOLINT diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index 501249dd6f5..0558c82f2f9 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -15,6 +15,10 @@ #include #include "glog/logging.h" +// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in +// paddle/fluid/ir/dialect/CMakeLists.txt. +#include "paddle/fluid/ir/dialect/pd_op.h" + #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/utils.h" @@ -125,7 +129,7 @@ class TestPass : public ir::Pass { pass_state().preserved_analyses.Preserve(); CHECK_EQ(pass_state().preserved_analyses.IsPreserved(), true); - CHECK_EQ(count_op_analysis.count, 4); + CHECK_EQ(count_op_analysis.count, 11); auto module_op = op->dyn_cast(); CHECK_EQ(module_op.operation(), op); @@ -143,139 +147,78 @@ class TestPass : public ir::Pass { } }; -TEST(pass_manager, PassManager) { - // - // TODO(liuyuanle): remove test code other than pass manager - // +void BuildProgram(ir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_filter_op = + builder.Build(std::vector{64, 3, 3, 3}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_mean_op = builder.Build( + std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::FullOp full_variance_op = + builder.Build(std::vector{64}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_scale_op = + builder.Build(std::vector{64}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_bias_op = builder.Build( + std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::Conv2dOp conv2d_op = + builder.Build(full_input_op.out(), + full_filter_op.out()); + + paddle::dialect::BatchNormOp batch_norm_op = + builder.Build(conv2d_op.out(), + full_mean_op.out(), + full_variance_op.out(), + full_scale_op.out(), + full_bias_op.out(), + true, + 0.9, + 1e-6, + "NCHW", + false, + false); + + auto transpose1_op = builder.Build( + batch_norm_op.out(), std::vector{0, 2, 3, 1}); + + auto transpose2_op = builder.Build( + transpose1_op.out(), std::vector{0, 3, 1, 2}); + + builder.Build(transpose2_op.out(), "out", 0); +} - // (1) Init environment. +TEST(pass_manager, PassManager) { ir::IrContext *ctx = ir::IrContext::Instance(); - ir::Dialect *builtin_dialect = - ctx->GetOrRegisterDialect(); - builtin_dialect->RegisterOp(); - ir::Dialect *paddle_dialect = - ctx->GetOrRegisterDialect(); - - // (2) Create an empty program object + ctx->GetOrRegisterDialect(); ir::Program program(ctx); + ir::Builder builder = ir::Builder(ctx, program.block()); + BuildProgram(builder); - // (3) Create a float32 DenseTensor Parameter and save into Program - ir::Type fp32_dtype = ir::Float32Type::get(ctx); - phi::DDim dims = {2, 2}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset); - - std::vector data_a = {1, 2, 3, 4}; - std::unique_ptr parameter_a = - std::make_unique(reinterpret_cast(data_a.data()), - 4 * sizeof(float), - dense_tensor_dtype); - program.SetParameter("a", std::move(parameter_a)); - EXPECT_EQ(program.parameters_num() == 1, true); - - std::vector data_b = {5, 6, 7, 8}; - std::unique_ptr parameter_b = - std::make_unique(reinterpret_cast(data_b.data()), - 4 * sizeof(float), - dense_tensor_dtype); - program.SetParameter("b", std::move(parameter_b)); - EXPECT_EQ(program.parameters_num() == 2, true); - - // (4) Def a = GetParameterOp("a"), and create DenseTensor for a. - ir::Builder builder(ctx, program.block()); - auto op1 = builder.Build("a", dense_tensor_dtype); - - EXPECT_EQ(&program, op1->GetParentProgram()); - EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id()); - using Interface = paddle::dialect::ParameterConvertInterface; - Interface *a_interface = - op1->result(0).type().dialect().GetRegisteredInterface(); - std::shared_ptr a_var = - a_interface->ParameterToVariable(program.GetParameter("a")); - const phi::DenseTensor &a_tensor = a_var->Get(); - EXPECT_EQ(a_tensor.numel(), 4); - EXPECT_EQ(a_tensor.dims(), dims); - EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); - EXPECT_EQ(a_tensor.layout(), data_layout); - EXPECT_EQ(a_tensor.lod(), lod); - EXPECT_EQ(a_tensor.offset(), offset); - for (int64_t i = 0; i < a_tensor.numel(); i++) { - EXPECT_EQ(*(a_tensor.data() + i), data_a[i]); - } - - // (5) Def b = GetParameterOp("b"), and create DenseTensor for b. - auto op2 = builder.Build("b", dense_tensor_dtype); - EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id()); - Interface *b_interface = - op2->result(0).type().dialect().GetRegisteredInterface(); - std::shared_ptr b_var = - b_interface->ParameterToVariable(program.GetParameter("b")); - const phi::DenseTensor &b_tensor = b_var->Get(); - EXPECT_EQ(b_tensor.numel(), 4); - EXPECT_EQ(b_tensor.dims(), dims); - EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); - EXPECT_EQ(b_tensor.layout(), data_layout); - EXPECT_EQ(b_tensor.lod(), lod); - EXPECT_EQ(b_tensor.offset(), offset); - for (int64_t i = 0; i < b_tensor.numel(); i++) { - EXPECT_EQ(*(b_tensor.data() + i), data_b[i]); - } - - // (6) Def c = AddOp(a, b), execute this op. - auto op3 = - builder.Build(op1->result(0), op2->result(0), dense_tensor_dtype); - phi::CPUContext *dev_ctx = static_cast( - paddle::platform::DeviceContextPool::Instance().Get( - paddle::platform::CPUPlace())); - phi::DenseTensor c_tensor = - phi::Add(*dev_ctx, a_tensor, b_tensor); - std::shared_ptr variable_c = - std::make_shared(); - auto *dst_tensor = variable_c->GetMutable(); - *dst_tensor = c_tensor; - EXPECT_EQ(dst_tensor->numel(), b_tensor.numel()); - EXPECT_EQ(dst_tensor->dims(), b_tensor.dims()); - EXPECT_EQ(dst_tensor->dtype(), b_tensor.dtype()); - EXPECT_EQ(dst_tensor->layout(), b_tensor.layout()); - EXPECT_EQ(dst_tensor->lod(), b_tensor.lod()); - EXPECT_EQ(dst_tensor->offset(), b_tensor.offset()); - for (int64_t i = 0; i < dst_tensor->numel(); i++) { - EXPECT_EQ(*(dst_tensor->data() + i), data_a[i] + data_b[i]); - } - - // (7) Def SetParameterOp(c, "c") - auto op4 = builder.Build(op3->result(0), "c"); - EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id()); - Interface *c_interface = - op4->op_operand(0).type().dialect().GetRegisteredInterface(); - // ir::Parameter *parameter_c = - // c_interface->VariableToParameter(variable_c.get()); - - std::unique_ptr parameter_c = - c_interface->VariableToParameter(variable_c.get()); - EXPECT_EQ(parameter_c->type(), dense_tensor_dtype); - for (int64_t i = 0; i < dst_tensor->numel(); i++) { - EXPECT_EQ(*(dst_tensor->data() + i), - *(static_cast(parameter_c->data()) + i)); - } - program.SetParameter("c", std::move(parameter_c)); - - // (8) Traverse Program - EXPECT_EQ(program.block()->size() == 4, true); - EXPECT_EQ(program.parameters_num() == 3, true); - - // - // TODO(liuyuanle): remove the code above. - // + EXPECT_EQ(program.block()->size(), 11u); // (9) Test pass manager for program. ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); + // pm.EnableIRPrinting(); pm.EnableIRPrinting(std::make_unique( [](ir::Pass *pass, ir::Operation *op) { return pass->name() == "TestPass"; diff --git a/test/cpp/ir/pattern_rewrite/CMakeLists.txt b/test/cpp/ir/pattern_rewrite/CMakeLists.txt index f3e2cbbee22..fd527db5550 100644 --- a/test/cpp/ir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/ir/pattern_rewrite/CMakeLists.txt @@ -1,9 +1,9 @@ -cc_test_old( - pattern_rewrite_test - SRCS - pattern_rewrite_test.cc - DEPS - ir - pd_dialect - transform_general_functions - gtest) +set(PATTERN_REWRITE_TEST_DEPS _constant_folding_pass + transform_general_functions gtest pd_dialect ir) + +if(WITH_DISTRIBUTE) + set(PATTERN_REWRITE_TEST_DEPS ${PATTERN_REWRITE_TEST_DEPS} fleet_executor) +endif() + +cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS + ${PATTERN_REWRITE_TEST_DEPS}) diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index ee51316f482..1edfdead8e7 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -15,12 +15,15 @@ #include #include #include +#include #include #include #include #include "paddle/fluid/ir/dialect/pd_attribute.h" +#include "paddle/fluid/ir/transforms/constant_folding_pass.h" #include "paddle/fluid/ir/transforms/transform_general_functions.h" +#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" @@ -39,7 +42,7 @@ #include "paddle/ir/pattern_rewrite/pattern_applicator.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" -#include "paddle/ir/transforms/dce.h" +#include "paddle/phi/core/kernel_registry.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/ir/dialect/CMakeLists.txt. @@ -56,6 +59,18 @@ #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/infermeta/multiary.h" +PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(divide, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(reshape, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(fetch, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(conv2d, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(transpose, CPU, ALL_LAYOUT); + // Define op1. class Operation1 : public ir::Op { public: @@ -197,7 +212,7 @@ class RedundantTransposeFusePattern bool MatchAndRewrite(paddle::dialect::TransposeOp op, ir::PatternRewriter &rewriter) const override { - auto prev_op = ir::GetDefiningOpForInput<0>(op); + auto prev_op = ir::GetDefiningOpForInput(op, 0); std::vector axis_last = GetAxis(op); auto prev_trans_op = prev_op->dyn_cast(); if (prev_trans_op) { @@ -207,7 +222,7 @@ class RedundantTransposeFusePattern auto new_perm = GetPerm(axis_first, axis_last); rewriter.SetInsertionPoint(op); auto new_transpose_op = rewriter.Build( - ir::GetDefiningOpForInput<0>(prev_trans_op)->result(0), new_perm); + ir::GetDefiningOpForInput(prev_trans_op, 0)->result(0), new_perm); rewriter.ReplaceOp(op, {new_transpose_op.out()}); return true; } @@ -249,7 +264,7 @@ class Conv2dBnFusePattern ir::PatternRewriter &rewriter) const override { // NOLINT // The next op should be batch_norm. paddle::dialect::Conv2dOp conv2d_op = - ir::GetDefiningOpForInput(op)->dyn_cast(); + ir::GetDefiningOpForInput(op, 0)->dyn_cast(); if (!conv2d_op) return false; ir::OpResult conv2d_out = conv2d_op.out(); @@ -320,7 +335,6 @@ class Conv2dBnFusePattern std::string data_format = new_conv2d_op.attribute("data_format").data(); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); - new_bias_new_shape[0] = new_conv2d_out_shape[0]; new_bias_new_shape[1] = new_conv2d_out_shape[1]; paddle::dialect::ReshapeOp reshape_bias_op = rewriter.Build(sub_op.out(), @@ -895,7 +909,7 @@ class Conv2dAddFusePattern ir::PatternRewriter &rewriter) const override { // NOLINT // The next op should be add. paddle::dialect::Conv2dOp conv2d_op = - ir::GetDefiningOpForInput(op)->dyn_cast(); + ir::GetDefiningOpForInput(op, 0)->dyn_cast(); if (!conv2d_op) return false; ir::OpResult conv2d_out = conv2d_op.out(); @@ -929,12 +943,10 @@ class Conv2dAddFusePattern conv2d_attributes.at("dilations"), conv2d_attributes.at("groups"), conv2d_attributes.at("data_format"), - ir::StrAttribute::get(ir::IrContext::Instance(), "identity"), - ir::BoolAttribute::get(ir::IrContext::Instance(), true), - ir::ArrayAttribute::get(ir::IrContext::Instance(), - std::vector()), - ir::Int32Attribute::get(ir::IrContext::Instance(), int32_t(0)), - }; + rewriter.str_attr("identity"), + rewriter.bool_attr(true), + rewriter.array_attr(std::vector{}), + rewriter.int32_attr(0)}; ir::AttributeMap conv2d_fusion_attributes; for (size_t i = 0; i < conv2d_fusion_attrStr.size(); ++i) { conv2d_fusion_attributes[conv2d_fusion_attrStr[i]] = con2d_fusing_attr[i]; @@ -943,7 +955,7 @@ class Conv2dAddFusePattern ir::OpResult tmpResidual; auto conv2d_fuse_op = rewriter.Build( - ir::GetDefiningOpForInput<0>(conv2d_op)->result(0), + ir::GetDefiningOpForInput(conv2d_op, 0)->result(0), conv2d_filter_result, bias, tmpResidual, @@ -956,27 +968,48 @@ class Conv2dAddFusePattern class TestPass : public ir::Pass { public: TestPass() : ir::Pass("TestPass", 1) {} - void Run(ir::Operation *op) override { - ir::RewritePatternSet ps(op->ir_context()); - ps.Add(op->ir_context()); - ps.Add(op->ir_context()); - ps.Add(op->ir_context()); - ir::FrozenRewritePatternSet frozen_ps(std::move(ps)); + bool Initialize(ir::IrContext *context) override { + ir::RewritePatternSet ps(context); + ps.Add(context); + auto conv_bn_pattern = std::make_unique( + context, + 1, + std::vector{paddle::dialect::FullOp::name(), + paddle::dialect::AddOp::name(), + paddle::dialect::SqrtOp::name(), + paddle::dialect::DivideOp::name(), + paddle::dialect::ReshapeOp::name(), + paddle::dialect::MultiplyOp::name(), + paddle::dialect::SubtractOp::name(), + paddle::dialect::Conv2dOp::name()}); + LOG(INFO) << "Conv2dBnFusePattern will generate the following operations: "; + for (auto op_info : conv_bn_pattern->generated_ops()) { + LOG(INFO) << "--- " << op_info.name(); + } + ps.Add(std::move(conv_bn_pattern)); + patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(ir::Operation *op) override { ir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; - ir::ApplyPatternsGreedily(op->region(0), frozen_ps, cfg); + ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } bool CanApplyOn(ir::Operation *op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } + + private: + ir::FrozenRewritePatternSet patterns_; }; void BuildProgram(ir::Builder &builder) { // NOLINT paddle::dialect::FullOp full_input_op = - builder.Build(std::vector{1, 3, 16, 16}, + builder.Build(std::vector{4, 3, 16, 16}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); @@ -1045,11 +1078,19 @@ TEST(pattern_rewrite, Patterns) { ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); - pm.AddPass(ir::CreateDcePass()); - program.Print(std::cout); - std::cout << std::endl; - pm.Run(&program); - LOG(INFO) << "After Pass."; - program.Print(std::cout); - std::cout << std::endl; + pm.AddPass(ir::CreateConstantFoldingPass()); + pm.AddPass(ir::CreateDeadCodeEliminationPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + // pm.EnableIRPrinting(std::make_unique( + // [](ir::Pass *pass, ir::Operation *op) { + // return pass->name() == "ConstantFoldingPass"; + // }, + // [](ir::Pass *pass, ir::Operation *op) { + // return pass->name() == "ConstantFoldingPass"; + // }, + // true, + // true)); + + CHECK_EQ(pm.Run(&program), true); } diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 8f88a744d5c..858d96d4434 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -69,7 +69,6 @@ TEST(StandaloneExecutor, run) { auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, std::move(kernel_program), &scope); test_core.Run({}); @@ -141,8 +140,6 @@ TEST(StandaloneExecutor, run_2) { auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; - InterpreterCore test_core(place, std::move(kernel_program), &scope); test_core.Run({}); @@ -216,8 +213,6 @@ TEST(StandaloneExecutor, data_transfer) { auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; - InterpreterCore test_core(place, std::move(kernel_program), &scope); test_core.Run({}); -- GitLab