// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/ir/transforms/constant_folding_pass.h" #include #include #include // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/ir/dialect/CMakeLists.txt. #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/transforms/transform_general_functions.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/parameter.h" #include "paddle/ir/core/program.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" namespace { class ConstantFoldingPattern : public ir::RewritePattern { public: ConstantFoldingPattern(ir::IrContext* context, ir::PatternBenefit benefit = 1, const std::vector& generated_names = {}) : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { } bool Match(ir::Operation* op) const override { // TODO(liuyuanle): Use trait to improve robustness. if (op->dyn_cast() || op->dyn_cast() || op->dyn_cast()) return false; // Inputs must come from get parameter op. for (uint32_t i = 0; i < op->num_operands(); ++i) if (ir::GetDefiningOpForInput(op, i)->dyn_cast() == nullptr) return false; return true; } void Rewrite(ir::Operation* op, ir::PatternRewriter& rewriter) const override { // NOLINT ir::Program* program = op->GetParentProgram(); auto temp_program = BuildProgramFromOperation(op); std::vector fetch_var_names; auto block = temp_program->block(); for (auto it = block->begin(); it != block->end(); ++it) { if ((*it)->name() == "pd.fetch") { size_t index = (*it)->attributes().at("col").dyn_cast().data(); if (fetch_var_names.size() < index + 1) { fetch_var_names.resize(index + 1); } fetch_var_names[index] = (*it) ->attributes() .at("name") .dyn_cast() .AsString() + "@fetch"; } } // Execute program exe_config_.create_local_scope = false; paddle::framework::InterpreterCore core( phi::CPUPlace{}, fetch_var_names, paddle::dialect::PdOpLowerToKernelPass(temp_program.get()), &scope_, exe_config_); paddle::framework::FetchList fetch_list = core.Run({}); // TODO(liuyuanle): Support multiple output. auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]); std::unique_ptr parameter = std::make_unique( reinterpret_cast(out_tensor.data()), out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), op->result(0).type()); std::string param_name = "@constant_folding_pass@_" + std::to_string(suffix_++); exe_config_.skip_gc_vars.insert(param_name); auto* param_var = scope_.Var(param_name); auto* param_tensor = param_var->GetMutable(); *param_tensor = out_tensor; program->SetParameter(param_name, std::move(parameter)); // rewriter.SetInsertionPoint(op); auto get_parameter_op = rewriter.Build(param_name, op->result(0).type()); rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0)); rewriter.EraseOp(op); } private: std::unique_ptr BuildProgramFromOperation( ir::Operation* op) const { auto program = std::make_unique(ir_context()); ir::Builder builder = ir::Builder(ir_context(), program->block()); // prepare op inputs std::vector op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { PADDLE_ENFORCE_EQ( op->operand_source(i).type().isa(), true, phi::errors::InvalidArgument( "Op's input must be a dense tensor type.")); auto [param_name, param] = ir::GetParameterFromValue(op->operand_source(i)); program->SetParameter(param_name, std::make_unique(*param)); auto* param_var = scope_.FindVar(param_name); PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); auto get_parameter_op = builder.Build( param_name, op->operand_source(i).type()); op_inputs.push_back(get_parameter_op->result(0)); } // prepare op outputs std::vector output_types; for (uint32_t i = 0; i < op->num_results(); i++) { output_types.push_back(op->result(i).type()); } auto* temp_op = builder.Build(op_inputs, op->attributes(), output_types, op->info()); // TODO(liuyuanle): Support multiple output. // for (uint32_t i = 0; i < op->num_results(); i++) { PADDLE_ENFORCE_EQ( temp_op->result(0).type().isa(), true, phi::errors::InvalidArgument( "Op's output must be a dense tensor type.")); builder.Build( temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0); // } return program; } private: inline static size_t suffix_{0}; inline static paddle::framework::Scope scope_{}; inline static paddle::framework::interpreter::ExecutionConfig exe_config_{}; }; class ConstantFoldingPass : public ir::Pass { public: // TODO(liuyuanle): Naming convention for pass. ConstantFoldingPass() : ir::Pass("ConstantFoldingPass", 1) {} bool Initialize(ir::IrContext* context) override { ir::RewritePatternSet ps(context); ps.Add(context); patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); return true; } void Run(ir::Operation* op) override { ir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } bool CanApplyOn(ir::Operation* op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } private: ir::FrozenRewritePatternSet patterns_; }; } // namespace namespace ir { std::unique_ptr CreateConstantFoldingPass() { return std::make_unique(); } } // namespace ir